Skip to content

Commit c3dd3a4

Browse files
authored
fix: handle scalar decimal value overflow correctly in ANSI mode (#3803)
* fix: handle scalar decimal value overflow correctly.
1 parent 124646f commit c3dd3a4

2 files changed

Lines changed: 196 additions & 16 deletions

File tree

native/spark-expr/src/math_funcs/internal/checkoverflow.rs

Lines changed: 167 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -199,22 +199,38 @@ impl PhysicalExpr for CheckOverflow {
199199
Ok(ColumnarValue::Array(new_array))
200200
}
201201
ColumnarValue::Scalar(ScalarValue::Decimal128(v, precision, scale)) => {
202-
// `fail_on_error` is only true when ANSI is enabled, which we don't support yet
203-
// (Java side will simply fallback to Spark when it is enabled)
204-
assert!(
205-
!self.fail_on_error,
206-
"fail_on_error (ANSI mode) is not supported yet"
207-
);
208-
209-
let new_v: Option<i128> = v.and_then(|v| {
210-
Decimal128Type::validate_decimal_precision(v, precision, scale)
211-
.map(|_| v)
212-
.ok()
213-
});
214-
215-
Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
216-
new_v, precision, scale,
217-
)))
202+
if self.fail_on_error {
203+
if let Some(val) = v {
204+
Decimal128Type::validate_decimal_precision(val, precision, scale).map_err(
205+
|_| {
206+
let spark_error =
207+
crate::error::decimal_overflow_error(val, precision, scale);
208+
if let Some(ctx) = &self.query_context {
209+
DataFusionError::External(Box::new(
210+
crate::SparkErrorWithContext::with_context(
211+
spark_error,
212+
Arc::clone(ctx),
213+
),
214+
))
215+
} else {
216+
DataFusionError::External(Box::new(spark_error))
217+
}
218+
},
219+
)?;
220+
}
221+
Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
222+
v, precision, scale,
223+
)))
224+
} else {
225+
let new_v: Option<i128> = v.and_then(|v| {
226+
Decimal128Type::validate_decimal_precision(v, precision, scale)
227+
.map(|_| v)
228+
.ok()
229+
});
230+
Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
231+
new_v, precision, scale,
232+
)))
233+
}
218234
}
219235
v => Err(DataFusionError::Execution(format!(
220236
"CheckOverflow's child expression should be decimal array, but found {v:?}"
@@ -239,3 +255,138 @@ impl PhysicalExpr for CheckOverflow {
239255
)))
240256
}
241257
}
258+
259+
#[cfg(test)]
260+
mod tests {
261+
use super::*;
262+
use arrow::datatypes::{Field, Schema};
263+
use arrow::record_batch::RecordBatch;
264+
use std::fmt::{Display, Formatter};
265+
266+
/// Helper that always returns a fixed Decimal128 scalar.
267+
#[derive(Debug, Eq, PartialEq, Hash)]
268+
struct ScalarChild(Option<i128>, u8, i8);
269+
270+
impl Display for ScalarChild {
271+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
272+
write!(f, "ScalarChild({:?})", self.0)
273+
}
274+
}
275+
276+
impl PhysicalExpr for ScalarChild {
277+
fn as_any(&self) -> &dyn Any {
278+
self
279+
}
280+
fn data_type(&self, _: &Schema) -> datafusion::common::Result<DataType> {
281+
Ok(DataType::Decimal128(self.1, self.2))
282+
}
283+
fn nullable(&self, _: &Schema) -> datafusion::common::Result<bool> {
284+
Ok(true)
285+
}
286+
fn evaluate(&self, _: &RecordBatch) -> datafusion::common::Result<ColumnarValue> {
287+
Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
288+
self.0, self.1, self.2,
289+
)))
290+
}
291+
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
292+
vec![]
293+
}
294+
fn with_new_children(
295+
self: Arc<Self>,
296+
_: Vec<Arc<dyn PhysicalExpr>>,
297+
) -> datafusion::common::Result<Arc<dyn PhysicalExpr>> {
298+
Ok(self)
299+
}
300+
fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
301+
Display::fmt(self, f)
302+
}
303+
}
304+
305+
fn empty_batch() -> RecordBatch {
306+
let schema = Schema::new(vec![Field::new("x", DataType::Decimal128(38, 0), true)]);
307+
RecordBatch::new_empty(Arc::new(schema))
308+
}
309+
310+
fn make_check_overflow(
311+
value: Option<i128>,
312+
precision: u8,
313+
scale: i8,
314+
fail_on_error: bool,
315+
) -> CheckOverflow {
316+
CheckOverflow::new(
317+
Arc::new(ScalarChild(value, precision, scale)),
318+
DataType::Decimal128(precision, scale),
319+
fail_on_error,
320+
None,
321+
None,
322+
)
323+
}
324+
325+
// --- scalar, fail_on_error = false (legacy mode) ---
326+
327+
#[test]
328+
fn test_scalar_no_overflow_legacy() {
329+
// 999 fits in precision 3, scale 0 → returned as-is
330+
let expr = make_check_overflow(Some(999), 3, 0, false);
331+
let result = expr.evaluate(&empty_batch()).unwrap();
332+
match result {
333+
ColumnarValue::Scalar(ScalarValue::Decimal128(v, 3, 0)) => assert_eq!(v, Some(999)),
334+
other => panic!("unexpected: {other:?}"),
335+
}
336+
}
337+
338+
#[test]
339+
fn test_scalar_overflow_returns_null_in_legacy_mode() {
340+
// 1000 does not fit in precision 3 → null, no error
341+
let expr = make_check_overflow(Some(1000), 3, 0, false);
342+
let result = expr.evaluate(&empty_batch()).unwrap();
343+
match result {
344+
ColumnarValue::Scalar(ScalarValue::Decimal128(v, 3, 0)) => assert_eq!(v, None),
345+
other => panic!("unexpected: {other:?}"),
346+
}
347+
}
348+
349+
#[test]
350+
fn test_scalar_null_passthrough_legacy() {
351+
let expr = make_check_overflow(None, 3, 0, false);
352+
let result = expr.evaluate(&empty_batch()).unwrap();
353+
match result {
354+
ColumnarValue::Scalar(ScalarValue::Decimal128(v, 3, 0)) => assert_eq!(v, None),
355+
other => panic!("unexpected: {other:?}"),
356+
}
357+
}
358+
359+
// --- scalar, fail_on_error = true (ANSI mode) ---
360+
361+
#[test]
362+
fn test_scalar_no_overflow_ansi() {
363+
// 999 fits in precision 3 → returned as-is, no error
364+
let expr = make_check_overflow(Some(999), 3, 0, true);
365+
let result = expr.evaluate(&empty_batch()).unwrap();
366+
match result {
367+
ColumnarValue::Scalar(ScalarValue::Decimal128(v, 3, 0)) => assert_eq!(v, Some(999)),
368+
other => panic!("unexpected: {other:?}"),
369+
}
370+
}
371+
372+
#[test]
373+
fn test_scalar_overflow_returns_error_in_ansi_mode() {
374+
// 1000 does not fit in precision 3 → error, not Ok(None)
375+
// This is the case that previously panicked with "fail_on_error (ANSI mode) is not
376+
// supported yet".
377+
let expr = make_check_overflow(Some(1000), 3, 0, true);
378+
let result = expr.evaluate(&empty_batch());
379+
assert!(result.is_err(), "expected error on overflow in ANSI mode");
380+
}
381+
382+
#[test]
383+
fn test_scalar_null_passthrough_ansi() {
384+
// None input → None output even in ANSI mode (no value to overflow)
385+
let expr = make_check_overflow(None, 3, 0, true);
386+
let result = expr.evaluate(&empty_batch()).unwrap();
387+
match result {
388+
ColumnarValue::Scalar(ScalarValue::Decimal128(v, 3, 0)) => assert_eq!(v, None),
389+
other => panic!("unexpected: {other:?}"),
390+
}
391+
}
392+
}

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,6 +1271,35 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
12711271
}
12721272
}
12731273

