Skip to content

Commit 99aed8e

Browse files
authored
chore: native_datafusion to report scan task input metrics (#3842)
1 parent 90633dc commit 99aed8e

4 files changed

Lines changed: 276 additions & 14 deletions

File tree

spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ package org.apache.spark.sql.comet
2121

2222
import scala.jdk.CollectionConverters._
2323

24-
import org.apache.spark.SparkContext
24+
import org.apache.spark.{SparkContext, TaskContext}
2525
import org.apache.spark.internal.Logging
2626
import org.apache.spark.sql.execution.SparkPlan
2727
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
@@ -41,6 +41,49 @@ import org.apache.comet.serde.Metric
4141
case class CometMetricNode(metrics: Map[String, SQLMetric], children: Seq[CometMetricNode])
4242
extends Logging {
4343

44+
/**
45+
* Returns the leaf node (deepest single-child descendant). For a native scan plan like
46+
* FilterExec -> DataSourceExec, this returns the DataSourceExec node which has the
47+
* bytes_scanned and output_rows metrics from the Parquet reader.
48+
*/
49+
def leafNode: CometMetricNode = {
50+
if (children.isEmpty) this
51+
else children.head.leafNode
52+
}
53+
54+
/**
55+
* Returns all leaf nodes (nodes with no children) in the metric tree. Unlike [[leafNode]] which
56+
* only follows the first child, this finds all leaves, which is needed for plans with multiple
57+
* scans (e.g., joins, unions).
58+
*/
59+
def leafNodes: Seq[CometMetricNode] = {
60+
if (children.isEmpty) Seq(this)
61+
else children.flatMap(_.leafNodes)
62+
}
63+
64+
/**
65+
* Reports aggregated scan input metrics (bytesRead, recordsRead) to Spark's task metrics.
66+
* Aggregates across all scan leaf nodes to handle plans with multiple scans (e.g., joins). Must
67+
* be called in a TaskCompletionListener after the iterator is fully consumed.
68+
*/
69+
def reportScanInputMetrics(ctx: TaskContext): Unit = {
70+
ctx.addTaskCompletionListener[Unit] { _ =>
71+
val scanLeaves = leafNodes.filter(_.metrics.contains("bytes_scanned"))
72+
if (scanLeaves.nonEmpty) {
73+
val totalBytes = scanLeaves.map(_.metrics("bytes_scanned").value).sum
74+
val totalRows = scanLeaves.map { leaf =>
75+
val outputRows =
76+
leaf.metrics.get("output_rows").map(_.value).getOrElse(0L)
77+
val prunedRows =
78+
leaf.metrics.get("pushdown_rows_pruned").map(_.value).getOrElse(0L)
79+
outputRows + prunedRows
80+
}.sum
81+
ctx.taskMetrics().inputMetrics.setBytesRead(totalBytes)
82+
ctx.taskMetrics().inputMetrics.setRecordsRead(totalRows)
83+
}
84+
}
85+
}
86+
4487
/**
4588
* Gets a child node. Called from native.
4689
*/
@@ -79,6 +122,7 @@ case class CometMetricNode(metrics: Map[String, SQLMetric], children: Seq[CometM
79122
}
80123
}
81124

125+
// Called via JNI from `comet_metric_node.rs`
82126
def set_all_from_bytes(bytes: Array[Byte]): Unit = {
83127
val metricNode = Metric.NativeMetricNode.parseFrom(bytes)
84128
set_all(metricNode)

spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ package org.apache.spark.sql.comet
2121

2222
import scala.reflect.ClassTag
2323

24+
import org.apache.spark.{Partition, TaskContext}
2425
import org.apache.spark.rdd.RDD
2526
import org.apache.spark.sql.SparkSession
2627
import org.apache.spark.sql.catalyst._
@@ -180,18 +181,27 @@ case class CometNativeScanExec(
180181
(None, Seq.empty)
181182
}
182183

183-
CometExecRDD(
184+
new CometExecRDD(
184185
sparkContext,
185-
inputRDDs = Seq.empty,
186-
commonByKey = Map(sourceKey -> commonData),
187-
perPartitionByKey = Map(sourceKey -> perPartitionData),
188-
serializedPlan = serializedPlan,
189-
numPartitions = perPartitionData.length,
190-
numOutputCols = output.length,
191-
nativeMetrics = nativeMetrics,
192-
subqueries = Seq.empty,
193-
broadcastedHadoopConfForEncryption = broadcastedHadoopConfForEncryption,
194-
encryptedFilePaths = encryptedFilePaths)
186+
Seq.empty,
187+
Map(sourceKey -> commonData),
188+
Map(sourceKey -> perPartitionData),
189+
serializedPlan,
190+
perPartitionData.length,
191+
output.length,
192+
nativeMetrics,
193+
Seq.empty,
194+
broadcastedHadoopConfForEncryption,
195+
encryptedFilePaths) {
196+
override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
197+
val res = super.compute(split, context)
198+
199+
// Report scan input metrics after the iterator is fully consumed.
200+
Option(context).foreach(nativeMetrics.reportScanInputMetrics)
201+
202+
res
203+
}
204+
}
195205
}
196206

197207
override def doCanonicalize(): CometNativeScanExec = {

spark/src/main/scala/org/apache/spark/sql/comet/operators.scala

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import scala.collection.mutable
2525
import scala.collection.mutable.ArrayBuffer
2626
import scala.jdk.CollectionConverters._
2727

28+
import org.apache.spark.{Partition, TaskContext}
2829
import org.apache.spark.broadcast.Broadcast
2930
import org.apache.spark.rdd.RDD
3031
import org.apache.spark.sql.catalyst.InternalRow
@@ -558,7 +559,8 @@ abstract class CometNativeExec extends CometExec {
558559

559560
// Unified RDD creation - CometExecRDD handles all cases
560561
val subqueries = collectSubqueries(this)
561-
CometExecRDD(
562+
val hasScanInput = sparkPlans.exists(_.isInstanceOf[CometNativeScanExec])
563+
new CometExecRDD(
562564
sparkContext,
563565
inputs.toSeq,
564566
commonByKey,
@@ -570,7 +572,20 @@ abstract class CometNativeExec extends CometExec {
570572
subqueries,
571573
broadcastedHadoopConfForEncryption,
572574
encryptedFilePaths,
573-
shuffleScanIndices)
575+
shuffleScanIndices) {
576+
override def compute(
577+
split: Partition,
578+
context: TaskContext): Iterator[ColumnarBatch] = {
579+
val res = super.compute(split, context)
580+
581+
// Report scan input metrics only when the native plan contains a scan.
582+
if (hasScanInput) {
583+
Option(context).foreach(nativeMetrics.reportScanInputMetrics)
584+
}
585+
586+
res
587+
}
588+
}
574589
}
575590
}
576591

spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,25 @@ package org.apache.spark.sql.comet
2121

2222
import scala.collection.mutable
2323

24+
import org.apache.spark.SparkConf
2425
import org.apache.spark.executor.ShuffleReadMetrics
2526
import org.apache.spark.executor.ShuffleWriteMetrics
2627
import org.apache.spark.scheduler.SparkListener
2728
import org.apache.spark.scheduler.SparkListenerTaskEnd
2829
import org.apache.spark.sql.CometTestBase
2930
import org.apache.spark.sql.comet.execution.shuffle.CometNativeShuffle
3031
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
32+
import org.apache.spark.sql.execution.SparkPlan
3133
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
3234

35+
import org.apache.comet.CometConf
36+
3337
class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper {
3438

39+
override protected def sparkConf: SparkConf = {
40+
super.sparkConf.set("spark.ui.enabled", "true")
41+
}
42+
3543
import testImplicits._
3644

3745
test("per-task native shuffle metrics") {
@@ -91,4 +99,189 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper {
9199
}
92100
}
93101
}
102+
103+
test("native_datafusion scan reports task-level input metrics matching Spark") {
104+
val totalRows = 10000
105+
withTempPath { dir =>
106+
spark
107+
.createDataFrame((0 until totalRows).map(i => (i, s"elem_$i")))
108+
.repartition(5)
109+
.write
110+
.parquet(dir.getAbsolutePath)
111+
spark.read.parquet(dir.getAbsolutePath).createOrReplaceTempView("tbl")
112+
// Collect baseline input metrics from vanilla Spark (Comet disabled)
113+
val (sparkBytes, sparkRecords, _) =
114+
collectInputMetrics(
115+
"SELECT * FROM tbl where _1 > 2000",
116+
CometConf.COMET_ENABLED.key -> "false")
117+
118+
// Collect input metrics from Comet native_datafusion scan.
119+
val (cometBytes, cometRecords, cometPlan) = collectInputMetrics(
120+
"SELECT * FROM tbl where _1 > 2000",
121+
CometConf.COMET_ENABLED.key -> "true",
122+
CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION)
123+
124+
// Verify the plan actually used CometNativeScanExec
125+
assert(
126+
find(cometPlan)(_.isInstanceOf[CometNativeScanExec]).isDefined,
127+
s"Expected CometNativeScanExec in plan:\n${cometPlan.treeString}")
128+
129+
assert(sparkRecords > 0, s"Spark outputRecords should be > 0, got $sparkRecords")
130+
assert(cometRecords > 0, s"Comet outputRecords should be > 0, got $cometRecords")
131+
132+
assert(
133+
cometRecords == sparkRecords,
134+
s"recordsRead mismatch: comet=$cometRecords, sparkRecords=$sparkRecords")
135+
136+
// Bytes should be in the same ballpark -- both read the same Parquet file(s),
137+
// but the exact byte count can differ due to reader implementation details
138+
// (e.g. footer reads, page headers, buffering granularity).
139+
assert(sparkBytes > 0, s"Spark bytesRead should be > 0, got $sparkBytes")
140+
assert(cometBytes > 0, s"Comet bytesRead should be > 0, got $cometBytes")
141+
val ratio = cometBytes.toDouble / sparkBytes.toDouble
142+
assert(
143+
ratio >= 0.7 && ratio <= 1.3,
144+
s"bytesRead ratio out of range: comet=$cometBytes, spark=$sparkBytes, ratio=$ratio")
145+
}
146+
}
147+
148+
test("input metrics aggregate across multiple native scans in a join") {
149+
withTempPath { dir1 =>
150+
withTempPath { dir2 =>
151+
// Create two separate parquet tables
152+
spark
153+
.createDataFrame((0 until 5000).map(i => (i, s"left_$i")))
154+
.repartition(3)
155+
.write
156+
.parquet(dir1.getAbsolutePath)
157+
spark
158+
.createDataFrame((0 until 5000).map(i => (i, s"right_$i")))
159+
.repartition(3)
160+
.write
161+
.parquet(dir2.getAbsolutePath)
162+
163+
spark.read.parquet(dir1.getAbsolutePath).createOrReplaceTempView("left_tbl")
164+
spark.read.parquet(dir2.getAbsolutePath).createOrReplaceTempView("right_tbl")
165+
166+
val joinQuery = "SELECT * FROM left_tbl JOIN right_tbl ON left_tbl._1 = right_tbl._1"
167+
168+
// Collect baseline from vanilla Spark
169+
val (sparkBytes, sparkRecords, _) =
170+
collectInputMetrics(joinQuery, CometConf.COMET_ENABLED.key -> "false")
171+
172+
// Collect from Comet native scan
173+
val (cometBytes, cometRecords, cometPlan) = collectInputMetrics(
174+
joinQuery,
175+
CometConf.COMET_ENABLED.key -> "true",
176+
CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION)
177+
178+
// Verify the plan has multiple CometNativeScanExec nodes
179+
val scanCount = collect(cometPlan) { case s: CometNativeScanExec =>
180+
s
181+
}.size
182+
assert(
183+
scanCount >= 2,
184+
s"Expected at least 2 CometNativeScanExec in plan, found $scanCount:\n" +
185+
cometPlan.treeString)
186+
187+
assert(sparkBytes > 0, s"Spark bytesRead should be > 0, got $sparkBytes")
188+
assert(cometBytes > 0, s"Comet bytesRead should be > 0, got $cometBytes")
189+
assert(sparkRecords > 0, s"Spark recordsRead should be > 0, got $sparkRecords")
190+
assert(cometRecords > 0, s"Comet recordsRead should be > 0, got $cometRecords")
191+
192+
// Both sides should contribute to the total bytes
193+
val ratio = cometBytes.toDouble / sparkBytes.toDouble
194+
assert(
195+
ratio >= 0.7 && ratio <= 1.3,
196+
s"bytesRead ratio out of range: comet=$cometBytes, spark=$sparkBytes, ratio=$ratio")
197+
}
198+
}
199+
}
200+
201+
test("input metrics aggregate across multiple native scans in a union") {
202+
withTempPath { dir1 =>
203+
withTempPath { dir2 =>
204+
spark
205+
.createDataFrame((0 until 5000).map(i => (i, s"left_$i")))
206+
.repartition(3)
207+
.write
208+
.parquet(dir1.getAbsolutePath)
209+
spark
210+
.createDataFrame((5000 until 10000).map(i => (i, s"right_$i")))
211+
.repartition(3)
212+
.write
213+
.parquet(dir2.getAbsolutePath)
214+
215+
spark.read.parquet(dir1.getAbsolutePath).createOrReplaceTempView("union_left")
216+
spark.read.parquet(dir2.getAbsolutePath).createOrReplaceTempView("union_right")
217+
218+
val unionQuery = "SELECT * FROM union_left UNION ALL SELECT * FROM union_right"
219+
220+
// Collect baseline from vanilla Spark
221+
val (sparkBytes, sparkRecords, _) =
222+
collectInputMetrics(unionQuery, CometConf.COMET_ENABLED.key -> "false")
223+
224+
// Collect from Comet native scan
225+
val (cometBytes, cometRecords, cometPlan) = collectInputMetrics(
226+
unionQuery,
227+
CometConf.COMET_ENABLED.key -> "true",
228+
CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION)
229+
230+
// Verify the plan has multiple CometNativeScanExec nodes
231+
val scanCount = collect(cometPlan) { case s: CometNativeScanExec =>
232+
s
233+
}.size
234+
assert(
235+
scanCount >= 2,
236+
s"Expected at least 2 CometNativeScanExec in plan, found $scanCount:\n" +
237+
cometPlan.treeString)
238+
239+
assert(sparkBytes > 0, s"Spark bytesRead should be > 0, got $sparkBytes")
240+
assert(cometBytes > 0, s"Comet bytesRead should be > 0, got $cometBytes")
241+
assert(sparkRecords > 0, s"Spark recordsRead should be > 0, got $sparkRecords")
242+
assert(cometRecords > 0, s"Comet recordsRead should be > 0, got $cometRecords")
243+
244+
val ratio = cometBytes.toDouble / sparkBytes.toDouble
245+
assert(
246+
ratio >= 0.7 && ratio <= 1.3,
247+
s"bytesRead ratio out of range: comet=$cometBytes, spark=$sparkBytes, ratio=$ratio")
248+
}
249+
}
250+
}
251+
252+
/**
253+
* Runs the given query with the given SQL config overrides and returns the aggregated
254+
* (bytesRead, recordsRead) across all tasks, along with the executed plan.
255+
*
256+
* Uses AppStatusStore (same source as Spark UI) to read task-level input metrics.
257+
* AppStatusStore stores immutable snapshots of metric values, unlike SparkListener's
258+
* InputMetrics which are backed by mutable accumulators that can be reset.
259+
*/
260+
private def collectInputMetrics(
261+
query: String,
262+
confs: (String, String)*): (Long, Long, SparkPlan) = {
263+
val store = spark.sparkContext.statusStore
264+
265+
// Record existing stage IDs so we only look at stages from our query
266+
val stagesBefore = store.stageList(null).map(_.stageId).toSet
267+
268+
var plan: SparkPlan = null
269+
withSQLConf(confs: _*) {
270+
val df = sql(query)
271+
df.collect()
272+
plan = stripAQEPlan(df.queryExecution.executedPlan)
273+
}
274+
275+
// Wait for listener bus to flush all events into the status store
276+
spark.sparkContext.listenerBus.waitUntilEmpty()
277+
278+
// Sum input metrics from stages created by our query
279+
val newStages = store.stageList(null).filterNot(s => stagesBefore.contains(s.stageId))
280+
assert(newStages.nonEmpty, s"No new stages found for confs=$confs")
281+
282+
val totalBytes = newStages.map(_.inputBytes).sum
283+
val totalRecords = newStages.map(_.inputRecords).sum
284+
285+
(totalBytes, totalRecords, plan)
286+
}
94287
}

0 commit comments

Comments
 (0)