Skip to content

Commit a05c568

Browse files
authored
perf: use aligned pointer reads for SparkUnsafeRow field accessors (#3670)
* perf: use aligned pointer reads for SparkUnsafeRow field accessors SparkUnsafeRow field offsets are always 8-byte aligned (the JVM guarantees 8-byte alignment on the base address, bitset_width is a multiple of 8, and each field slot is 8 bytes). This means we can safely use ptr::read() instead of the from_le_bytes(slice) pattern for all typed accesses, avoiding slice creation and try_into overhead. Move primitive accessor implementations out of the SparkUnsafeObject trait defaults and into each concrete impl via a macro parameterized on the read method: - SparkUnsafeRow uses ptr::read() (aligned) - SparkUnsafeArray uses ptr::read_unaligned() (may be unaligned when nested in a row's variable-length region) Also switch is_null_at/set_not_null_at in SparkUnsafeRow from read_unaligned/write_unaligned to aligned read/write, since the null bitset is always at 8-byte aligned offsets within the row. * fix: use 8-byte aligned buffer in SparkUnsafeRow Miri test The test_append_null_struct_field_to_struct_builder test used a plain [u8; 16] stack buffer with no alignment guarantee. Since is_null_at performs aligned i64 reads, Miri flags this as undefined behavior when the buffer lands at a non-8-byte-aligned address. Wrap the buffer in a #[repr(align(8))] struct to match the alignment that real Spark UnsafeRow data always has from JVM memory.
1 parent 42c806b commit a05c568

2 files changed

Lines changed: 132 additions & 103 deletions

File tree

native/core/src/execution/shuffle/spark_unsafe/list.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ use crate::{
1919
errors::CometError,
2020
execution::shuffle::spark_unsafe::{
2121
map::append_map_elements,
22-
row::{append_field, downcast_builder_ref, SparkUnsafeObject, SparkUnsafeRow},
22+
row::{
23+
append_field, downcast_builder_ref, impl_primitive_accessors, SparkUnsafeObject,
24+
SparkUnsafeRow,
25+
},
2326
},
2427
};
2528
use arrow::array::{
@@ -101,6 +104,10 @@ impl SparkUnsafeObject for SparkUnsafeArray {
101104
fn get_element_offset(&self, index: usize, element_size: usize) -> *const u8 {
102105
(self.element_offset + (index * element_size) as i64) as *const u8
103106
}
107+
108+
// SparkUnsafeArray base address may be unaligned when nested within a row's variable-length
109+
// region, so we must use ptr::read_unaligned() for all typed accesses.
110+
impl_primitive_accessors!(read_unaligned);
104111
}
105112

106113
impl SparkUnsafeArray {

native/core/src/execution/shuffle/spark_unsafe/row.rs

Lines changed: 124 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,32 @@ const NESTED_TYPE_BUILDER_CAPACITY: usize = 100;
7171
/// safe to call as long as:
7272
/// - The index is within bounds (caller's responsibility)
7373
/// - The object was constructed from valid Spark UnsafeRow/UnsafeArray data
74+
///
75+
/// # Alignment
76+
///
77+
/// Primitive accessor methods are implemented separately for each type because they have
78+
/// different alignment guarantees:
79+
/// - `SparkUnsafeRow`: All field offsets are 8-byte aligned (bitset width is a multiple of 8,
80+
/// and each field slot is 8 bytes), so accessors use aligned `ptr::read()`.
81+
/// - `SparkUnsafeArray`: The array base address may be unaligned when nested within a row's
82+
/// variable-length region, so accessors use `ptr::read_unaligned()`.
7483
pub trait SparkUnsafeObject {
7584
/// Returns the address of the row.
7685
fn get_row_addr(&self) -> i64;
7786

7887
/// Returns the offset of the element at the given index.
7988
fn get_element_offset(&self, index: usize, element_size: usize) -> *const u8;
8089

90+
fn get_boolean(&self, index: usize) -> bool;
91+
fn get_byte(&self, index: usize) -> i8;
92+
fn get_short(&self, index: usize) -> i16;
93+
fn get_int(&self, index: usize) -> i32;
94+
fn get_long(&self, index: usize) -> i64;
95+
fn get_float(&self, index: usize) -> f32;
96+
fn get_double(&self, index: usize) -> f64;
97+
fn get_date(&self, index: usize) -> i32;
98+
fn get_timestamp(&self, index: usize) -> i64;
99+
81100
/// Returns the offset and length of the element at the given index.
82101
#[inline]
83102
fn get_offset_and_len(&self, index: usize) -> (i32, i32) {
@@ -87,79 +106,6 @@ pub trait SparkUnsafeObject {
87106
(offset, len)
88107
}
89108

90-
/// Returns boolean value at the given index of the object.
91-
#[inline]
92-
fn get_boolean(&self, index: usize) -> bool {
93-
let addr = self.get_element_offset(index, 1);
94-
// SAFETY: addr points to valid element data within the UnsafeRow/UnsafeArray region.
95-
// The caller ensures index is within bounds.
96-
debug_assert!(
97-
!addr.is_null(),
98-
"get_boolean: null pointer at index {index}"
99-
);
100-
unsafe { *addr != 0 }
101-
}
102-
103-
/// Returns byte value at the given index of the object.
104-
#[inline]
105-
fn get_byte(&self, index: usize) -> i8 {
106-
let addr = self.get_element_offset(index, 1);
107-
// SAFETY: addr points to valid element data (1 byte) within the row/array region.
108-
debug_assert!(!addr.is_null(), "get_byte: null pointer at index {index}");
109-
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 1) };
110-
i8::from_le_bytes(slice.try_into().unwrap())
111-
}
112-
113-
/// Returns short value at the given index of the object.
114-
#[inline]
115-
fn get_short(&self, index: usize) -> i16 {
116-
let addr = self.get_element_offset(index, 2);
117-
// SAFETY: addr points to valid element data (2 bytes) within the row/array region.
118-
debug_assert!(!addr.is_null(), "get_short: null pointer at index {index}");
119-
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 2) };
120-
i16::from_le_bytes(slice.try_into().unwrap())
121-
}
122-
123-
/// Returns integer value at the given index of the object.
124-
#[inline]
125-
fn get_int(&self, index: usize) -> i32 {
126-
let addr = self.get_element_offset(index, 4);
127-
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
128-
debug_assert!(!addr.is_null(), "get_int: null pointer at index {index}");
129-
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) };
130-
i32::from_le_bytes(slice.try_into().unwrap())
131-
}
132-
133-
/// Returns long value at the given index of the object.
134-
#[inline]
135-
fn get_long(&self, index: usize) -> i64 {
136-
let addr = self.get_element_offset(index, 8);
137-
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
138-
debug_assert!(!addr.is_null(), "get_long: null pointer at index {index}");
139-
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) };
140-
i64::from_le_bytes(slice.try_into().unwrap())
141-
}
142-
143-
/// Returns float value at the given index of the object.
144-
#[inline]
145-
fn get_float(&self, index: usize) -> f32 {
146-
let addr = self.get_element_offset(index, 4);
147-
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
148-
debug_assert!(!addr.is_null(), "get_float: null pointer at index {index}");
149-
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) };
150-
f32::from_le_bytes(slice.try_into().unwrap())
151-
}
152-
153-
/// Returns double value at the given index of the object.
154-
#[inline]
155-
fn get_double(&self, index: usize) -> f64 {
156-
let addr = self.get_element_offset(index, 8);
157-
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
158-
debug_assert!(!addr.is_null(), "get_double: null pointer at index {index}");
159-
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) };
160-
f64::from_le_bytes(slice.try_into().unwrap())
161-
}
162-
163109
/// Returns string value at the given index of the object.
164110
fn get_string(&self, index: usize) -> &str {
165111
let (offset, len) = self.get_offset_and_len(index);
@@ -190,29 +136,6 @@ pub trait SparkUnsafeObject {
190136
unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) }
191137
}
192138

