diff --git a/sqlx-sqlite/src/regexp.rs b/sqlx-sqlite/src/regexp.rs index eb14fffc77..b525992d9d 100644 --- a/sqlx-sqlite/src/regexp.rs +++ b/sqlx-sqlite/src/regexp.rs @@ -136,8 +136,11 @@ unsafe fn get_regex_from_arg( Some(regex) } -/// Get a text reference of the value of `arg`. If this value is not a string value, an error is printed and `None` is -/// returned. +/// Get a text reference of the value of `arg`. Returns `None` for NULL values. +/// +/// For non-NULL values, `sqlite3_value_text()` is called directly, which lets SQLite +/// coerce INTEGER, REAL, and BLOB values to their text representation. This matches +/// the coercion behavior documented at . /// /// The returned `&str` is valid for lifetime `'a` which can be determined by the caller. This lifetime should **not** /// outlive `ctx`. @@ -146,20 +149,19 @@ unsafe fn get_text_from_arg<'a>( arg: *mut ffi::sqlite3_value, ) -> Option<&'a str> { let ty = ffi::sqlite3_value_type(arg); - if ty == ffi::SQLITE_TEXT { - let ptr = ffi::sqlite3_value_text(arg); - let len = ffi::sqlite3_value_bytes(arg); - let slice = std::slice::from_raw_parts(ptr.cast(), len as usize); - match std::str::from_utf8(slice) { - Ok(result) => Some(result), - Err(e) => { - log::error!("Incoming text is not valid UTF8: {e:?}"); - ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_CONSTRAINT_FUNCTION); - None - } + if ty == ffi::SQLITE_NULL { + return None; + } + let ptr = ffi::sqlite3_value_text(arg); + let len = ffi::sqlite3_value_bytes(arg); + let slice = std::slice::from_raw_parts(ptr.cast(), len as usize); + match std::str::from_utf8(slice) { + Ok(result) => Some(result), + Err(e) => { + log::error!("Incoming text is not valid UTF8: {e:?}"); + ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_CONSTRAINT_FUNCTION); + None } - } else { - None } } @@ -222,6 +224,52 @@ mod tests { assert!(result.is_empty()); } + #[sqlx::test] + async fn test_regexp_coerces_non_text_values() { + let mut conn = crate::SqliteConnectOptions::from_str("sqlite://:memory:") + .unwrap() + .with_regexp() + .connect() + .await + .unwrap(); + + // INTEGER coercion + let result: Option = sqlx::query_scalar("SELECT 123 REGEXP '23'") + .fetch_one(&mut conn) + .await + .unwrap(); + assert_eq!(result, Some(1)); + + // REAL coercion + let result: Option = sqlx::query_scalar("SELECT 12.5 REGEXP '12\\.5'") + .fetch_one(&mut conn) + .await + .unwrap(); + assert_eq!(result, Some(1)); + + // INTEGER column + sqlx::query("CREATE TABLE int_test (x INTEGER NOT NULL)") + .execute(&mut conn) + .await + .unwrap(); + sqlx::query("INSERT INTO int_test VALUES (123), (45)") + .execute(&mut conn) + .await + .unwrap(); + let rows: Vec = sqlx::query_scalar("SELECT x FROM int_test WHERE x REGEXP '23'") + .fetch_all(&mut conn) + .await + .unwrap(); + assert_eq!(rows, vec![123]); + + // NULL should return NULL, not match + let result: Option = sqlx::query_scalar("SELECT NULL REGEXP '.*'") + .fetch_one(&mut conn) + .await + .unwrap(); + assert_eq!(result, None); + } + #[sqlx::test] async fn test_invalid_regexp_should_fail() { let mut conn = test_db().await;