Skip to content

Commit c481721

Browse files
chor: enable Corr
1 parent f5fa616 commit c481721

5 files changed

Lines changed: 60 additions & 37 deletions

File tree

docs/source/user-guide/latest/compatibility.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,6 @@ the [Comet Supported Expressions Guide](expressions.md) for more information on
8787

8888
### Aggregate Expressions
8989

90-
- **Corr**: Returns null instead of NaN in some edge cases.
91-
[#2646](https://github.com/apache/datafusion-comet/issues/2646)
9290
- **First, Last**: These functions are not deterministic. When `ignoreNulls` is set, results may not match Spark.
9391
[#1630](https://github.com/apache/datafusion-comet/issues/1630)
9492

docs/source/user-guide/latest/expressions.md

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -189,27 +189,27 @@ Expressions that are not Spark-compatible will fall back to Spark by default and
189189

190190
## Aggregate Expressions
191191

192-
| Expression | SQL | Spark-Compatible? | Compatibility Notes |
193-
| ------------- | ---------- | ------------------------- | ---------------------------------------------------------------------------------------------------------------- |
194-
| Average | | Yes, except for ANSI mode | |
195-
| BitAndAgg | | Yes | |
196-
| BitOrAgg | | Yes | |
197-
| BitXorAgg | | Yes | |
198-
| BoolAnd | `bool_and` | Yes | |
199-
| BoolOr | `bool_or` | Yes | |
200-
| Corr | | No | Returns null instead of NaN in some edge cases ([#2646](https://github.com/apache/datafusion-comet/issues/2646)) |
201-
| Count | | Yes | |
202-
| CovPopulation | | Yes | |
203-
| CovSample | | Yes | |
204-
| First | | No | This function is not deterministic. Results may not match Spark. |
205-
| Last | | No | This function is not deterministic. Results may not match Spark. |
206-
| Max | | Yes | |
207-
| Min | | Yes | |
208-
| StddevPop | | Yes | |
209-
| StddevSamp | | Yes | |
210-
| Sum | | Yes, except for ANSI mode | |
211-
| VariancePop | | Yes | |
212-
| VarianceSamp | | Yes | |
192+
| Expression | SQL | Spark-Compatible? | Compatibility Notes |
193+
| ------------- | ---------- | ------------------------- | ---------------------------------------------------------------- |
194+
| Average | | Yes, except for ANSI mode | |
195+
| BitAndAgg | | Yes | |
196+
| BitOrAgg | | Yes | |
197+
| BitXorAgg | | Yes | |
198+
| BoolAnd | `bool_and` | Yes | |
199+
| BoolOr | `bool_or` | Yes | |
200+
| Corr | | Yes | |
201+
| Count | | Yes | |
202+
| CovPopulation | | Yes | |
203+
| CovSample | | Yes | |
204+
| First | | No | This function is not deterministic. Results may not match Spark. |
205+
| Last | | No | This function is not deterministic. Results may not match Spark. |
206+
| Max | | Yes | |
207+
| Min | | Yes | |
208+
| StddevPop | | Yes | |
209+
| StddevSamp | | Yes | |
210+
| Sum | | Yes, except for ANSI mode | |
211+
| VariancePop | | Yes | |
212+
| VarianceSamp | | Yes | |
213213

214214
## Window Functions
215215

spark/src/main/scala/org/apache/comet/serde/aggregates.scala

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ package org.apache.comet.serde
2121

2222
import scala.jdk.CollectionConverters._
2323

24-
import org.apache.spark.sql.catalyst.expressions.Attribute
24+
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, If, IsNaN, Literal}
2525
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp}
2626
import org.apache.spark.sql.internal.SQLConf
2727
import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType}
@@ -584,20 +584,18 @@ object CometStddevPop extends CometAggregateExpressionSerde[StddevPop] with Come
584584
}
585585

