Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package org.apache.spark.sql.comet

import scala.jdk.CollectionConverters._

import org.apache.spark.SparkContext
import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
Expand All @@ -41,6 +41,49 @@ import org.apache.comet.serde.Metric
case class CometMetricNode(metrics: Map[String, SQLMetric], children: Seq[CometMetricNode])
extends Logging {

/**
* Returns the leaf node (deepest single-child descendant). For a native scan plan like
* FilterExec -> DataSourceExec, this returns the DataSourceExec node which has the
* bytes_scanned and output_rows metrics from the Parquet reader.
*/
def leafNode: CometMetricNode = {
if (children.isEmpty) this
else children.head.leafNode
}

/**
* Returns all leaf nodes (nodes with no children) in the metric tree. Unlike [[leafNode]] which
* only follows the first child, this finds all leaves, which is needed for plans with multiple
* scans (e.g., joins, unions).
*/
def leafNodes: Seq[CometMetricNode] = {
if (children.isEmpty) Seq(this)
else children.flatMap(_.leafNodes)
}

/**
* Reports aggregated scan input metrics (bytesRead, recordsRead) to Spark's task metrics.
* Aggregates across all scan leaf nodes to handle plans with multiple scans (e.g., joins). Must
* be called in a TaskCompletionListener after the iterator is fully consumed.
*/
def reportScanInputMetrics(ctx: TaskContext): Unit = {
ctx.addTaskCompletionListener[Unit] { _ =>
val scanLeaves = leafNodes.filter(_.metrics.contains("bytes_scanned"))
if (scanLeaves.nonEmpty) {
val totalBytes = scanLeaves.map(_.metrics("bytes_scanned").value).sum
val totalRows = scanLeaves.map { leaf =>
val outputRows =
leaf.metrics.get("output_rows").map(_.value).getOrElse(0L)
val prunedRows =
leaf.metrics.get("pushdown_rows_pruned").map(_.value).getOrElse(0L)
outputRows + prunedRows
}.sum
ctx.taskMetrics().inputMetrics.setBytesRead(totalBytes)
ctx.taskMetrics().inputMetrics.setRecordsRead(totalRows)
}
}
}

/**
* Gets a child node. Called from native.
*/
Expand Down Expand Up @@ -79,6 +122,7 @@ case class CometMetricNode(metrics: Map[String, SQLMetric], children: Seq[CometM
}
}

// Called via JNI from `comet_metric_node.rs`
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that the only place this will ever be called from? Otherwise I'm not sure the comment is necessary.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDE highlights the method as unused because it is called via JNI only, can be accidentally cleaned up. Added comments to clarify

def set_all_from_bytes(bytes: Array[Byte]): Unit = {
val metricNode = Metric.NativeMetricNode.parseFrom(bytes)
set_all(metricNode)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package org.apache.spark.sql.comet

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst._
Expand Down Expand Up @@ -180,18 +181,27 @@ case class CometNativeScanExec(
(None, Seq.empty)
}

CometExecRDD(
new CometExecRDD(
sparkContext,
inputRDDs = Seq.empty,
commonByKey = Map(sourceKey -> commonData),
perPartitionByKey = Map(sourceKey -> perPartitionData),
serializedPlan = serializedPlan,
numPartitions = perPartitionData.length,
numOutputCols = output.length,
nativeMetrics = nativeMetrics,
subqueries = Seq.empty,
broadcastedHadoopConfForEncryption = broadcastedHadoopConfForEncryption,
encryptedFilePaths = encryptedFilePaths)
Seq.empty,
Map(sourceKey -> commonData),
Map(sourceKey -> perPartitionData),
serializedPlan,
perPartitionData.length,
output.length,
nativeMetrics,
Seq.empty,
broadcastedHadoopConfForEncryption,
encryptedFilePaths) {
override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
val res = super.compute(split, context)

// Report scan input metrics after the iterator is fully consumed.
Option(context).foreach(nativeMetrics.reportScanInputMetrics)

res
}
}
}