1274+
test("scalar decimal overflow - legacy mode produces null") {
1275+
// 1.1e19 * 1.1e19 = 1.21e38 fits in i128 (max ~1.7e38) but exceeds DECIMAL(38,0)'s
1276+
// max of 10^38-1, so CheckOverflow nulls the result in legacy (non-ANSI) mode.
1277+
withSQLConf(CometConf.COMET_ENABLED.key -> "true", SQLConf.ANSI_ENABLED.key -> "false") {
1278+
withParquetTable(Seq((BigDecimal("11000000000000000000"), 0)), "tbl") {
1279+
checkSparkAnswerAndOperator("SELECT _1 * _1 FROM tbl")
1280+
}
1281+
}
1282+
}
1283+
1284+
test("scalar decimal overflow - ANSI mode throws ArithmeticException") {
1285+
// 1.1e19 * 1.1e19 = 1.21e38 overflows DECIMAL(38,0). With ANSI mode on, both Spark and
1286+
// Comet must throw — Comet must not panic or silently return null. Spark reports
1287+
// NUMERIC_VALUE_OUT_OF_RANGE; Comet's WideDecimalBinaryExpr catches the overflow first
1288+
// and surfaces it as an arithmetic overflow error.
1289+
withSQLConf(CometConf.COMET_ENABLED.key -> "true", SQLConf.ANSI_ENABLED.key -> "true") {
1290+
withParquetTable(Seq((BigDecimal("11000000000000000000"), 0)), "tbl") {
1291+
val res = sql("SELECT _1 * _1 FROM tbl")
1292+
checkSparkAnswerMaybeThrows(res) match {
1293+
case (Some(sparkExc), Some(cometExc)) =>
1294+
assert(sparkExc.getMessage.contains("NUMERIC_VALUE_OUT_OF_RANGE"))
1295+
assert(cometExc.getMessage.toLowerCase.contains("overflow"))
1296+
case _ =>
1297+
fail("Expected exception for decimal overflow in ANSI mode")
1298+
}
1299+
}
1300+
}
1301+
}
1302+
12741303
test("cast decimals to int") {
12751304
Seq(16, 1024).foreach { batchSize =>
12761305
withSQLConf(

0 commit comments

Comments
 (0)