Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2438,7 +2438,7 @@ impl PhysicalPlanner {
sort_phy_exprs,
window_frame.into(),
input_schema,
false, // TODO: Ignore nulls
spark_expr.ignore_nulls,
false, // TODO: Spark does not support DISTINCT ... OVER
None,
)
Expand Down Expand Up @@ -2492,6 +2492,12 @@ impl PhysicalPlanner {
.udaf(name)
.map(WindowFunctionDefinition::AggregateUDF)
.ok()
.or_else(|| {
registry
.udwf(name)
.map(WindowFunctionDefinition::WindowUDF)
.ok()
})
}

/// Create a DataFusion physical partitioning from Spark physical partitioning
Expand Down
1 change: 1 addition & 0 deletions native/proto/src/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ message WindowExpr {
spark.spark_expression.Expr built_in_window_function = 1;
spark.spark_expression.AggExpr agg_func = 2;
WindowSpecDefinition spec = 3;
bool ignore_nulls = 4;
}

enum WindowFrameType {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package org.apache.spark.sql.comet

import scala.jdk.CollectionConverters._

import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, CurrentRow, Expression, NamedExpression, RangeFrame, RowFrame, SortOrder, SpecifiedWindowFrame, UnboundedFollowing, UnboundedPreceding, WindowExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, CurrentRow, Expression, FrameLessOffsetWindowFunction, Lag, Lead, NamedExpression, RangeFrame, RowFrame, SortOrder, SpecifiedWindowFrame, UnboundedFollowing, UnboundedPreceding, WindowExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count, Max, Min, Sum}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.SparkPlan
Expand All @@ -36,7 +36,7 @@ import org.apache.comet.{CometConf, ConfigEntry}
import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.serde.{AggSerde, CometOperatorSerde, Incompatible, OperatorOuterClass, SupportLevel}
import org.apache.comet.serde.OperatorOuterClass.Operator
import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, exprToProto}
import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, exprToProto, scalarFunctionExprToProto}