193-
/// Returns date value at the given index of the object.
194-
#[inline]
195-
fn get_date(&self, index: usize) -> i32 {
196-
let addr = self.get_element_offset(index, 4);
197-
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
198-
debug_assert!(!addr.is_null(), "get_date: null pointer at index {index}");
199-
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) };
200-
i32::from_le_bytes(slice.try_into().unwrap())
201-
}
202-
203-
/// Returns timestamp value at the given index of the object.
204-
#[inline]
205-
fn get_timestamp(&self, index: usize) -> i64 {
206-
let addr = self.get_element_offset(index, 8);
207-
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
208-
debug_assert!(
209-
!addr.is_null(),
210-
"get_timestamp: null pointer at index {index}"
211-
);
212-
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) };
213-
i64::from_le_bytes(slice.try_into().unwrap())
214-
}
215-
216139
/// Returns decimal value at the given index of the object.
217140
fn get_decimal(&self, index: usize, precision: u8) -> i128 {
218141
if precision <= MAX_LONG_DIGITS {
@@ -244,6 +167,94 @@ pub trait SparkUnsafeObject {
244167
}
245168
}
246169

170+
/// Generates primitive accessor implementations for `SparkUnsafeObject`.
171+
///
172+
/// Uses `$read_method` to read typed values from raw pointers:
173+
/// - `read` for aligned access (SparkUnsafeRow — all offsets are 8-byte aligned)
174+
/// - `read_unaligned` for potentially unaligned access (SparkUnsafeArray)
175+
macro_rules! impl_primitive_accessors {
176+
($read_method:ident) => {
177+
#[inline]
178+
fn get_boolean(&self, index: usize) -> bool {
179+
let addr = self.get_element_offset(index, 1);
180+
debug_assert!(
181+
!addr.is_null(),
182+
"get_boolean: null pointer at index {index}"
183+
);
184+
// SAFETY: addr points to valid element data within the row/array region.
185+
unsafe { *addr != 0 }
186+
}
187+
188+
#[inline]
189+
fn get_byte(&self, index: usize) -> i8 {
190+
let addr = self.get_element_offset(index, 1);
191+
debug_assert!(!addr.is_null(), "get_byte: null pointer at index {index}");
192+
// SAFETY: addr points to valid element data (1 byte) within the row/array region.
193+
unsafe { *(addr as *const i8) }
194+
}
195+
196+
#[inline]
197+
fn get_short(&self, index: usize) -> i16 {
198+
let addr = self.get_element_offset(index, 2) as *const i16;
199+
debug_assert!(!addr.is_null(), "get_short: null pointer at index {index}");
200+
// SAFETY: addr points to valid element data (2 bytes) within the row/array region.
201+
unsafe { addr.$read_method() }
202+
}
203+
204+
#[inline]
205+
fn get_int(&self, index: usize) -> i32 {
206+
let addr = self.get_element_offset(index, 4) as *const i32;
207+
debug_assert!(!addr.is_null(), "get_int: null pointer at index {index}");
208+
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
209+
unsafe { addr.$read_method() }
210+
}
211+
212+
#[inline]
213+
fn get_long(&self, index: usize) -> i64 {
214+
let addr = self.get_element_offset(index, 8) as *const i64;
215+
debug_assert!(!addr.is_null(), "get_long: null pointer at index {index}");
216+
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
217+
unsafe { addr.$read_method() }
218+
}
219+
220+
#[inline]
221+
fn get_float(&self, index: usize) -> f32 {
222+
let addr = self.get_element_offset(index, 4) as *const f32;
223+
debug_assert!(!addr.is_null(), "get_float: null pointer at index {index}");
224+
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
225+
unsafe { addr.$read_method() }
226+
}
227+
228+
#[inline]
229+
fn get_double(&self, index: usize) -> f64 {
230+
let addr = self.get_element_offset(index, 8) as *const f64;
231+
debug_assert!(!addr.is_null(), "get_double: null pointer at index {index}");
232+
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
233+
unsafe { addr.$read_method() }
234+
}
235+
236+
#[inline]
237+
fn get_date(&self, index: usize) -> i32 {
238+
let addr = self.get_element_offset(index, 4) as *const i32;
239+
debug_assert!(!addr.is_null(), "get_date: null pointer at index {index}");
240+
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
241+
unsafe { addr.$read_method() }
242+
}
243+
244+
#[inline]
245+
fn get_timestamp(&self, index: usize) -> i64 {
246+
let addr = self.get_element_offset(index, 8) as *const i64;
247+
debug_assert!(
248+
!addr.is_null(),
249+
"get_timestamp: null pointer at index {index}"
250+
);
251+
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
252+
unsafe { addr.$read_method() }
253+
}
254+
};
255+
}
256+
pub(crate) use impl_primitive_accessors;
257+
247258
pub struct SparkUnsafeRow {
248259
row_addr: i64,
249260
row_size: i32,
@@ -265,6 +276,11 @@ impl SparkUnsafeObject for SparkUnsafeRow {
265276
);
266277
(self.row_addr + offset) as *const u8
267278
}
279+
280+
// SparkUnsafeRow field offsets are always 8-byte aligned: the base address is 8-byte
281+
// aligned (JVM guarantee), bitset_width is a multiple of 8, and each field slot is
282+
// 8 bytes. This means we can safely use aligned ptr::read() for all typed accesses.
283+
impl_primitive_accessors!(read);
268284
}
269285

