implement Iterator for SmallVec and add const generic for array size

This commit is contained in:
Jussi Saurio
2025-04-14 11:02:15 +03:00
parent f5523e7a23
commit f79da7194f

View File

@@ -1057,13 +1057,16 @@ pub fn validate_serial_type(value: u64) -> Result<SerialType> {
}
}
struct SmallVec<T> {
pub data: [std::mem::MaybeUninit<T>; 64],
pub struct SmallVec<T, const N: usize = 64> {
/// Stack allocated data
pub data: [std::mem::MaybeUninit<T>; N],
/// Length of the vector, accounting for both stack and heap allocated data
pub len: usize,
/// Extra data on heap
pub extra_data: Option<Vec<T>>,
}
impl<T: Default + Copy> SmallVec<T> {
impl<T: Default + Copy, const N: usize> SmallVec<T, N> {
pub fn new() -> Self {
Self {
data: unsafe { std::mem::MaybeUninit::uninit().assume_init() },
@@ -1084,6 +1087,50 @@ impl<T: Default + Copy> SmallVec<T> {
self.len += 1;
}
}
fn get_from_heap(&self, index: usize) -> T {
assert!(self.extra_data.is_some());
assert!(index >= self.data.len());
let extra_data_index = index - self.data.len();
let extra_data = self.extra_data.as_ref().unwrap();
assert!(extra_data_index < extra_data.len());
extra_data[extra_data_index]
}
pub fn get(&self, index: usize) -> Option<T> {
if index >= self.len {
return None;
}
let data_is_on_stack = index < self.data.len();
if data_is_on_stack {
// SAFETY: We know this index is initialized we checked for index < self.len earlier above.
unsafe { Some(self.data[index].assume_init()) }
} else {
Some(self.get_from_heap(index))
}
}
}
impl<T: Default + Copy, const N: usize> SmallVec<T, N> {
pub fn iter(&self) -> SmallVecIter<'_, T, N> {
SmallVecIter { vec: self, pos: 0 }
}
}
pub struct SmallVecIter<'a, T, const N: usize> {
vec: &'a SmallVec<T, N>,
pos: usize,
}
impl<'a, T: Default + Copy, const N: usize> Iterator for SmallVecIter<'a, T, N> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
self.vec.get(self.pos).map(|item| {
self.pos += 1;
item
})
}
}
pub fn read_record(payload: &[u8], reuse_immutable: &mut ImmutableRecord) -> Result<()> {
@@ -1099,7 +1146,7 @@ pub fn read_record(payload: &[u8], reuse_immutable: &mut ImmutableRecord) -> Res
let mut header_size = (header_size as usize) - nr;
pos += nr;
let mut serial_types = SmallVec::new();
let mut serial_types = SmallVec::<u64, 64>::new();
while header_size > 0 {
let (serial_type, nr) = read_varint(&reuse_immutable.get_payload()[pos..])?;
let serial_type = validate_serial_type(serial_type)?;
@@ -1685,4 +1732,33 @@ mod tests {
let result = validate_serial_type(10);
assert!(result.is_err());
}
#[test]
fn test_smallvec_iter() {
let mut small_vec = SmallVec::<i32, 4>::new();
(0..8).for_each(|i| small_vec.push(i));
let mut iter = small_vec.iter();
assert_eq!(iter.next(), Some(0));
assert_eq!(iter.next(), Some(1));
assert_eq!(iter.next(), Some(2));
assert_eq!(iter.next(), Some(3));
assert_eq!(iter.next(), Some(4));
assert_eq!(iter.next(), Some(5));
assert_eq!(iter.next(), Some(6));
assert_eq!(iter.next(), Some(7));
assert_eq!(iter.next(), None);
}
#[test]
fn test_smallvec_get() {
let mut small_vec = SmallVec::<i32, 4>::new();
(0..8).for_each(|i| small_vec.push(i));
(0..8).for_each(|i| {
assert_eq!(small_vec.get(i), Some(i as i32));
});
assert_eq!(small_vec.get(8), None);
}
}