Skip to content

Commit 5333d09

Browse files
authored
fix: scalar subquery pushdown and reuse for CometNativeScanExec (SPARK-43402) (#4053)
1 parent 724152e commit 5333d09

8 files changed

Lines changed: 333 additions & 41 deletions

File tree

dev/diffs/4.0.1.diff

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,24 +1255,23 @@ index 0df7f806272..92390bd819f 100644
12551255

12561256
test("non-matching optional group") {
12571257
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
1258-
index 2e33f6505ab..949fdea0003 100644
1258+
index 2e33f6505ab..54f5081e10a 100644
12591259
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
12601260
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
1261-
@@ -23,10 +23,12 @@ import org.apache.spark.SparkRuntimeException
1261+
@@ -23,10 +23,11 @@ import org.apache.spark.SparkRuntimeException
12621262
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
12631263
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
12641264
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, LogicalPlan, Project, Sort, Union}
12651265
+import org.apache.spark.sql.comet.{CometNativeColumnarToRowExec, CometNativeScanExec, CometScanExec}
12661266
import org.apache.spark.sql.execution._
1267-
+import org.apache.spark.sql.IgnoreCometNativeDataFusion
12681267
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecution}
12691268
import org.apache.spark.sql.execution.datasources.FileScanRDD
12701269
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
12711270
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
12721271
import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec}
12731272
import org.apache.spark.sql.internal.SQLConf
12741273
import org.apache.spark.sql.test.SharedSparkSession
1275-
@@ -1529,6 +1531,18 @@ class SubquerySuite extends QueryTest
1274+
@@ -1529,6 +1530,18 @@ class SubquerySuite extends QueryTest
12761275
fs.inputRDDs().forall(
12771276
_.asInstanceOf[FileScanRDD].filePartitions.forall(
12781277
_.files.forall(_.urlEncodedPath.contains("p=0"))))
@@ -1291,7 +1290,7 @@ index 2e33f6505ab..949fdea0003 100644
12911290
case _ => false
12921291
})
12931292
}
1294-
@@ -2094,7 +2108,7 @@ class SubquerySuite extends QueryTest
1293+
@@ -2094,7 +2107,7 @@ class SubquerySuite extends QueryTest
12951294

