Skip to content

Commit 919c9ad

Browse files
authored
perf: Optimize some decimal expressions (#3619)
* feat: fused WideDecimalBinaryExpr for Decimal128 add/sub/mul Replace the 4-node expression tree (Cast→BinaryExpr→Cast→Cast) used for Decimal128 arithmetic that may overflow with a single fused expression that performs i256 register arithmetic directly. This reduces per-batch allocation from 4 intermediate arrays (112 bytes/elem) to 1 output array (16 bytes/elem). The new WideDecimalBinaryExpr evaluates children, performs add/sub/mul using i256 intermediates via try_binary, applies scale adjustment with HALF_UP rounding, checks precision bounds, and outputs a single Decimal128 array. Follows the same pattern as decimal_div. * feat: add criterion benchmark for wide decimal binary expr Add benchmark comparing old Cast->BinaryExpr->Cast chain vs fused WideDecimalBinaryExpr for Decimal128 add/sub/mul. Covers four cases: add with same scale, add with different scales, multiply, and subtract. * feat: fuse CheckOverflow with Cast and WideDecimalBinaryExpr Eliminate redundant CheckOverflow when wrapping WideDecimalBinaryExpr (which already handles overflow). Fuse Cast(Decimal128→Decimal128) + CheckOverflow into a single DecimalRescaleCheckOverflow expression that rescales and validates precision in one pass. * fix: address PR review feedback for decimal optimizations - Handle scale-up when s_out > max(s1, s2) in add/subtract - Propagate errors in scalar path when fail_on_error=true - Guard against large scale delta (>38) overflow in rescale - Assert precision <= 38 in precision_bound - Assert exp <= 76 in i256_pow10 - Remove unnecessary _ prefix on used variables in planner - Use value.signum() instead of manual sign check - Verify Cast target type matches before fusing with CheckOverflow - Validate children count in with_new_children for both expressions - Add tests for scale-up, scalar error propagation, and large delta * style: apply cargo fmt * fix: add defensive checks for CheckOverflow bypass and multiply scale-up - Validate WideDecimalBinaryExpr output type matches CheckOverflow data_type before bypassing the overflow check - Handle s_out > natural_scale (scale-up) in multiply path for consistency with add/subtract
1 parent a05c568 commit 919c9ad

8 files changed

Lines changed: 1273 additions & 27 deletions

File tree

native/core/src/execution/planner.rs

Lines changed: 55 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,9 @@ use datafusion_comet_proto::{
126126
use datafusion_comet_spark_expr::monotonically_increasing_id::MonotonicallyIncreasingId;
127127
use datafusion_comet_spark_expr::{
128128
ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct,
129-
GetArrayStructFields, GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, RandExpr,
130-
RandnExpr, SparkCastOptions, Stddev, SumDecimal, ToJson, UnboundColumn, Variance,
129+
DecimalRescaleCheckOverflow, GetArrayStructFields, GetStructField, IfExpr, ListExtract,
130+
NormalizeNaNAndZero, RandExpr, RandnExpr, SparkCastOptions, Stddev, SumDecimal, ToJson,
131+
UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp,
131132
};
132133
use itertools::Itertools;
133134
use jni::objects::GlobalRef;
@@ -408,10 +409,45 @@ impl PhysicalPlanner {
408409
)))
409410
}
410411
ExprStruct::CheckOverflow(expr) => {
411-
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
412+
let child =
413+
self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?;
412414
let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap());
413415
let fail_on_error = expr.fail_on_error;
414416