586586
object CometCorr extends CometAggregateExpressionSerde[Corr] {
587-
588-
override def getSupportLevel(expr: Corr): SupportLevel =
589-
Incompatible(
590-
Some(
591-
"Returns null instead of NaN in some edge cases" +
592-
" (https://github.com/apache/datafusion-comet/issues/2646)"))
593-
594587
override def convert(
595588
aggExpr: AggregateExpression,
596589
corr: Corr,
597590
inputs: Seq[Attribute],
598591
binding: Boolean,
599592
conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
600-
val child1Expr = exprToProto(corr.x, inputs, binding)
593+
// When both inputs are NaN, convert one input to null in order to return null.
594+
// This matches Spark's behavior where corr(NaN, NaN) returns null.
595+
val wrappedX =
596+
If(And(IsNaN(corr.x), IsNaN(corr.y)), Literal.create(null, corr.x.dataType), corr.x)
597+
598+
val child1Expr = exprToProto(wrappedX, inputs, binding)
601599
val child2Expr = exprToProto(corr.y, inputs, binding)
602600
val dataType = serializeDataType(corr.dataType)
603601

spark/src/test/resources/sql-tests/expressions/aggregate/corr.sql

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
-- specific language governing permissions and limitations
1616
-- under the License.
1717

18-
-- Config: spark.comet.expression.Corr.allowIncompatible=true
1918
-- ConfigMatrix: parquet.enable.dictionary=false,true
2019

2120
statement
@@ -29,3 +28,13 @@ SELECT corr(x, y) FROM test_corr
2928

3029
query tolerance=1e-6
3130
SELECT grp, corr(x, y) FROM test_corr GROUP BY grp ORDER BY grp
31+
32+
-- Test permutations of NULL and NaN
33+
statement
34+
CREATE TABLE test_corr_nan(x double, y double, grp string) USING parquet
35+
36+
statement
37+
INSERT INTO test_corr_nan VALUES (cast('NaN' as double), cast('NaN' as double), 'both_nan'), (cast('NaN' as double), 1.0, 'nan_val'), (1.0, cast('NaN' as double), 'val_nan'), (NULL, cast('NaN' as double), 'null_nan'), (cast('NaN' as double), NULL, 'nan_null'), (NULL, NULL, 'both_null'), (NULL, 1.0, 'null_val'), (1.0, NULL, 'val_null')
38+
39+
query tolerance=1e-6
40+
SELECT grp, corr(x, y) FROM test_corr_nan GROUP BY grp ORDER BY grp

spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import scala.util.Random
2424
import org.apache.hadoop.fs.Path
2525
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
2626
import org.apache.spark.sql.catalyst.expressions.Cast
27-
import org.apache.spark.sql.catalyst.expressions.aggregate.Corr
2827
import org.apache.spark.sql.catalyst.optimizer.EliminateSorts
2928
import org.apache.spark.sql.comet.CometHashAggregateExec
3029
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -1320,9 +1319,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
13201319
}
13211320

13221321
test("covariance & correlation") {
1323-
withSQLConf(
1324-
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
1325-
CometConf.getExprAllowIncompatConfigKey(classOf[Corr]) -> "true") {
1322+
withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") {
13261323
Seq("jvm", "native").foreach { cometShuffleMode =>
13271324
withSQLConf(CometConf.COMET_SHUFFLE_MODE.key -> cometShuffleMode) {
13281325
Seq(true, false).foreach { dictionary =>
@@ -1393,6 +1390,27 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
13931390
}
13941391
}
13951392

1393+
test("corr - nan/null") {
1394+
withTable("t1") {
1395+
sql("""create table t1 using parquet as
1396+
select cast(null as float) f1, CAST('NaN' AS float) f2, cast(null as double) d1, CAST('NaN' AS double) d2
1397+
from range(1)
1398+
""")
1399+
1400+
checkSparkAnswerAndOperator("""
1401+
|select
1402+
| corr(f1, f2) c1,
1403+
| corr(f1, f1) c2,
1404+
| corr(f2, f1) c3,
1405+
| corr(f2, f2) c4,
1406+
| corr(d1, d2) c5,
1407+
| corr(d1, d1) c6,
1408+
| corr(d2, d1) c7,
1409+
| corr(d2, d2) c8
1410+
| FROM t1""".stripMargin)
1411+
}
1412+
}
1413+
13961414
test("var_pop and var_samp") {
13971415
withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") {
13981416
Seq("native", "jvm").foreach { cometShuffleMode =>

0 commit comments

Comments
 (0)