12961295
df.collect()
12971296
val exchanges = collect(df.queryExecution.executedPlan) {
@@ -1300,13 +1299,7 @@ index 2e33f6505ab..949fdea0003 100644
13001299
}
13011300
assert(exchanges.size === 1)
13021301
}
1303-
@@ -2674,22 +2688,31 @@ class SubquerySuite extends QueryTest
1304-
}
1305-
}
1306-
1307-
- test("SPARK-43402: FileSourceScanExec supports push down data filter with scalar subquery") {
1308-
+ test("SPARK-43402: FileSourceScanExec supports push down data filter with scalar subquery",
1309-
+ IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3315")) {
1302+
@@ -2678,18 +2691,25 @@ class SubquerySuite extends QueryTest
13101303
def checkFileSourceScan(query: String, answer: Seq[Row]): Unit = {
13111304
val df = sql(query)
13121305
checkAnswer(df, answer)
@@ -1315,6 +1308,7 @@ index 2e33f6505ab..949fdea0003 100644
13151308
+ val dataSourceScanExec = collect(df.queryExecution.executedPlan) {
13161309
+ case f: FileSourceScanLike => f
13171310
+ case c: CometScanExec => c
1311+
+ case n: CometNativeScanExec => n
13181312
}
13191313
sparkContext.listenerBus.waitUntilEmpty()
13201314
- assert(fileSourceScanExec.size === 1)
@@ -1324,13 +1318,11 @@ index 2e33f6505ab..949fdea0003 100644
13241318
+ assert(dataSourceScanExec.size === 1)
13251319
+ val scalarSubquery = dataSourceScanExec.head match {
13261320
+ case f: FileSourceScanLike =>
1327-
+ f.dataFilters.flatMap(_.collect {
1328-
+ case s: ScalarSubquery => s
1329-
+ })
1321+
+ f.dataFilters.flatMap(_.collect { case s: ScalarSubquery => s })
13301322
+ case c: CometScanExec =>
1331-
+ c.dataFilters.flatMap(_.collect {
1332-
+ case s: ScalarSubquery => s
1333-
+ })
1323+
+ c.dataFilters.flatMap(_.collect { case s: ScalarSubquery => s })
1324+
+ case n: CometNativeScanExec =>
1325+
+ n.dataFilters.flatMap(_.collect { case s: ScalarSubquery => s })
13341326
+ }
13351327
assert(scalarSubquery.length === 1)
13361328
assert(scalarSubquery.head.plan.isInstanceOf[ReusedSubqueryExec])

spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,45 @@ import org.apache.spark.sql.execution._
3232
import org.apache.spark.sql.internal.SQLConf
3333

3434
import org.apache.comet.CometConf._
35-
import org.apache.comet.rules.{CometExecRule, CometScanRule, EliminateRedundantTransitions}
35+
import org.apache.comet.rules.{CometExecRule, CometReuseSubquery, CometScanRule, EliminateRedundantTransitions}
3636
import org.apache.comet.shims.ShimCometSparkSessionExtensions
3737

3838
/**
3939
* CometDriverPlugin will register an instance of this class with Spark.
4040
*
41-
* This class is responsible for injecting Comet rules and extensions into Spark.
41+
* Comet rules are injected into Spark's rule pipeline at several extension points. The execution
42+
* order differs between AQE and non-AQE paths:
43+
*
44+
* Non-AQE (QueryExecution.preparations):
45+
* {{{
46+
* 1. PlanDynamicPruningFilters -- Spark creates DPP filters
47+
* 2. PlanSubqueries -- Spark creates SubqueryExec for scalar subqueries
48+
* 3. EnsureRequirements -- Spark inserts shuffles/sorts
49+
* 4. ApplyColumnarRulesAndInsertTransitions:
50+
* a. preColumnarTransitions: CometScanRule, CometExecRule (replace Spark -> Comet nodes)
51+
* b. insertTransitions: ColumnarToRow/RowToColumnar added
52+
* c. postColumnarTransitions: EliminateRedundantTransitions
53+
* 5. ReuseExchangeAndSubquery -- Spark deduplicates subqueries (sees Comet nodes)
54+
* }}}
55+
*
56+
* AQE (AdaptiveSparkPlanExec):
57+
* {{{
58+
* Initial plan:
59+
* queryStagePreparationRules: CometScanRule, CometExecRule (replace Spark -> Comet nodes)
60+
*
61+
* Per stage (optimizeQueryStage + postStageCreationRules):
62+
* 1. queryStageOptimizerRules: ReuseAdaptiveSubquery, CometReuseSubquery
63+
* 2. postStageCreationRules -> ApplyColumnarRulesAndInsertTransitions:
64+
* a. preColumnarTransitions: CometScanRule, CometExecRule (no-ops, already converted)
65+
* b. insertTransitions
66+
* c. postColumnarTransitions: EliminateRedundantTransitions
67+
* }}}
68+
*
69+
* CometReuseSubquery is needed in AQE because Spark's ReuseAdaptiveSubquery may run before
70+
* Comet's node replacements in the initial plan construction, and the replacements can disrupt
71+
* subquery reuse that was already applied. The shim-based registration
72+
* (injectQueryStageOptimizerRuleShim) handles API availability: Spark 3.5+ has
73+
* injectQueryStageOptimizerRule, Spark 3.4 does not (no-op).
4274
*/
4375
class CometSparkSessionExtensions
4476
extends (SparkSessionExtensions => Unit)
@@ -49,6 +81,7 @@ class CometSparkSessionExtensions
4981
extensions.injectColumnar { session => CometExecColumnar(session) }
5082
extensions.injectQueryStagePrepRule { session => CometScanRule(session) }
5183
extensions.injectQueryStagePrepRule { session => CometExecRule(session) }
84+
injectQueryStageOptimizerRuleShim(extensions, CometReuseSubquery)
5285
}
5386

5487
case class CometScanColumnar(session: SparkSession) extends ColumnarRule {
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.rules
21+
22+
import scala.collection.mutable
23+
24+
import org.apache.spark.sql.catalyst.rules.Rule
25+
import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION
26+
import org.apache.spark.sql.execution.{BaseSubqueryExec, ExecSubqueryExpression, ReusedSubqueryExec, SparkPlan}
27+
28+
/**
29+
* Re-applies subquery deduplication after Comet node conversions.
30+
*
31+
* Spark's ReuseAdaptiveSubquery runs as a queryStageOptimizerRule before postStageCreationRules,
32+
* which is where CometScanRule/CometExecRule replace Spark operators with Comet equivalents. The
33+
* Comet rules copy expressions from the original Spark nodes, which can disrupt subquery reuse
34+
* that was already applied by Spark's rule. This rule runs after Comet conversions to restore
35+
* proper deduplication.
36+
*
37+
* Uses the same algorithm as Spark's ReuseExchangeAndSubquery (subquery portion): top-down
38+
* traversal via transformAllExpressionsWithPruning, caching by canonical form.
39+
*
40+
* For non-AQE, Spark's ReuseExchangeAndSubquery runs after ApplyColumnarRulesAndInsertTransitions
41+
* in QueryExecution.preparations and handles reuse correctly without this rule.
42+
*
43+
* @see
44+
* ReuseExchangeAndSubquery (Spark's non-AQE subquery reuse)
45+
* @see
46+
* ReuseAdaptiveSubquery (Spark's AQE subquery reuse)
47+
*/
48+
case object CometReuseSubquery extends Rule[SparkPlan] {
49+
50+
def apply(plan: SparkPlan): SparkPlan = {
51+
if (!conf.subqueryReuseEnabled) {
52+
return plan
53+
}
54+
55+
val cache = mutable.Map.empty[SparkPlan, BaseSubqueryExec]
56+
57+
plan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) {
58+
case sub: ExecSubqueryExpression if !sub.plan.isInstanceOf[ReusedSubqueryExec] =>
59+
val cached = cache.getOrElseUpdate(sub.plan.canonicalized, sub.plan)
60+
if (cached.ne(sub.plan)) {
61+
sub.withNewPlan(ReusedSubqueryExec(cached))
62+
} else {
63+
sub
64+
}
65+
}
66+
}
67+
}

spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan
3030
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
3131
import org.apache.spark.sql.comet.shims.ShimStreamSourceAwareSparkPlan
3232
import org.apache.spark.sql.execution._
33+
import org.apache.spark.sql.execution.{ScalarSubquery => ExecScalarSubquery}
3334
import org.apache.spark.sql.execution.datasources._
3435
import org.apache.spark.sql.execution.metric.SQLMetric
3536
import org.apache.spark.sql.types._
@@ -41,6 +42,7 @@ import com.google.common.base.Objects
4142

4243
import org.apache.comet.parquet.{CometParquetFileFormat, CometParquetUtils}
4344
import org.apache.comet.serde.OperatorOuterClass.Operator
45+
import org.apache.comet.serde.QueryPlanSerde.exprToProto
4446

4547
/**
4648
* Native scan operator for DataSource V1 Parquet files using DataFusion's ParquetExec.
@@ -77,23 +79,30 @@ case class CometNativeScanExec(
7779
override lazy val metadata: Map[String, String] = originalPlan.metadata
7880

7981
/**
80-
* Prepare DPP subquery plans before execution.
82+
* Prepare subquery plans before execution.
8183
*
82-
* For non-AQE DPP, partitionFilters contains DynamicPruningExpression(InSubqueryExec(...))
83-
* inserted by PlanDynamicPruningFilters (which runs before Comet rules). We call
84-
* e.plan.prepare() here so that the subquery plans are set up before execution begins.
84+
* DPP: partitionFilters may contain DynamicPruningExpression(InSubqueryExec(...)) from
85+
* PlanDynamicPruningFilters.
8586
*
86-
* Note: doPrepare() alone is NOT sufficient for DPP resolution. serializedPartitionData can be
87-
* triggered from findAllPlanData (via commonData) on a BroadcastExchangeExec thread, outside
88-
* the normal prepare() -> executeSubqueries() flow. The actual DPP resolution (updateResult)
89-
* happens in serializedPartitionData below.
87+
* Scalar subquery pushdown (SPARK-43402, Spark 4.0+): dataFilters may contain ScalarSubquery.
88+
*
89+
* serializedPartitionData can be triggered outside the normal prepare() -> executeSubqueries()
90+
* flow (e.g., from a BroadcastExchangeExec thread), so we prepare subquery plans here and
91+
* resolve them explicitly in serializedPartitionData via updateResult().
9092
*/
9193
override protected def doPrepare(): Unit = {
9294
partitionFilters.foreach {
9395
case DynamicPruningExpression(e: InSubqueryExec) =>
9496
e.plan.prepare()
9597
case _ =>
9698
}
99+
dataFilters.foreach { f =>
100+
f.foreach {
101+
case s: ExecScalarSubquery =>
102+
s.plan.prepare()
103+
case _ =>
104+
}
105+
}
97106
super.doPrepare()
98107
}
99108

@@ -138,7 +147,7 @@ case class CometNativeScanExec(
138147
//
139148
// originalPlan.inputRDD triggers FileSourceScanExec's full scan pipeline including
140149
// codegen on partition filter expressions. With DPP, this calls
141-
// InSubqueryExec.doGenCode which requires the subquery to have finished but
150+
// InSubqueryExec.doGenCode which requires the subquery to have finished - but
142151
// outputPartitioning can be accessed before prepare() runs (e.g., by
143152
// ValidateRequirements during plan validation).
144153
//
@@ -208,8 +217,40 @@ case class CometNativeScanExec(
208217
case _ =>
209218
}
210219
}
211-
// Extract common data from nativeOp
212-
val commonBytes = nativeOp.getNativeScan.getCommon.toByteArray
220+
// Resolve scalar subqueries in dataFilters and push to the native Parquet reader.
221+
// supportedDataFilters excludes PlanExpression at planning time (unresolved), so these
222+
// aren't in the serialized native plan yet. We resolve them here and append to the
223+
// NativeScanCommon protobuf. Same approach as FileSourceScanLike.pushedDownFilters
224+
// (DataSourceScanExec.scala), which resolves ScalarSubquery -> Literal at execution time.
225+
val commonBytes = {
226+
val base = nativeOp.getNativeScan.getCommon
227+
val scalarSubqueryFilters = dataFilters
228+
.filter(_.exists(_.isInstanceOf[ExecScalarSubquery]))
229+
scalarSubqueryFilters.foreach { f =>
230+
f.foreach {
231+
case s: ExecScalarSubquery =>
232+
s.updateResult()
233+
case _ =>
234+
}
235+
}
236+
val resolvedFilters = scalarSubqueryFilters
237+
.map(_.transform { case s: ExecScalarSubquery =>
238+
Literal.create(s.eval(null), s.dataType)
239+
})
240+
if (resolvedFilters.nonEmpty) {
241+
val commonBuilder = base.toBuilder
242+
for (filter <- resolvedFilters) {
243+
exprToProto(filter, output) match {
244+
case Some(proto) => commonBuilder.addDataFilters(proto)
245+
case _ =>
246+
logWarning(s"Could not serialize resolved scalar subquery filter: $filter")
247+
}
248+
}
249+
commonBuilder.build().toByteArray
250+
} else {
251+
base.toByteArray
252+
}
253+
}
213254

214255
// Get file partitions from CometScanExec (handles bucketing, etc.)
215256
val filePartitions = scan.getFilePartitions()
@@ -299,13 +340,15 @@ case class CometNativeScanExec(
299340
case other: CometNativeScanExec =>
300341
this.originalPlan == other.originalPlan &&
301342
this.serializedPlanOpt == other.serializedPlanOpt &&
302-
this.partitionFilters == other.partitionFilters
343+
this.partitionFilters == other.partitionFilters &&
344+
this.dataFilters == other.dataFilters
303345
case _ =>
304346
false
305347
}
306348
}
307349

308-
override def hashCode(): Int = Objects.hashCode(originalPlan, serializedPlanOpt)
350+
override def hashCode(): Int =
351+
Objects.hashCode(originalPlan, serializedPlanOpt, partitionFilters, dataFilters)
309352

310353
private val driverMetricKeys =
311354
Set(

spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala renamed to spark/src/main/spark-3.4/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,16 @@
1919

2020
package org.apache.comet.shims
2121

22+
import org.apache.spark.sql.SparkSessionExtensions
23+
import org.apache.spark.sql.catalyst.rules.Rule
2224
import org.apache.spark.sql.execution.{QueryExecution, SparkPlan}
2325

2426
trait ShimCometSparkSessionExtensions {
2527

26-
/**
27-
* TODO: delete after dropping Spark 3.x support and directly call
28-
* SQLConf.EXTENDED_EXPLAIN_PROVIDERS.key
29-
*/
3028
protected val EXTENDED_EXPLAIN_PROVIDERS_KEY = "spark.sql.extendedExplainProviders"
3129

32-
// Extended info is available only since Spark 4.0.0
33-
// (https://issues.apache.org/jira/browse/SPARK-47289)
3430
def supportsExtendedExplainInfo(qe: QueryExecution): Boolean = {
3531
try {
36-
// Look for QueryExecution.extendedExplainInfo(scala.Function1[String, Unit], SparkPlan)
3732
qe.getClass.getDeclaredMethod(
3833
"extendedExplainInfo",
3934
classOf[String => Unit],
@@ -43,4 +38,9 @@ trait ShimCometSparkSessionExtensions {
4338
}
4439
true
4540
}
41+
42+
// injectQueryStageOptimizerRule not available on Spark 3.4
43+
def injectQueryStageOptimizerRuleShim(
44+
extensions: SparkSessionExtensions,
45+
rule: Rule[SparkPlan]): Unit = {}
4646
}

0 commit comments

Comments
 (0)