270286
impl Default for SparkUnsafeRow {
@@ -328,11 +344,13 @@ impl SparkUnsafeRow {
328344
// SAFETY: row_addr points to valid Spark UnsafeRow data with at least
329345
// ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields.
330346
// word_offset is within the bitset region since (index >> 6) << 3 < bitset size.
347+
// The bitset starts at row_addr (8-byte aligned) and each word is at offset 8*k,
348+
// so word_offset is always 8-byte aligned — we can use aligned ptr::read().
331349
debug_assert!(self.row_addr != -1, "is_null_at: row not initialized");
332350
unsafe {
333351
let mask: i64 = 1i64 << (index & 0x3f);
334352
let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *const i64;
335-
let word: i64 = word_offset.read_unaligned();
353+
let word: i64 = word_offset.read();
336354
(word & mask) != 0
337355
}
338356
}
@@ -343,12 +361,13 @@ impl SparkUnsafeRow {
343361
// ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields.
344362
// word_offset is within the bitset region since (index >> 6) << 3 < bitset size.
345363
// Writing is safe because we have mutable access and the memory is owned by the JVM.
364+
// The bitset is always 8-byte aligned — we can use aligned ptr::read()/write().
346365
debug_assert!(self.row_addr != -1, "set_not_null_at: row not initialized");
347366
unsafe {
348367
let mask: i64 = 1i64 << (index & 0x3f);
349368
let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *mut i64;
350-
let word: i64 = word_offset.read_unaligned();
351-
word_offset.write_unaligned(word & !mask);
369+
let word: i64 = word_offset.read();
370+
word_offset.write(word & !mask);
352371
}
353372
}
354373
}
@@ -1668,9 +1687,12 @@ mod test {
16681687
let mut row = SparkUnsafeRow::new_with_num_fields(1);
16691688
// 8 bytes null bitset + 8 bytes field value = 16 bytes
16701689
// Set bit 0 in the null bitset to mark field 0 as null
1671-
let mut data = [0u8; 16];
1672-
data[0] = 1;
1673-
row.point_to_slice(&data);
1690+
// Use aligned buffer to match real Spark UnsafeRow layout (8-byte aligned)
1691+
#[repr(align(8))]
1692+
struct Aligned([u8; 16]);
1693+
let mut data = Aligned([0u8; 16]);
1694+
data.0[0] = 1;
1695+
row.point_to_slice(&data.0);
16741696
append_field(&data_type, &mut struct_builder, &row, 0).expect("append field");
16751697
struct_builder.append_null();
16761698
let struct_array = struct_builder.finish();

0 commit comments

Comments
 (0)