object CometWindowExec extends CometOperatorSerde[WindowExec] {

Expand Down Expand Up @@ -72,7 +72,12 @@ object CometWindowExec extends CometOperatorSerde[WindowExec] {
return None
}

if (op.partitionSpec.nonEmpty && op.orderSpec.nonEmpty &&
// Offset window functions (LAG, LEAD) support arbitrary partition and order specs, so skip
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering if FIRST_VALUE, LAST_VALUE, NTH are also offset window function, cause they also access the data within frame by some offset (FiRST_VALUE by 1, etc) ?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, in Spark

Lag / Lead → inherit FrameLessOffsetWindowFunction
NthValue → inherit AggregateWindowFunction with OffsetWindowFunction
First / Last → inherit DeclarativeAggregate

Only FrameLessOffsetWindowFunction doesn't require frame, currently in Spark only Lag / Lead are FrameLessOffsetWindowFunction.

// the validatePartitionAndSortSpecsForWindowFunc check which requires partition columns to
// equal order columns. That stricter check is only needed for aggregate window functions.
val hasOnlyOffsetFunctions = winExprs.nonEmpty &&
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment explaining the logic here?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, added a comment.

winExprs.forall(e => e.windowFunction.isInstanceOf[FrameLessOffsetWindowFunction])
if (!hasOnlyOffsetFunctions && op.partitionSpec.nonEmpty && op.orderSpec.nonEmpty &&
!validatePartitionAndSortSpecsForWindowFunc(op.partitionSpec, op.orderSpec, op)) {
return None
}
Expand Down Expand Up @@ -141,12 +146,27 @@ object CometWindowExec extends CometOperatorSerde[WindowExec] {
}
}.toArray

val (aggExpr, builtinFunc) = if (aggregateExpressions.nonEmpty) {
val (aggExpr, builtinFunc, ignoreNulls) = if (aggregateExpressions.nonEmpty) {
val modes = aggregateExpressions.map(_.mode).distinct
assert(modes.size == 1 && modes.head == Complete)
(aggExprToProto(aggregateExpressions.head, output, true, conf), None)
(aggExprToProto(aggregateExpressions.head, output, true, conf), None, false)
} else {
(None, exprToProto(windowExpr.windowFunction, output))
windowExpr.windowFunction match {
case lag: Lag =>
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should Lead also be handled the same way?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Since you asked, I just added Lead in this PR too.

val inputExpr = exprToProto(lag.input, output)
val offsetExpr = exprToProto(lag.inputOffset, output)
val defaultExpr = exprToProto(lag.default, output)
val func = scalarFunctionExprToProto("lag", inputExpr, offsetExpr, defaultExpr)
(None, func, lag.ignoreNulls)
case lead: Lead =>
val inputExpr = exprToProto(lead.input, output)
val offsetExpr = exprToProto(lead.offset, output)
val defaultExpr = exprToProto(lead.default, output)
val func = scalarFunctionExprToProto("lead", inputExpr, offsetExpr, defaultExpr)
(None, func, lead.ignoreNulls)
case _ =>
(None, exprToProto(windowExpr.windowFunction, output), false)
}
}

if (aggExpr.isEmpty && builtinFunc.isEmpty) {
Expand Down Expand Up @@ -254,6 +274,7 @@ object CometWindowExec extends CometOperatorSerde[WindowExec] {
.newBuilder()
.setBuiltInWindowFunction(builtinFunc.get)
.setSpec(spec)
.setIgnoreNulls(ignoreNulls)
.build())
} else if (aggExpr.isDefined) {
Some(
Expand Down
185 changes: 115 additions & 70 deletions spark/src/test/scala/org/apache/comet/exec/CometWindowExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.sql.{CometTestBase, Row}
import org.apache.spark.sql.comet.CometWindowExec
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{count, lead, sum}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -605,87 +606,131 @@ class CometWindowExecSuite extends CometTestBase {
}
}

// TODO: LAG produces incorrect results
ignore("window: LAG with default offset") {
withTempDir { dir =>
(0 until 30)
.map(i => (i % 3, i % 5, i))
.toDF("a", "b", "c")
.repartition(3)
.write
.mode("overwrite")
.parquet(dir.toString)

spark.read.parquet(dir.toString).createOrReplaceTempView("window_test")
val df = sql("""
SELECT a, b, c,
LAG(c) OVER (PARTITION BY a ORDER BY b) as lag_c
FROM window_test
""")
checkSparkAnswerAndOperator(df)
test("window: LAG with default offset") {
withSQLConf(CometConf.getOperatorAllowIncompatConfigKey(classOf[WindowExec]) -> "true") {
withTempDir { dir =>
(0 until 30)
.map(i => (i % 3, i % 5, i))
.toDF("a", "b", "c")
.repartition(3)
.write
.mode("overwrite")
.parquet(dir.toString)

spark.read.parquet(dir.toString).createOrReplaceTempView("window_test")
val df = sql("""
SELECT a, b, c,
LAG(c) OVER (PARTITION BY a ORDER BY b, c) as lag_c
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this testing IGNORE NULL?

Copy link
Copy Markdown
Member Author

@viirya viirya Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, missed it. Added another test with IGNORE NULLs now.

FROM window_test
""")
checkSparkAnswerAndOperator(df)
}
}
}

// TODO: LAG with offset 2 produces incorrect results
ignore("window: LAG with offset 2 and default value") {
withTempDir { dir =>
(0 until 30)
.map(i => (i % 3, i % 5, i))
.toDF("a", "b", "c")
.repartition(3)
.write
.mode("overwrite")
.parquet(dir.toString)

spark.read.parquet(dir.toString).createOrReplaceTempView("window_test")
val df = sql("""
SELECT a, b, c,
LAG(c, 2, -1) OVER (PARTITION BY a ORDER BY b) as lag_c_2
FROM window_test
""")
checkSparkAnswerAndOperator(df)
test("window: LAG with offset 2 and default value") {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh nice, the PR also fixes this test although it is not related to IGNORE NULLS

withSQLConf(CometConf.getOperatorAllowIncompatConfigKey(classOf[WindowExec]) -> "true") {
withTempDir { dir =>
(0 until 30)
.map(i => (i % 3, i % 5, i))
.toDF("a", "b", "c")
.repartition(3)
.write
.mode("overwrite")
.parquet(dir.toString)

spark.read.parquet(dir.toString).createOrReplaceTempView("window_test")
val df = sql("""
SELECT a, b, c,
LAG(c, 2, -1) OVER (PARTITION BY a ORDER BY b, c) as lag_c_2
FROM window_test
""")
checkSparkAnswerAndOperator(df)
}
}
}

// TODO: LEAD produces incorrect results
ignore("window: LEAD with default offset") {
withTempDir { dir =>
(0 until 30)
.map(i => (i % 3, i % 5, i))
.toDF("a", "b", "c")
.repartition(3)
.write
.mode("overwrite")
.parquet(dir.toString)
test("window: LAG with IGNORE NULLS") {
withSQLConf(CometConf.getOperatorAllowIncompatConfigKey(classOf[WindowExec]) -> "true") {
withTempDir { dir =>
Seq((1, 1, Some(10)), (1, 2, None), (1, 3, Some(30)), (2, 1, None), (2, 2, Some(20)))
.toDF("a", "b", "c")
.write
.mode("overwrite")
.parquet(dir.toString)

spark.read.parquet(dir.toString).createOrReplaceTempView("window_test")
val df = sql("""
SELECT a, b, c,
LAG(c) IGNORE NULLS OVER (PARTITION BY a ORDER BY b) as lag_c
FROM window_test
""")
checkSparkAnswerAndOperator(df)
}
}
}

spark.read.parquet(dir.toString).createOrReplaceTempView("window_test")
val df = sql("""
SELECT a, b, c,
LEAD(c) OVER (PARTITION BY a ORDER BY b) as lead_c
FROM window_test
""")
checkSparkAnswerAndOperator(df)
test("window: LEAD with default offset") {
withSQLConf(CometConf.getOperatorAllowIncompatConfigKey(classOf[WindowExec]) -> "true") {
withTempDir { dir =>
(0 until 30)
.map(i => (i % 3, i % 5, i))
.toDF("a", "b", "c")
.repartition(3)
.write
.mode("overwrite")
.parquet(dir.toString)

spark.read.parquet(dir.toString).createOrReplaceTempView("window_test")
val df = sql("""
SELECT a, b, c,
LEAD(c) OVER (PARTITION BY a ORDER BY b, c) as lead_c
FROM window_test
""")
checkSparkAnswerAndOperator(df)
}
}
}

// TODO: LEAD with offset 2 produces incorrect results
ignore("window: LEAD with offset 2 and default value") {
withTempDir { dir =>
(0 until 30)
.map(i => (i % 3, i % 5, i))
.toDF("a", "b", "c")
.repartition(3)
.write
.mode("overwrite")
.parquet(dir.toString)
test("window: LEAD with offset 2 and default value") {
withSQLConf(CometConf.getOperatorAllowIncompatConfigKey(classOf[WindowExec]) -> "true") {
withTempDir { dir =>
(0 until 30)
.map(i => (i % 3, i % 5, i))
.toDF("a", "b", "c")
.repartition(3)
.write
.mode("overwrite")
.parquet(dir.toString)

spark.read.parquet(dir.toString).createOrReplaceTempView("window_test")
val df = sql("""
SELECT a, b, c,
LEAD(c, 2, -1) OVER (PARTITION BY a ORDER BY b, c) as lead_c_2
FROM window_test
""")
checkSparkAnswerAndOperator(df)
}
}
}

spark.read.parquet(dir.toString).createOrReplaceTempView("window_test")
val df = sql("""
SELECT a, b, c,
LEAD(c, 2, -1) OVER (PARTITION BY a ORDER BY b) as lead_c_2
FROM window_test
""")
checkSparkAnswerAndOperator(df)
test("window: LEAD with IGNORE NULLS") {
withSQLConf(CometConf.getOperatorAllowIncompatConfigKey(classOf[WindowExec]) -> "true") {
withTempDir { dir =>
Seq((1, 1, Some(10)), (1, 2, None), (1, 3, Some(30)), (2, 1, None), (2, 2, Some(20)))
.toDF("a", "b", "c")
.write
.mode("overwrite")
.parquet(dir.toString)

spark.read.parquet(dir.toString).createOrReplaceTempView("window_test")
val df = sql("""
SELECT a, b, c,
LEAD(c) IGNORE NULLS OVER (PARTITION BY a ORDER BY b) as lead_c
FROM window_test
""")
checkSparkAnswerAndOperator(df)
}
}
}

Expand Down
Loading