Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,18 @@ private[spark] class CometExecRDD(
ctx.addTaskCompletionListener[Unit] { _ =>
it.close()
subqueries.foreach(sub => CometScalarSubquery.removeSubquery(it.id, sub))

// Propagate native scan metrics (bytes_scanned, output_rows) to Spark's task-level
// inputMetrics so they appear in the Spark UI "Input" column and are reported via
// the listener infrastructure. The native reader bypasses Hadoop's Java FileSystem,
// so thread-local FS statistics are never updated -- we bridge the gap here.
val bytesScannedMetric = nativeMetrics.findMetric("bytes_scanned")
val outputRowsMetric = nativeMetrics.findMetric("output_rows")
if (bytesScannedMetric.isDefined || outputRowsMetric.isDefined) {
val inputMetrics = ctx.taskMetrics().inputMetrics
bytesScannedMetric.foreach(m => inputMetrics.setBytesRead(m.value))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

foreach already handles the None case for finding the metric, so I find wrapping this in if unnecessary. You save ctx.taskMetrics().inputMetrics but the result is oddly-structured conditional logic.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree

outputRowsMetric.foreach(m => inputMetrics.setRecordsRead(m.value))
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,21 @@ case class CometMetricNode(metrics: Map[String, SQLMetric], children: Seq[CometM
}
}

// Called via JNI from `comet_metric_node.rs`
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that the only place this will ever be called from? Otherwise I'm not sure the comment is necessary.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDE highlights the method as unused because it is called via JNI only, can be accidentally cleaned up. Added comments to clarify

def set_all_from_bytes(bytes: Array[Byte]): Unit = {
val metricNode = Metric.NativeMetricNode.parseFrom(bytes)
set_all(metricNode)
}

/**
* Finds a metric by name in this node or any descendant node. Returns the first match found via
* depth-first search.
*/
def findMetric(name: String): Option[SQLMetric] = {
metrics.get(name).orElse {
children.iterator.map(_.findMetric(name)).collectFirst { case Some(m) => m }
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this just return the first match it finds with the metric name? Can't multiple plans have nodes that have "output_rows"?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mm, what if we try to restrict output_rows to scan nodes?

}
}
}

object CometMetricNode {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package org.apache.spark.sql.comet

import scala.collection.mutable

import org.apache.spark.executor.InputMetrics
import org.apache.spark.executor.ShuffleReadMetrics
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.SparkListener
Expand All @@ -30,6 +31,8 @@ import org.apache.spark.sql.comet.execution.shuffle.CometNativeShuffle
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper

import org.apache.comet.CometConf

class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper {

import testImplicits._
Expand Down Expand Up @@ -91,4 +94,66 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}
}

test("native_datafusion scan reports task-level input metrics matching Spark") {
withParquetTable((0 until 10000).map(i => (i, (i + 1).toLong)), "tbl") {
// Collect baseline input metrics from vanilla Spark (Comet disabled)
val (sparkBytes, sparkRecords) = collectInputMetrics(CometConf.COMET_ENABLED.key -> "false")

// Collect input metrics from Comet native_datafusion scan
val (cometBytes, cometRecords) = collectInputMetrics(
CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
val (cometBytes, cometRecords) = collectInputMetrics(
CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION)
val (cometBytes, cometRecords) = collectInputMetrics(
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CometConf.COMET_ENABLED.key -> "true", is enabled on test level by default, but I think we might ensure Comet operators was applied


// Records must match exactly
assert(
cometRecords == sparkRecords,
s"recordsRead mismatch: comet=$cometRecords, spark=$sparkRecords")

// Bytes should be in the same ballpark -- both read the same Parquet file(s),
// but the exact byte count can differ due to reader implementation details
// (e.g. footer reads, page headers, buffering granularity).
assert(sparkBytes > 0, s"Spark bytesRead should be > 0, got $sparkBytes")
assert(cometBytes > 0, s"Comet bytesRead should be > 0, got $cometBytes")
val ratio = cometBytes.toDouble / sparkBytes.toDouble
assert(
ratio >= 0.8 && ratio <= 1.2,
s"bytesRead ratio out of range: comet=$cometBytes, spark=$sparkBytes, ratio=$ratio")
}
}

/**
* Runs `SELECT * FROM tbl` with the given SQL config overrides and returns the aggregated
* (bytesRead, recordsRead) across all tasks.
*/
private def collectInputMetrics(confs: (String, String)*): (Long, Long) = {
val inputMetricsList = mutable.ArrayBuffer.empty[InputMetrics]

val listener = new SparkListener {
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
val im = taskEnd.taskMetrics.inputMetrics
inputMetricsList.synchronized {
inputMetricsList += im
}
}
}

spark.sparkContext.addSparkListener(listener)
try {
// Drain any earlier events
spark.sparkContext.listenerBus.waitUntilEmpty()

withSQLConf(confs: _*) {
sql("SELECT * FROM tbl").collect()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sql("SELECT * FROM tbl").collect()
sql("SELECT * FROM tbl WHERE _1 > 5000").collect()

add a filter to make it more realistic

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @martin-g why the filter would be needed? I'd prefer to keep repro as simple as possible

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A filter would show the discrepancy/incorrect values when scan isn't the first child node.

}

spark.sparkContext.listenerBus.waitUntilEmpty()

assert(inputMetricsList.nonEmpty, s"No input metrics found for confs=$confs")
val totalBytes = inputMetricsList.map(_.bytesRead).sum
val totalRecords = inputMetricsList.map(_.recordsRead).sum
(totalBytes, totalRecords)
} finally {
spark.sparkContext.removeSparkListener(listener)
}
}
}
Loading