417+
// WideDecimalBinaryExpr already handles overflow — skip redundant check
418+
// but only if its output type matches CheckOverflow's declared type
419+
if child
420+
.as_any()
421+
.downcast_ref::<WideDecimalBinaryExpr>()
422+
.is_some()
423+
{
424+
let child_type = child.data_type(&input_schema)?;
425+
if child_type == data_type {
426+
return Ok(child);
427+
}
428+
}
429+
430+
// Fuse Cast(Decimal128→Decimal128) + CheckOverflow into single rescale+check
431+
// Only fuse when the Cast target type matches the CheckOverflow output type
432+
if let Some(cast) = child.as_any().downcast_ref::<Cast>() {
433+
if let (
434+
DataType::Decimal128(p_out, s_out),
435+
Ok(DataType::Decimal128(_p_in, s_in)),
436+
) = (&data_type, cast.child.data_type(&input_schema))
437+
{
438+
let cast_target = cast.data_type(&input_schema)?;
439+
if cast_target == data_type {
440+
return Ok(Arc::new(DecimalRescaleCheckOverflow::new(
441+
Arc::clone(&cast.child),
442+
s_in,
443+
*p_out,
444+
*s_out,
445+
fail_on_error,
446+
)));
447+
}
448+
}
449+
}
450+
415451
// Look up query context from registry if expr_id is present
416452
let query_context = spark_expr.expr_id.and_then(|expr_id| {
417453
let registry = &self.query_context_registry;
@@ -740,29 +776,22 @@ impl PhysicalPlanner {
740776
|| (op == DataFusionOperator::Multiply && p1 + p2 >= DECIMAL128_MAX_PRECISION) =>
741777
{
742778
let data_type = return_type.map(to_arrow_datatype).unwrap();
743-
// For some Decimal128 operations, we need wider internal digits.
744-
// Cast left and right to Decimal256 and cast the result back to Decimal128
745-
let left = Arc::new(Cast::new(
746-
left,
747-
DataType::Decimal256(p1, s1),
748-
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
749-
None,
750-
None,
751-
));
752-
let right = Arc::new(Cast::new(
753-
right,
754-
DataType::Decimal256(p2, s2),
755-
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
756-
None,
757-
None,
758-
));
759-
let child = Arc::new(BinaryExpr::new(left, op, right));
760-
Ok(Arc::new(Cast::new(
761-
child,
762-
data_type,
763-
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
764-
None,
765-
None,
779+
let (p_out, s_out) = match &data_type {
780+
DataType::Decimal128(p, s) => (*p, *s),
781+
dt => {
782+
return Err(ExecutionError::GeneralError(format!(
783+
"Expected Decimal128 return type, got {dt:?}"
784+
)))
785+
}
786+
};
787+
let wide_op = match op {
788+
DataFusionOperator::Plus => WideDecimalOp::Add,
789+
DataFusionOperator::Minus => WideDecimalOp::Subtract,
790+
DataFusionOperator::Multiply => WideDecimalOp::Multiply,
791+
_ => unreachable!(),
792+
};
793+
Ok(Arc::new(WideDecimalBinaryExpr::new(
794+
left, right, wide_op, p_out, s_out, eval_mode,
766795
)))
767796
}
768797
(

native/spark-expr/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@ path = "tests/spark_expr_reg.rs"
105105
name = "cast_from_boolean"
106106
harness = false
107107

108+
[[bench]]
109+
name = "wide_decimal"
110+
harness = false
111+
108112
[[bench]]
109113
name = "cast_non_int_numeric_timestamp"
110114
harness = false
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Benchmarks comparing the old Cast->BinaryExpr->Cast chain vs the fused WideDecimalBinaryExpr
19+
//! for Decimal128 arithmetic that requires wider intermediate precision.
20+
21+
use arrow::array::builder::Decimal128Builder;
22+
use arrow::array::RecordBatch;
23+
use arrow::datatypes::{DataType, Field, Schema};
24+
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
25+
use datafusion::logical_expr::Operator;
26+
use datafusion::physical_expr::expressions::{BinaryExpr, Column};
27+
use datafusion::physical_expr::PhysicalExpr;
28+
use datafusion_comet_spark_expr::{
29+
Cast, EvalMode, SparkCastOptions, WideDecimalBinaryExpr, WideDecimalOp,
30+
};
31+
use std::sync::Arc;
32+
33+
const BATCH_SIZE: usize = 8192;
34+
35+
/// Build a RecordBatch with two Decimal128 columns.
36+
fn make_decimal_batch(p1: u8, s1: i8, p2: u8, s2: i8) -> RecordBatch {
37+
let mut left = Decimal128Builder::new();
38+
let mut right = Decimal128Builder::new();
39+
for i in 0..BATCH_SIZE as i128 {
40+
left.append_value(123456789012345_i128 + i * 1000);
41+
right.append_value(987654321098765_i128 - i * 1000);
42+
}
43+
let left = left.finish().with_data_type(DataType::Decimal128(p1, s1));
44+
let right = right.finish().with_data_type(DataType::Decimal128(p2, s2));
45+
let schema = Schema::new(vec![
46+
Field::new("left", DataType::Decimal128(p1, s1), false),
47+
Field::new("right", DataType::Decimal128(p2, s2), false),
48+
]);
49+
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(left), Arc::new(right)]).unwrap()
50+
}
51+
52+
/// Old approach: Cast(Decimal128->Decimal256) both sides, BinaryExpr, Cast(Decimal256->Decimal128).
53+
fn build_old_expr(
54+
p1: u8,
55+
s1: i8,
56+
p2: u8,
57+
s2: i8,
58+
op: Operator,
59+
out_type: DataType,
60+
) -> Arc<dyn PhysicalExpr> {
61+
let left_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("left", 0));
62+
let right_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("right", 1));
63+
let cast_opts = SparkCastOptions::new_without_timezone(EvalMode::Legacy, false);
64+
let left_cast = Arc::new(Cast::new(
65+
left_col,
66+
DataType::Decimal256(p1, s1),
67+
cast_opts.clone(),
68+
None,
69+
None,
70+
));
71+
let right_cast = Arc::new(Cast::new(
72+
right_col,
73+
DataType::Decimal256(p2, s2),
74+
cast_opts.clone(),
75+
None,
76+
None,
77+
));
78+
let binary = Arc::new(BinaryExpr::new(left_cast, op, right_cast));
79+
Arc::new(Cast::new(binary, out_type, cast_opts, None, None))
80+
}
81+
82+
/// New approach: single fused WideDecimalBinaryExpr.
83+
fn build_new_expr(op: WideDecimalOp, p_out: u8, s_out: i8) -> Arc<dyn PhysicalExpr> {
84+
let left_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("left", 0));
85+
let right_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("right", 1));
86+
Arc::new(WideDecimalBinaryExpr::new(
87+
left_col,
88+
right_col,
89+
op,
90+
p_out,
91+
s_out,
92+
EvalMode::Legacy,
93+
))
94+
}
95+
96+
fn bench_case(
97+
group: &mut criterion::BenchmarkGroup<criterion::measurement::WallTime>,
98+
name: &str,
99+
batch: &RecordBatch,
100+
old_expr: &Arc<dyn PhysicalExpr>,
101+
new_expr: &Arc<dyn PhysicalExpr>,
102+
) {
103+
group.bench_with_input(BenchmarkId::new("old", name), batch, |b, batch| {
104+
b.iter(|| old_expr.evaluate(batch).unwrap());
105+
});
106+
group.bench_with_input(BenchmarkId::new("fused", name), batch, |b, batch| {
107+
b.iter(|| new_expr.evaluate(batch).unwrap());
108+
});
109+
}
110+
111+
fn criterion_benchmark(c: &mut Criterion) {
112+
let mut group = c.benchmark_group("wide_decimal");
113+
114+
// Case 1: Add with same scale - Decimal128(38,10) + Decimal128(38,10) -> Decimal128(38,10)
115+
// Triggers wide path because max(s1,s2) + max(p1-s1, p2-s2) = 10 + 28 = 38 >= 38
116+
{
117+
let batch = make_decimal_batch(38, 10, 38, 10);
118+
let old = build_old_expr(38, 10, 38, 10, Operator::Plus, DataType::Decimal128(38, 10));
119+
let new = build_new_expr(WideDecimalOp::Add, 38, 10);
120+
bench_case(&mut group, "add_same_scale", &batch, &old, &new);
121+
}
122+
123+
// Case 2: Add with different scales - Decimal128(38,6) + Decimal128(38,4) -> Decimal128(38,6)
124+
{
125+
let batch = make_decimal_batch(38, 6, 38, 4);
126+
let old = build_old_expr(38, 6, 38, 4, Operator::Plus, DataType::Decimal128(38, 6));
127+
let new = build_new_expr(WideDecimalOp::Add, 38, 6);
128+
bench_case(&mut group, "add_diff_scale", &batch, &old, &new);
129+
}
130+
131+
// Case 3: Multiply - Decimal128(20,10) * Decimal128(20,10) -> Decimal128(38,6)
132+
// Triggers wide path because p1 + p2 = 40 >= 38
133+
{
134+
let batch = make_decimal_batch(20, 10, 20, 10);
135+
let old = build_old_expr(
136+
20,
137+
10,
138+
20,
139+
10,
140+
Operator::Multiply,
141+
DataType::Decimal128(38, 6),
142+
);
143+
let new = build_new_expr(WideDecimalOp::Multiply, 38, 6);
144+
bench_case(&mut group, "multiply", &batch, &old, &new);
145+
}
146+
147+
// Case 4: Subtract with same scale - Decimal128(38,18) - Decimal128(38,18) -> Decimal128(38,18)
148+
{
149+
let batch = make_decimal_batch(38, 18, 38, 18);
150+
let old = build_old_expr(
151+
38,
152+
18,
153+
38,
154+
18,
155+
Operator::Minus,
156+
DataType::Decimal128(38, 18),
157+
);
158+
let new = build_new_expr(WideDecimalOp::Subtract, 38, 18);
159+
bench_case(&mut group, "subtract", &batch, &old, &new);
160+
}
161+
162+
group.finish();
163+
}
164+
165+
criterion_group!(benches, criterion_benchmark);
166+
criterion_main!(benches);

native/spark-expr/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ pub use json_funcs::{FromJson, ToJson};
8080
pub use math_funcs::{
8181
create_modulo_expr, create_negate_expr, spark_ceil, spark_decimal_div,
8282
spark_decimal_integral_div, spark_floor, spark_make_decimal, spark_round, spark_unhex,
83-
spark_unscaled_value, CheckOverflow, NegativeExpr, NormalizeNaNAndZero,
83+
spark_unscaled_value, CheckOverflow, DecimalRescaleCheckOverflow, NegativeExpr,
84+
NormalizeNaNAndZero, WideDecimalBinaryExpr, WideDecimalOp,
8485
};
8586
pub use query_context::{create_query_context_map, QueryContext, QueryContextMap};
8687
pub use string_funcs::*;

0 commit comments

Comments
 (0)