Skip to content

Commit 4d2c398

Browse files
authored
fix: fix string to timestamp cast for UTC timestamps (#3656)
* fix: fix string to timestamp cast for UTC timestamps
1 parent 8867501 commit 4d2c398

8 files changed

Lines changed: 264 additions & 101 deletions

File tree

native/common/src/error.rs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,18 @@ pub enum SparkError {
3232
to_type: String,
3333
},
3434

35+
/// Like CastInvalidValue but maps to SparkDateTimeException instead of SparkNumberFormatException.
36+
/// Used for string → timestamp/date cast failures.
37+
#[error("[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \
38+
because it is malformed. Correct the value as per the syntax, or change its target type. \
39+
Use `try_cast` to tolerate malformed input and return NULL instead. If necessary \
40+
set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
41+
InvalidInputInCastToDatetime {
42+
value: String,
43+
from_type: String,
44+
to_type: String,
45+
},
46+
3547
#[error("[NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION] {value} cannot be represented as Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error, and return NULL instead.")]
3648
NumericValueOutOfRange {
3749
value: String,
@@ -208,6 +220,7 @@ impl SparkError {
208220
pub(crate) fn error_type_name(&self) -> &'static str {
209221
match self {
210222
SparkError::CastInvalidValue { .. } => "CastInvalidValue",
223+
SparkError::InvalidInputInCastToDatetime { .. } => "InvalidInputInCastToDatetime",
211224
SparkError::NumericValueOutOfRange { .. } => "NumericValueOutOfRange",
212225
SparkError::NumericOutOfRange { .. } => "NumericOutOfRange",
213226
SparkError::CastOverFlow { .. } => "CastOverFlow",
@@ -266,6 +279,17 @@ impl SparkError {
266279
"toType": to_type,
267280
})
268281
}
282+
SparkError::InvalidInputInCastToDatetime {
283+
value,
284+
from_type,
285+
to_type,
286+
} => {
287+
serde_json::json!({
288+
"value": value,
289+
"fromType": from_type,
290+
"toType": to_type,
291+
})
292+
}
269293
SparkError::NumericValueOutOfRange {
270294
value,
271295
precision,
@@ -505,7 +529,8 @@ impl SparkError {
505529
| SparkError::ScalarSubqueryTooManyRows => "org/apache/spark/SparkRuntimeException",
506530

507531
// DateTimeException
508-
SparkError::CannotParseTimestamp { .. }
532+
SparkError::InvalidInputInCastToDatetime { .. }
533+
| SparkError::CannotParseTimestamp { .. }
509534
| SparkError::InvalidFractionOfSecond { .. } => "org/apache/spark/SparkDateTimeException",
510535

511536
// IllegalArgumentException
@@ -530,6 +555,7 @@ impl SparkError {
530555
match self {
531556
// Cast errors
532557
SparkError::CastInvalidValue { .. } => Some("CAST_INVALID_INPUT"),
558+
SparkError::InvalidInputInCastToDatetime { .. } => Some("CAST_INVALID_INPUT"),
533559
SparkError::CastOverFlow { .. } => Some("CAST_OVERFLOW"),
534560
SparkError::NumericValueOutOfRange { .. } => {
535561
Some("NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION")

native/spark-expr/src/conversion_funcs/string.rs

Lines changed: 101 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,35 @@ use num::{CheckedSub, Integer};
3131
use regex::Regex;
3232
use std::num::Wrapping;
3333
use std::str::FromStr;
34-
use std::sync::Arc;
34+
use std::sync::{Arc, LazyLock};
3535

3636
macro_rules! cast_utf8_to_timestamp {
3737
($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident, $tz:expr) => {{
3838
let len = $array.len();
3939
let mut cast_array = PrimitiveArray::<$array_type>::builder(len).with_timezone("UTC");
40+
let mut cast_err: Option<SparkError> = None;
4041
for i in 0..len {
4142
if $array.is_null(i) {
4243
cast_array.append_null()
43-
} else if let Ok(Some(cast_value)) =
44-
$cast_method($array.value(i).trim(), $eval_mode, $tz)
45-
{
46-
cast_array.append_value(cast_value);
4744
} else {
48-
cast_array.append_null()
45+
match $cast_method($array.value(i).trim(), $eval_mode, $tz) {
46+
Ok(Some(cast_value)) => cast_array.append_value(cast_value),
47+
Ok(None) => cast_array.append_null(),
48+
Err(e) => {
49+
if $eval_mode == EvalMode::Ansi {
50+
cast_err = Some(e);
51+
break;
52+
}
53+
cast_array.append_null()
54+
}
55+
}
4956
}
5057
}
51-
let result: ArrayRef = Arc::new(cast_array.finish()) as ArrayRef;
52-
result
58+
if let Some(e) = cast_err {
59+
Err(e)
60+
} else {
61+
Ok(Arc::new(cast_array.finish()) as ArrayRef)
62+
}
5363
}};
5464
}
5565

@@ -668,15 +678,13 @@ pub(crate) fn cast_string_to_timestamp(
668678
let tz = &timezone::Tz::from_str(timezone_str).unwrap();
669679

670680
let cast_array: ArrayRef = match to_type {
671-
DataType::Timestamp(_, _) => {
672-
cast_utf8_to_timestamp!(
673-
string_array,
674-
eval_mode,
675-
TimestampMicrosecondType,
676-
timestamp_parser,
677-
tz
678-
)
679-
}
681+
DataType::Timestamp(_, _) => cast_utf8_to_timestamp!(
682+
string_array,
683+
eval_mode,
684+
TimestampMicrosecondType,
685+
timestamp_parser,
686+
tz
687+
)?,
680688
_ => unreachable!("Invalid data type {:?} in cast from string", to_type),
681689
};
682690
Ok(cast_array)
@@ -961,6 +969,12 @@ fn get_timestamp_values<T: TimeZone>(
961969
) -> SparkResult<Option<i64>> {
962970
let values: Vec<_> = value.split(['T', '-', ':', '.']).collect();
963971
let year = values[0].parse::<i32>().unwrap_or_default();
972+
973+
// NaiveDate (used internally by chrono's with_ymd_and_hms) is bounded to ±262142.
974+
if !(-262143..=262142).contains(&year) {
975+
return Ok(None);
976+
}
977+
964978
let month = values.get(1).map_or(1, |m| m.parse::<u32>().unwrap_or(1));
965979
let day = values.get(2).map_or(1, |d| d.parse::<u32>().unwrap_or(1));
966980
let hour = values.get(3).map_or(0, |h| h.parse::<u32>().unwrap_or(0));
@@ -1004,7 +1018,7 @@ fn get_timestamp_values<T: TimeZone>(
10041018
.with_second(second)
10051019
.with_microsecond(microsecond),
10061020
_ => {
1007-
return Err(SparkError::CastInvalidValue {
1021+
return Err(SparkError::InvalidInputInCastToDatetime {
10081022
value: value.to_string(),
10091023
from_type: "STRING".to_string(),
10101024
to_type: "TIMESTAMP".to_string(),
@@ -1082,7 +1096,21 @@ fn parse_str_to_microsecond_timestamp<T: TimeZone>(
10821096
get_timestamp_values(value, "microsecond", tz)
10831097
}
10841098

1085-
// used in tests only
1099+
type TimestampPattern<T> = (&'static Regex, fn(&str, &T) -> SparkResult<Option<i64>>);
1100+
1101+
static RE_YEAR: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^\d{4,7}$").unwrap());
1102+
static RE_MONTH: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^\d{4,7}-\d{2}$").unwrap());
1103+
static RE_DAY: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^\d{4,7}-\d{2}-\d{2}$").unwrap());
1104+
static RE_HOUR: LazyLock<Regex> =
1105+
LazyLock::new(|| Regex::new(r"^\d{4,7}-\d{2}-\d{2}T\d{1,2}$").unwrap());
1106+
static RE_MINUTE: LazyLock<Regex> =
1107+
LazyLock::new(|| Regex::new(r"^\d{4,7}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap());
1108+
static RE_SECOND: LazyLock<Regex> =
1109+
LazyLock::new(|| Regex::new(r"^\d{4,7}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap());
1110+
static RE_MICROSECOND: LazyLock<Regex> =
1111+
LazyLock::new(|| Regex::new(r"^\d{4,7}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap());
1112+
static RE_TIME_ONLY: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^T\d{1,2}$").unwrap());
1113+
10861114
fn timestamp_parser<T: TimeZone>(
10871115
value: &str,
10881116
eval_mode: EvalMode,
@@ -1092,40 +1120,15 @@ fn timestamp_parser<T: TimeZone>(
10921120
if value.is_empty() {
10931121
return Ok(None);
10941122
}
1095-
// Define regex patterns and corresponding parsing functions
1096-
let patterns = &[
1097-
(
1098-
Regex::new(r"^\d{4,5}$").unwrap(),
1099-
parse_str_to_year_timestamp as fn(&str, &T) -> SparkResult<Option<i64>>,
1100-
),
1101-
(
1102-
Regex::new(r"^\d{4,5}-\d{2}$").unwrap(),
1103-
parse_str_to_month_timestamp,
1104-
),
1105-
(
1106-
Regex::new(r"^\d{4,5}-\d{2}-\d{2}$").unwrap(),
1107-
parse_str_to_day_timestamp,
1108-
),
1109-
(
1110-
Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{1,2}$").unwrap(),
1111-
parse_str_to_hour_timestamp,
1112-
),
1113-
(
1114-
Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap(),
1115-
parse_str_to_minute_timestamp,
1116-
),
1117-
(
1118-
Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap(),
1119-
parse_str_to_second_timestamp,
1120-
),
1121-
(
1122-
Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap(),
1123-
parse_str_to_microsecond_timestamp,
1124-
),
1125-
(
1126-
Regex::new(r"^T\d{1,2}$").unwrap(),
1127-
parse_str_to_time_only_timestamp,
1128-
),
1123+
let patterns: &[TimestampPattern<T>] = &[
1124+
(&RE_YEAR, parse_str_to_year_timestamp),
1125+
(&RE_MONTH, parse_str_to_month_timestamp),
1126+
(&RE_DAY, parse_str_to_day_timestamp),
1127+
(&RE_HOUR, parse_str_to_hour_timestamp),
1128+
(&RE_MINUTE, parse_str_to_minute_timestamp),
1129+
(&RE_SECOND, parse_str_to_second_timestamp),
1130+
(&RE_MICROSECOND, parse_str_to_microsecond_timestamp),
1131+
(&RE_TIME_ONLY, parse_str_to_time_only_timestamp),
11291132
];
11301133

11311134
let mut timestamp = None;
@@ -1140,7 +1143,7 @@ fn timestamp_parser<T: TimeZone>(
11401143

11411144
if timestamp.is_none() {
11421145
return if eval_mode == EvalMode::Ansi {
1143-
Err(SparkError::CastInvalidValue {
1146+
Err(SparkError::InvalidInputInCastToDatetime {
11441147
value: value.to_string(),
11451148
from_type: "STRING".to_string(),
11461149
to_type: "TIMESTAMP".to_string(),
@@ -1150,12 +1153,7 @@ fn timestamp_parser<T: TimeZone>(
11501153
};
11511154
}
11521155

1153-
match timestamp {
1154-
Some(ts) => Ok(Some(ts)),
1155-
None => Err(SparkError::Internal(
1156-
"Failed to parse timestamp".to_string(),
1157-
)),
1158-
}
1156+
Ok(timestamp)
11591157
}
11601158

11611159
fn parse_str_to_time_only_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
@@ -1202,17 +1200,20 @@ fn date_parser(date_str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>>
12021200
}
12031201

12041202
fn is_valid_digits(segment: i32, digits: usize) -> bool {
1205-
// An integer is able to represent a date within [+-]5 million years.
1203+
// NaiveDate is bounded to [-262142, 262142] (6 digits). We allow up to 7 digits to support
1204+
// leading-zero year strings like "0002020" (= year 2020), matching Spark's
1205+
// isValidDigits. Values outside the bounds are caught by an explicit bounds
1206+
// check below.
12061207
let max_digits_year = 7;
1207-
//year (segment 0) can be between 4 to 7 digits,
1208-
//month and day (segment 1 and 2) can be between 1 to 2 digits
1208+
// year (segment 0) can be between 4 to 7 digits,
1209+
// month and day (segment 1 and 2) can be between 1 to 2 digits
12091210
(segment == 0 && digits >= 4 && digits <= max_digits_year)
12101211
|| (segment != 0 && digits > 0 && digits <= 2)
12111212
}
12121213

12131214
fn return_result(date_str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>> {
12141215
if eval_mode == EvalMode::Ansi {
1215-
Err(SparkError::CastInvalidValue {
1216+
Err(SparkError::InvalidInputInCastToDatetime {
12161217
value: date_str.to_string(),
12171218
from_type: "STRING".to_string(),
12181219
to_type: "DATE".to_string(),
@@ -1285,11 +1286,13 @@ fn date_parser(date_str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>>
12851286

12861287
date_segments[current_segment as usize] = current_segment_value.0;
12871288

1288-
match NaiveDate::from_ymd_opt(
1289-
sign * date_segments[0],
1290-
date_segments[1] as u32,
1291-
date_segments[2] as u32,
1292-
) {
1289+
// Reject out-of-range years explicitly
1290+
let year = sign * date_segments[0];
1291+
if !(-262143..=262142).contains(&year) {
1292+
return Ok(None);
1293+
}
1294+
1295+
match NaiveDate::from_ymd_opt(year, date_segments[1] as u32, date_segments[2] as u32) {
12931296
Some(date) => {
12941297
let duration_since_epoch = date
12951298
.signed_duration_since(DateTime::UNIX_EPOCH.naive_utc().date())
@@ -1341,7 +1344,8 @@ mod tests {
13411344
TimestampMicrosecondType,
13421345
timestamp_parser,
13431346
tz
1344-
);
1347+
)
1348+
.unwrap();
13451349

13461350
assert_eq!(
13471351
result.data_type(),
@@ -1350,6 +1354,33 @@ mod tests {
13501354
assert_eq!(result.len(), 4);
13511355
}
13521356

1357+
#[test]
1358+
fn test_cast_string_to_timestamp_ansi_error() {
1359+
// In ANSI mode, an invalid timestamp string must produce an error rather than null.
1360+
let array: ArrayRef = Arc::new(StringArray::from(vec![
1361+
Some("2020-01-01T12:34:56.123456"),
1362+
Some("not_a_timestamp"),
1363+
]));
1364+
let tz = &timezone::Tz::from_str("UTC").unwrap();
1365+
let string_array = array
1366+
.as_any()
1367+
.downcast_ref::<GenericStringArray<i32>>()
1368+
.expect("Expected a string array");
1369+
1370+
let eval_mode = EvalMode::Ansi;
1371+
let result = cast_utf8_to_timestamp!(
1372+
&string_array,
1373+
eval_mode,
1374+
TimestampMicrosecondType,
1375+
timestamp_parser,
1376+
tz
1377+
);
1378+
assert!(
1379+
result.is_err(),
1380+
"ANSI mode should return Err for an invalid timestamp string"
1381+
);
1382+
}
1383+
13531384
#[test]
13541385
fn test_cast_dict_string_to_timestamp() -> DataFusionResult<()> {
13551386
// prepare input data

spark/src/main/scala/org/apache/comet/SparkErrorConverter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ object SparkErrorConverter extends ShimSparkErrorConverter {
100100
case None => Array.empty[QueryContext] // No context
101101
}
102102

103-
val summary: String = errorJson.summary.orNull
103+
val summary: String = errorJson.summary.getOrElse("")
104104

105105
// Delegate to version-specific shim - let conversion exceptions propagate
106106
val optEx = convertErrorType(errorJson.errorType, errorClass, params, sparkContext, summary)

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,7 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
217217
Compatible(Some("Only supports years between 262143 BC and 262142 AD"))
218218
case DataTypes.TimestampType if timeZoneId.exists(tz => tz != "UTC") =>
219219
Incompatible(Some(s"Cast will use UTC instead of $timeZoneId"))
220-
case DataTypes.TimestampType if evalMode == CometEvalMode.ANSI =>
221-
Incompatible(Some("ANSI mode not supported"))
222220
case DataTypes.TimestampType =>
223-
// https://github.com/apache/datafusion-comet/issues/328
224221
Incompatible(Some("Not all valid formats are supported"))
225222
case _ =>
226223
unsupported(DataTypes.StringType, toType)

spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import java.io.FileNotFoundException
2323

2424
import scala.util.matching.Regex
2525

26-
import org.apache.spark.{QueryContext, SparkException}
26+
import org.apache.spark.{QueryContext, SparkDateTimeException, SparkException}
2727
import org.apache.spark.sql.catalyst.trees.SQLQueryContext
2828
import org.apache.spark.sql.errors.QueryExecutionErrors
2929
import org.apache.spark.sql.types._
@@ -172,6 +172,22 @@ trait ShimSparkErrorConverter {
172172
QueryExecutionErrors
173173
.invalidInputInCastToNumberError(targetType, str, sqlCtx(context)))
174174

175+
case "InvalidInputInCastToDatetime" =>
176+
val expression =
177+
s"'${params("value").toString.replace("\\", "\\\\").replace("'", "\\'")}'"
178+
val sourceType = s""""${params("fromType").toString}""""
179+
val targetType = s""""${params("toType").toString}""""
180+
Some(
181+
new SparkDateTimeException(
182+
errorClass = "CAST_INVALID_INPUT",
183+
messageParameters = Map(
184+
"expression" -> expression,
185+
"sourceType" -> sourceType,
186+
"targetType" -> targetType,
187+
"ansiConfig" -> "\"spark.sql.ansi.enabled\""),
188+
context = context,
189+
summary = summary))
190+
175191
case "CastOverFlow" =>
176192
val fromType = getDataType(params("fromType").toString)
177193
val toType = getDataType(params("toType").toString)

0 commit comments

Comments
 (0)