From f79da7194f12a1093b7e5bdb922d87cdcd4a3311 Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Mon, 14 Apr 2025 11:02:15 +0300 Subject: [PATCH] implement Iterator for SmallVec and add const generic for array size --- core/storage/sqlite3_ondisk.rs | 84 ++++++++++++++++++++++++++++++++-- 1 file changed, 80 insertions(+), 4 deletions(-) diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index 10251ca51..5f742887e 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -1057,13 +1057,16 @@ pub fn validate_serial_type(value: u64) -> Result { } } -struct SmallVec { - pub data: [std::mem::MaybeUninit; 64], +pub struct SmallVec { + /// Stack allocated data + pub data: [std::mem::MaybeUninit; N], + /// Length of the vector, accounting for both stack and heap allocated data pub len: usize, + /// Extra data on heap pub extra_data: Option>, } -impl SmallVec { +impl SmallVec { pub fn new() -> Self { Self { data: unsafe { std::mem::MaybeUninit::uninit().assume_init() }, @@ -1084,6 +1087,50 @@ impl SmallVec { self.len += 1; } } + + fn get_from_heap(&self, index: usize) -> T { + assert!(self.extra_data.is_some()); + assert!(index >= self.data.len()); + let extra_data_index = index - self.data.len(); + let extra_data = self.extra_data.as_ref().unwrap(); + assert!(extra_data_index < extra_data.len()); + extra_data[extra_data_index] + } + + pub fn get(&self, index: usize) -> Option { + if index >= self.len { + return None; + } + let data_is_on_stack = index < self.data.len(); + if data_is_on_stack { + // SAFETY: We know this index is initialized we checked for index < self.len earlier above. + unsafe { Some(self.data[index].assume_init()) } + } else { + Some(self.get_from_heap(index)) + } + } +} + +impl SmallVec { + pub fn iter(&self) -> SmallVecIter<'_, T, N> { + SmallVecIter { vec: self, pos: 0 } + } +} + +pub struct SmallVecIter<'a, T, const N: usize> { + vec: &'a SmallVec, + pos: usize, +} + +impl<'a, T: Default + Copy, const N: usize> Iterator for SmallVecIter<'a, T, N> { + type Item = T; + + fn next(&mut self) -> Option { + self.vec.get(self.pos).map(|item| { + self.pos += 1; + item + }) + } } pub fn read_record(payload: &[u8], reuse_immutable: &mut ImmutableRecord) -> Result<()> { @@ -1099,7 +1146,7 @@ pub fn read_record(payload: &[u8], reuse_immutable: &mut ImmutableRecord) -> Res let mut header_size = (header_size as usize) - nr; pos += nr; - let mut serial_types = SmallVec::new(); + let mut serial_types = SmallVec::::new(); while header_size > 0 { let (serial_type, nr) = read_varint(&reuse_immutable.get_payload()[pos..])?; let serial_type = validate_serial_type(serial_type)?; @@ -1685,4 +1732,33 @@ mod tests { let result = validate_serial_type(10); assert!(result.is_err()); } + + #[test] + fn test_smallvec_iter() { + let mut small_vec = SmallVec::::new(); + (0..8).for_each(|i| small_vec.push(i)); + + let mut iter = small_vec.iter(); + assert_eq!(iter.next(), Some(0)); + assert_eq!(iter.next(), Some(1)); + assert_eq!(iter.next(), Some(2)); + assert_eq!(iter.next(), Some(3)); + assert_eq!(iter.next(), Some(4)); + assert_eq!(iter.next(), Some(5)); + assert_eq!(iter.next(), Some(6)); + assert_eq!(iter.next(), Some(7)); + assert_eq!(iter.next(), None); + } + + #[test] + fn test_smallvec_get() { + let mut small_vec = SmallVec::::new(); + (0..8).for_each(|i| small_vec.push(i)); + + (0..8).for_each(|i| { + assert_eq!(small_vec.get(i), Some(i as i32)); + }); + + assert_eq!(small_vec.get(8), None); + } }