@@ -21,17 +21,25 @@ package org.apache.spark.sql.comet
2121
2222import scala .collection .mutable
2323
24+ import org .apache .spark .SparkConf
2425import org .apache .spark .executor .ShuffleReadMetrics
2526import org .apache .spark .executor .ShuffleWriteMetrics
2627import org .apache .spark .scheduler .SparkListener
2728import org .apache .spark .scheduler .SparkListenerTaskEnd
2829import org .apache .spark .sql .CometTestBase
2930import org .apache .spark .sql .comet .execution .shuffle .CometNativeShuffle
3031import org .apache .spark .sql .comet .execution .shuffle .CometShuffleExchangeExec
32+ import org .apache .spark .sql .execution .SparkPlan
3133import org .apache .spark .sql .execution .adaptive .AdaptiveSparkPlanHelper
3234
35+ import org .apache .comet .CometConf
36+
3337class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper {
3438
39+ override protected def sparkConf : SparkConf = {
40+ super .sparkConf.set(" spark.ui.enabled" , " true" )
41+ }
42+
3543 import testImplicits ._
3644
3745 test(" per-task native shuffle metrics" ) {
@@ -91,4 +99,189 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper {
9199 }
92100 }
93101 }
102+
103+ test(" native_datafusion scan reports task-level input metrics matching Spark" ) {
104+ val totalRows = 10000
105+ withTempPath { dir =>
106+ spark
107+ .createDataFrame((0 until totalRows).map(i => (i, s " elem_ $i" )))
108+ .repartition(5 )
109+ .write
110+ .parquet(dir.getAbsolutePath)
111+ spark.read.parquet(dir.getAbsolutePath).createOrReplaceTempView(" tbl" )
112+ // Collect baseline input metrics from vanilla Spark (Comet disabled)
113+ val (sparkBytes, sparkRecords, _) =
114+ collectInputMetrics(
115+ " SELECT * FROM tbl where _1 > 2000" ,
116+ CometConf .COMET_ENABLED .key -> " false" )
117+
118+ // Collect input metrics from Comet native_datafusion scan.
119+ val (cometBytes, cometRecords, cometPlan) = collectInputMetrics(
120+ " SELECT * FROM tbl where _1 > 2000" ,
121+ CometConf .COMET_ENABLED .key -> " true" ,
122+ CometConf .COMET_NATIVE_SCAN_IMPL .key -> CometConf .SCAN_NATIVE_DATAFUSION )
123+
124+ // Verify the plan actually used CometNativeScanExec
125+ assert(
126+ find(cometPlan)(_.isInstanceOf [CometNativeScanExec ]).isDefined,
127+ s " Expected CometNativeScanExec in plan: \n ${cometPlan.treeString}" )
128+
129+ assert(sparkRecords > 0 , s " Spark outputRecords should be > 0, got $sparkRecords" )
130+ assert(cometRecords > 0 , s " Comet outputRecords should be > 0, got $cometRecords" )
131+
132+ assert(
133+ cometRecords == sparkRecords,
134+ s " recordsRead mismatch: comet= $cometRecords, sparkRecords= $sparkRecords" )
135+
136+ // Bytes should be in the same ballpark -- both read the same Parquet file(s),
137+ // but the exact byte count can differ due to reader implementation details
138+ // (e.g. footer reads, page headers, buffering granularity).
139+ assert(sparkBytes > 0 , s " Spark bytesRead should be > 0, got $sparkBytes" )
140+ assert(cometBytes > 0 , s " Comet bytesRead should be > 0, got $cometBytes" )
141+ val ratio = cometBytes.toDouble / sparkBytes.toDouble
142+ assert(
143+ ratio >= 0.7 && ratio <= 1.3 ,
144+ s " bytesRead ratio out of range: comet= $cometBytes, spark= $sparkBytes, ratio= $ratio" )
145+ }
146+ }
147+
148+ test(" input metrics aggregate across multiple native scans in a join" ) {
149+ withTempPath { dir1 =>
150+ withTempPath { dir2 =>
151+ // Create two separate parquet tables
152+ spark
153+ .createDataFrame((0 until 5000 ).map(i => (i, s " left_ $i" )))
154+ .repartition(3 )
155+ .write
156+ .parquet(dir1.getAbsolutePath)
157+ spark
158+ .createDataFrame((0 until 5000 ).map(i => (i, s " right_ $i" )))
159+ .repartition(3 )
160+ .write
161+ .parquet(dir2.getAbsolutePath)
162+
163+ spark.read.parquet(dir1.getAbsolutePath).createOrReplaceTempView(" left_tbl" )
164+ spark.read.parquet(dir2.getAbsolutePath).createOrReplaceTempView(" right_tbl" )
165+
166+ val joinQuery = " SELECT * FROM left_tbl JOIN right_tbl ON left_tbl._1 = right_tbl._1"
167+
168+ // Collect baseline from vanilla Spark
169+ val (sparkBytes, sparkRecords, _) =
170+ collectInputMetrics(joinQuery, CometConf .COMET_ENABLED .key -> " false" )
171+
172+ // Collect from Comet native scan
173+ val (cometBytes, cometRecords, cometPlan) = collectInputMetrics(
174+ joinQuery,
175+ CometConf .COMET_ENABLED .key -> " true" ,
176+ CometConf .COMET_NATIVE_SCAN_IMPL .key -> CometConf .SCAN_NATIVE_DATAFUSION )
177+
178+ // Verify the plan has multiple CometNativeScanExec nodes
179+ val scanCount = collect(cometPlan) { case s : CometNativeScanExec =>
180+ s
181+ }.size
182+ assert(
183+ scanCount >= 2 ,
184+ s " Expected at least 2 CometNativeScanExec in plan, found $scanCount: \n " +
185+ cometPlan.treeString)
186+
187+ assert(sparkBytes > 0 , s " Spark bytesRead should be > 0, got $sparkBytes" )
188+ assert(cometBytes > 0 , s " Comet bytesRead should be > 0, got $cometBytes" )
189+ assert(sparkRecords > 0 , s " Spark recordsRead should be > 0, got $sparkRecords" )
190+ assert(cometRecords > 0 , s " Comet recordsRead should be > 0, got $cometRecords" )
191+
192+ // Both sides should contribute to the total bytes
193+ val ratio = cometBytes.toDouble / sparkBytes.toDouble
194+ assert(
195+ ratio >= 0.7 && ratio <= 1.3 ,
196+ s " bytesRead ratio out of range: comet= $cometBytes, spark= $sparkBytes, ratio= $ratio" )
197+ }
198+ }
199+ }
200+
201+ test(" input metrics aggregate across multiple native scans in a union" ) {
202+ withTempPath { dir1 =>
203+ withTempPath { dir2 =>
204+ spark
205+ .createDataFrame((0 until 5000 ).map(i => (i, s " left_ $i" )))
206+ .repartition(3 )
207+ .write
208+ .parquet(dir1.getAbsolutePath)
209+ spark
210+ .createDataFrame((5000 until 10000 ).map(i => (i, s " right_ $i" )))
211+ .repartition(3 )
212+ .write
213+ .parquet(dir2.getAbsolutePath)
214+
215+ spark.read.parquet(dir1.getAbsolutePath).createOrReplaceTempView(" union_left" )
216+ spark.read.parquet(dir2.getAbsolutePath).createOrReplaceTempView(" union_right" )
217+
218+ val unionQuery = " SELECT * FROM union_left UNION ALL SELECT * FROM union_right"
219+
220+ // Collect baseline from vanilla Spark
221+ val (sparkBytes, sparkRecords, _) =
222+ collectInputMetrics(unionQuery, CometConf .COMET_ENABLED .key -> " false" )
223+
224+ // Collect from Comet native scan
225+ val (cometBytes, cometRecords, cometPlan) = collectInputMetrics(
226+ unionQuery,
227+ CometConf .COMET_ENABLED .key -> " true" ,
228+ CometConf .COMET_NATIVE_SCAN_IMPL .key -> CometConf .SCAN_NATIVE_DATAFUSION )
229+
230+ // Verify the plan has multiple CometNativeScanExec nodes
231+ val scanCount = collect(cometPlan) { case s : CometNativeScanExec =>
232+ s
233+ }.size
234+ assert(
235+ scanCount >= 2 ,
236+ s " Expected at least 2 CometNativeScanExec in plan, found $scanCount: \n " +
237+ cometPlan.treeString)
238+
239+ assert(sparkBytes > 0 , s " Spark bytesRead should be > 0, got $sparkBytes" )
240+ assert(cometBytes > 0 , s " Comet bytesRead should be > 0, got $cometBytes" )
241+ assert(sparkRecords > 0 , s " Spark recordsRead should be > 0, got $sparkRecords" )
242+ assert(cometRecords > 0 , s " Comet recordsRead should be > 0, got $cometRecords" )
243+
244+ val ratio = cometBytes.toDouble / sparkBytes.toDouble
245+ assert(
246+ ratio >= 0.7 && ratio <= 1.3 ,
247+ s " bytesRead ratio out of range: comet= $cometBytes, spark= $sparkBytes, ratio= $ratio" )
248+ }
249+ }
250+ }
251+
252+ /**
253+ * Runs the given query with the given SQL config overrides and returns the aggregated
254+ * (bytesRead, recordsRead) across all tasks, along with the executed plan.
255+ *
256+ * Uses AppStatusStore (same source as Spark UI) to read task-level input metrics.
257+ * AppStatusStore stores immutable snapshots of metric values, unlike SparkListener's
258+ * InputMetrics which are backed by mutable accumulators that can be reset.
259+ */
260+ private def collectInputMetrics (
261+ query : String ,
262+ confs : (String , String )* ): (Long , Long , SparkPlan ) = {
263+ val store = spark.sparkContext.statusStore
264+
265+ // Record existing stage IDs so we only look at stages from our query
266+ val stagesBefore = store.stageList(null ).map(_.stageId).toSet
267+
268+ var plan : SparkPlan = null
269+ withSQLConf(confs : _* ) {
270+ val df = sql(query)
271+ df.collect()
272+ plan = stripAQEPlan(df.queryExecution.executedPlan)
273+ }
274+
275+ // Wait for listener bus to flush all events into the status store
276+ spark.sparkContext.listenerBus.waitUntilEmpty()
277+
278+ // Sum input metrics from stages created by our query
279+ val newStages = store.stageList(null ).filterNot(s => stagesBefore.contains(s.stageId))
280+ assert(newStages.nonEmpty, s " No new stages found for confs= $confs" )
281+
282+ val totalBytes = newStages.map(_.inputBytes).sum
283+ val totalRecords = newStages.map(_.inputRecords).sum
284+
285+ (totalBytes, totalRecords, plan)
286+ }
94287}
0 commit comments