override def doCanonicalize(): CometNativeScanExec = {
Expand Down
19 changes: 17 additions & 2 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -558,7 +559,8 @@ abstract class CometNativeExec extends CometExec {

// Unified RDD creation - CometExecRDD handles all cases
val subqueries = collectSubqueries(this)
CometExecRDD(
val hasScanInput = sparkPlans.exists(_.isInstanceOf[CometNativeScanExec])
new CometExecRDD(
sparkContext,
inputs.toSeq,
commonByKey,
Expand All @@ -570,7 +572,20 @@ abstract class CometNativeExec extends CometExec {
subqueries,
broadcastedHadoopConfForEncryption,
encryptedFilePaths,
shuffleScanIndices)
shuffleScanIndices) {
override def compute(
split: Partition,
context: TaskContext): Iterator[ColumnarBatch] = {
val res = super.compute(split, context)

// Report scan input metrics only when the native plan contains a scan.
if (hasScanInput) {
Option(context).foreach(nativeMetrics.reportScanInputMetrics)
}

res
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,25 @@ package org.apache.spark.sql.comet

import scala.collection.mutable

import org.apache.spark.SparkConf
import org.apache.spark.executor.ShuffleReadMetrics
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.SparkListener
import org.apache.spark.scheduler.SparkListenerTaskEnd
import org.apache.spark.sql.CometTestBase
import org.apache.spark.sql.comet.execution.shuffle.CometNativeShuffle
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper

import org.apache.comet.CometConf

class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper {

override protected def sparkConf: SparkConf = {
super.sparkConf.set("spark.ui.enabled", "true")
}

import testImplicits._

test("per-task native shuffle metrics") {
Expand Down Expand Up @@ -91,4 +99,189 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}
}

test("native_datafusion scan reports task-level input metrics matching Spark") {
val totalRows = 10000
withTempPath { dir =>
spark
.createDataFrame((0 until totalRows).map(i => (i, s"elem_$i")))
.repartition(5)
.write
.parquet(dir.getAbsolutePath)
spark.read.parquet(dir.getAbsolutePath).createOrReplaceTempView("tbl")
// Collect baseline input metrics from vanilla Spark (Comet disabled)
val (sparkBytes, sparkRecords, _) =
collectInputMetrics(
"SELECT * FROM tbl where _1 > 2000",
CometConf.COMET_ENABLED.key -> "false")

// Collect input metrics from Comet native_datafusion scan.
val (cometBytes, cometRecords, cometPlan) = collectInputMetrics(
"SELECT * FROM tbl where _1 > 2000",
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION)

// Verify the plan actually used CometNativeScanExec
assert(
find(cometPlan)(_.isInstanceOf[CometNativeScanExec]).isDefined,
s"Expected CometNativeScanExec in plan:\n${cometPlan.treeString}")

assert(sparkRecords > 0, s"Spark outputRecords should be > 0, got $sparkRecords")
assert(cometRecords > 0, s"Comet outputRecords should be > 0, got $cometRecords")

assert(
cometRecords == sparkRecords,
s"recordsRead mismatch: comet=$cometRecords, sparkRecords=$sparkRecords")

// Bytes should be in the same ballpark -- both read the same Parquet file(s),
// but the exact byte count can differ due to reader implementation details
// (e.g. footer reads, page headers, buffering granularity).
assert(sparkBytes > 0, s"Spark bytesRead should be > 0, got $sparkBytes")
assert(cometBytes > 0, s"Comet bytesRead should be > 0, got $cometBytes")
val ratio = cometBytes.toDouble / sparkBytes.toDouble
assert(
ratio >= 0.7 && ratio <= 1.3,
s"bytesRead ratio out of range: comet=$cometBytes, spark=$sparkBytes, ratio=$ratio")
}
}

test("input metrics aggregate across multiple native scans in a join") {
withTempPath { dir1 =>
withTempPath { dir2 =>
// Create two separate parquet tables
spark
.createDataFrame((0 until 5000).map(i => (i, s"left_$i")))
.repartition(3)
.write
.parquet(dir1.getAbsolutePath)
spark
.createDataFrame((0 until 5000).map(i => (i, s"right_$i")))
.repartition(3)
.write
.parquet(dir2.getAbsolutePath)

spark.read.parquet(dir1.getAbsolutePath).createOrReplaceTempView("left_tbl")
spark.read.parquet(dir2.getAbsolutePath).createOrReplaceTempView("right_tbl")

val joinQuery = "SELECT * FROM left_tbl JOIN right_tbl ON left_tbl._1 = right_tbl._1"

// Collect baseline from vanilla Spark
val (sparkBytes, sparkRecords, _) =
collectInputMetrics(joinQuery, CometConf.COMET_ENABLED.key -> "false")

// Collect from Comet native scan
val (cometBytes, cometRecords, cometPlan) = collectInputMetrics(
joinQuery,
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION)

// Verify the plan has multiple CometNativeScanExec nodes
val scanCount = collect(cometPlan) { case s: CometNativeScanExec =>
s
}.size
assert(
scanCount >= 2,
s"Expected at least 2 CometNativeScanExec in plan, found $scanCount:\n" +
cometPlan.treeString)

assert(sparkBytes > 0, s"Spark bytesRead should be > 0, got $sparkBytes")
assert(cometBytes > 0, s"Comet bytesRead should be > 0, got $cometBytes")
assert(sparkRecords > 0, s"Spark recordsRead should be > 0, got $sparkRecords")
assert(cometRecords > 0, s"Comet recordsRead should be > 0, got $cometRecords")

// Both sides should contribute to the total bytes
val ratio = cometBytes.toDouble / sparkBytes.toDouble
assert(
ratio >= 0.7 && ratio <= 1.3,
s"bytesRead ratio out of range: comet=$cometBytes, spark=$sparkBytes, ratio=$ratio")
}
}
}

test("input metrics aggregate across multiple native scans in a union") {
withTempPath { dir1 =>
withTempPath { dir2 =>
spark
.createDataFrame((0 until 5000).map(i => (i, s"left_$i")))
.repartition(3)
.write
.parquet(dir1.getAbsolutePath)
spark
.createDataFrame((5000 until 10000).map(i => (i, s"right_$i")))
.repartition(3)
.write
.parquet(dir2.getAbsolutePath)

spark.read.parquet(dir1.getAbsolutePath).createOrReplaceTempView("union_left")
spark.read.parquet(dir2.getAbsolutePath).createOrReplaceTempView("union_right")

val unionQuery = "SELECT * FROM union_left UNION ALL SELECT * FROM union_right"

// Collect baseline from vanilla Spark
val (sparkBytes, sparkRecords, _) =
collectInputMetrics(unionQuery, CometConf.COMET_ENABLED.key -> "false")

// Collect from Comet native scan
val (cometBytes, cometRecords, cometPlan) = collectInputMetrics(
unionQuery,
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION)

// Verify the plan has multiple CometNativeScanExec nodes
val scanCount = collect(cometPlan) { case s: CometNativeScanExec =>
s
}.size
assert(
scanCount >= 2,
s"Expected at least 2 CometNativeScanExec in plan, found $scanCount:\n" +
cometPlan.treeString)

assert(sparkBytes > 0, s"Spark bytesRead should be > 0, got $sparkBytes")
assert(cometBytes > 0, s"Comet bytesRead should be > 0, got $cometBytes")
assert(sparkRecords > 0, s"Spark recordsRead should be > 0, got $sparkRecords")
assert(cometRecords > 0, s"Comet recordsRead should be > 0, got $cometRecords")

val ratio = cometBytes.toDouble / sparkBytes.toDouble
assert(
ratio >= 0.7 && ratio <= 1.3,
s"bytesRead ratio out of range: comet=$cometBytes, spark=$sparkBytes, ratio=$ratio")
}
}
}

/**
* Runs the given query with the given SQL config overrides and returns the aggregated
* (bytesRead, recordsRead) across all tasks, along with the executed plan.
*
* Uses AppStatusStore (same source as Spark UI) to read task-level input metrics.
* AppStatusStore stores immutable snapshots of metric values, unlike SparkListener's
* InputMetrics which are backed by mutable accumulators that can be reset.
*/
private def collectInputMetrics(
query: String,
confs: (String, String)*): (Long, Long, SparkPlan) = {
val store = spark.sparkContext.statusStore

// Record existing stage IDs so we only look at stages from our query
val stagesBefore = store.stageList(null).map(_.stageId).toSet

var plan: SparkPlan = null
withSQLConf(confs: _*) {
val df = sql(query)
df.collect()
plan = stripAQEPlan(df.queryExecution.executedPlan)
}

// Wait for listener bus to flush all events into the status store
spark.sparkContext.listenerBus.waitUntilEmpty()

// Sum input metrics from stages created by our query
val newStages = store.stageList(null).filterNot(s => stagesBefore.contains(s.stageId))
assert(newStages.nonEmpty, s"No new stages found for confs=$confs")

val totalBytes = newStages.map(_.inputBytes).sum
val totalRecords = newStages.map(_.inputRecords).sum

(totalBytes, totalRecords, plan)
}
}
Loading