From af6a783f4d557d99dbb35e66b3ed5b046e9bf7c4 Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Thu, 24 Apr 2025 15:36:54 +0300 Subject: [PATCH 1/2] core/types: remove duplicate serialtype implementation --- core/storage/sqlite3_ondisk.rs | 470 +++++++++++++-------------------- core/types.rs | 45 +++- 2 files changed, 221 insertions(+), 294 deletions(-) diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index fccf233b5..2e86bbdff 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -47,7 +47,7 @@ use crate::io::{Buffer, Completion, ReadCompletion, SyncCompletion, WriteComplet use crate::storage::buffer_pool::BufferPool; use crate::storage::database::DatabaseStorage; use crate::storage::pager::Pager; -use crate::types::{ImmutableRecord, RawSlice, RefValue, TextRef, TextSubtype}; +use crate::types::{ImmutableRecord, RawSlice, RefValue, SerialType, TextRef, TextSubtype}; use crate::{File, Result}; use std::cell::RefCell; use std::mem::MaybeUninit; @@ -985,107 +985,8 @@ fn read_payload(unread: &'static [u8], payload_size: usize) -> (&'static [u8], O } } -pub type SerialType = u64; - -pub const SERIAL_TYPE_NULL: SerialType = 0; -pub const SERIAL_TYPE_INT8: SerialType = 1; -pub const SERIAL_TYPE_BEINT16: SerialType = 2; -pub const SERIAL_TYPE_BEINT24: SerialType = 3; -pub const SERIAL_TYPE_BEINT32: SerialType = 4; -pub const SERIAL_TYPE_BEINT48: SerialType = 5; -pub const SERIAL_TYPE_BEINT64: SerialType = 6; -pub const SERIAL_TYPE_BEFLOAT64: SerialType = 7; -pub const SERIAL_TYPE_CONSTINT0: SerialType = 8; -pub const SERIAL_TYPE_CONSTINT1: SerialType = 9; - -pub trait SerialTypeExt { - fn is_null(self) -> bool; - fn is_int8(self) -> bool; - fn is_beint16(self) -> bool; - fn is_beint24(self) -> bool; - fn is_beint32(self) -> bool; - fn is_beint48(self) -> bool; - fn is_beint64(self) -> bool; - fn is_befloat64(self) -> bool; - fn is_constint0(self) -> bool; - fn is_constint1(self) -> bool; - fn is_blob(self) -> bool; - fn is_string(self) -> bool; - fn blob_size(self) -> usize; - fn string_size(self) -> usize; - fn is_valid(self) -> bool; -} - -impl SerialTypeExt for u64 { - fn is_null(self) -> bool { - self == SERIAL_TYPE_NULL - } - - fn is_int8(self) -> bool { - self == SERIAL_TYPE_INT8 - } - - fn is_beint16(self) -> bool { - self == SERIAL_TYPE_BEINT16 - } - - fn is_beint24(self) -> bool { - self == SERIAL_TYPE_BEINT24 - } - - fn is_beint32(self) -> bool { - self == SERIAL_TYPE_BEINT32 - } - - fn is_beint48(self) -> bool { - self == SERIAL_TYPE_BEINT48 - } - - fn is_beint64(self) -> bool { - self == SERIAL_TYPE_BEINT64 - } - - fn is_befloat64(self) -> bool { - self == SERIAL_TYPE_BEFLOAT64 - } - - fn is_constint0(self) -> bool { - self == SERIAL_TYPE_CONSTINT0 - } - - fn is_constint1(self) -> bool { - self == SERIAL_TYPE_CONSTINT1 - } - - fn is_blob(self) -> bool { - self >= 12 && self % 2 == 0 - } - - fn is_string(self) -> bool { - self >= 13 && self % 2 == 1 - } - - fn blob_size(self) -> usize { - debug_assert!(self.is_blob()); - ((self - 12) / 2) as usize - } - - fn string_size(self) -> usize { - debug_assert!(self.is_string()); - ((self - 13) / 2) as usize - } - - fn is_valid(self) -> bool { - self <= 9 || self.is_blob() || self.is_string() - } -} - pub fn validate_serial_type(value: u64) -> Result { - if value.is_valid() { - Ok(value) - } else { - crate::bail_corrupt_error!("Invalid serial type: {}", value) - } + value.try_into() } pub struct SmallVec { @@ -1180,7 +1081,7 @@ pub fn read_record(payload: &[u8], reuse_immutable: &mut ImmutableRecord) -> Res let mut serial_types = SmallVec::::new(); while header_size > 0 { let (serial_type, nr) = read_varint(&reuse_immutable.get_payload()[pos..])?; - let serial_type = validate_serial_type(serial_type)?; + let _ = validate_serial_type(serial_type)?; serial_types.push(serial_type); pos += nr; assert!(header_size >= nr); @@ -1189,14 +1090,17 @@ pub fn read_record(payload: &[u8], reuse_immutable: &mut ImmutableRecord) -> Res for &serial_type in &serial_types.data[..serial_types.len.min(serial_types.data.len())] { let (value, n) = read_value(&reuse_immutable.get_payload()[pos..], unsafe { - *serial_type.as_ptr() + serial_type.assume_init().try_into()? })?; pos += n; reuse_immutable.add_value(value); } if let Some(extra) = serial_types.extra_data.as_ref() { for serial_type in extra { - let (value, n) = read_value(&reuse_immutable.get_payload()[pos..], *serial_type)?; + let (value, n) = read_value( + &reuse_immutable.get_payload()[pos..], + (*serial_type).try_into()?, + )?; pos += n; reuse_immutable.add_value(value); } @@ -1209,140 +1113,123 @@ pub fn read_record(payload: &[u8], reuse_immutable: &mut ImmutableRecord) -> Res /// always. #[inline(always)] pub fn read_value(buf: &[u8], serial_type: SerialType) -> Result<(RefValue, usize)> { - if serial_type.is_null() { - return Ok((RefValue::Null, 0)); - } - - if serial_type.is_int8() { - if buf.is_empty() { - crate::bail_corrupt_error!("Invalid UInt8 value"); + match serial_type { + SerialType::Null => Ok((RefValue::Null, 0)), + SerialType::I8 => { + if buf.is_empty() { + crate::bail_corrupt_error!("Invalid UInt8 value"); + } + let val = buf[0] as i8; + Ok((RefValue::Integer(val as i64), 1)) } - let val = buf[0] as i8; - return Ok((RefValue::Integer(val as i64), 1)); - } - - if serial_type.is_beint16() { - if buf.len() < 2 { - crate::bail_corrupt_error!("Invalid BEInt16 value"); + SerialType::I16 => { + if buf.len() < 2 { + crate::bail_corrupt_error!("Invalid BEInt16 value"); + } + Ok(( + RefValue::Integer(i16::from_be_bytes([buf[0], buf[1]]) as i64), + 2, + )) } - return Ok(( - RefValue::Integer(i16::from_be_bytes([buf[0], buf[1]]) as i64), - 2, - )); - } - - if serial_type.is_beint24() { - if buf.len() < 3 { - crate::bail_corrupt_error!("Invalid BEInt24 value"); + SerialType::I24 => { + if buf.len() < 3 { + crate::bail_corrupt_error!("Invalid BEInt24 value"); + } + let sign_extension = if buf[0] <= 127 { 0 } else { 255 }; + Ok(( + RefValue::Integer( + i32::from_be_bytes([sign_extension, buf[0], buf[1], buf[2]]) as i64 + ), + 3, + )) } - let sign_extension = if buf[0] <= 127 { 0 } else { 255 }; - return Ok(( - RefValue::Integer(i32::from_be_bytes([sign_extension, buf[0], buf[1], buf[2]]) as i64), - 3, - )); - } - - if serial_type.is_beint32() { - if buf.len() < 4 { - crate::bail_corrupt_error!("Invalid BEInt32 value"); + SerialType::I32 => { + if buf.len() < 4 { + crate::bail_corrupt_error!("Invalid BEInt32 value"); + } + Ok(( + RefValue::Integer(i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as i64), + 4, + )) } - return Ok(( - RefValue::Integer(i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as i64), - 4, - )); - } - - if serial_type.is_beint48() { - if buf.len() < 6 { - crate::bail_corrupt_error!("Invalid BEInt48 value"); + SerialType::I48 => { + if buf.len() < 6 { + crate::bail_corrupt_error!("Invalid BEInt48 value"); + } + let sign_extension = if buf[0] <= 127 { 0 } else { 255 }; + Ok(( + RefValue::Integer(i64::from_be_bytes([ + sign_extension, + sign_extension, + buf[0], + buf[1], + buf[2], + buf[3], + buf[4], + buf[5], + ])), + 6, + )) } - let sign_extension = if buf[0] <= 127 { 0 } else { 255 }; - return Ok(( - RefValue::Integer(i64::from_be_bytes([ - sign_extension, - sign_extension, - buf[0], - buf[1], - buf[2], - buf[3], - buf[4], - buf[5], - ])), - 6, - )); - } - - if serial_type.is_beint64() { - if buf.len() < 8 { - crate::bail_corrupt_error!("Invalid BEInt64 value"); + SerialType::I64 => { + if buf.len() < 8 { + crate::bail_corrupt_error!("Invalid BEInt64 value"); + } + Ok(( + RefValue::Integer(i64::from_be_bytes([ + buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], + ])), + 8, + )) } - return Ok(( - RefValue::Integer(i64::from_be_bytes([ - buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], - ])), - 8, - )); - } - - if serial_type.is_befloat64() { - if buf.len() < 8 { - crate::bail_corrupt_error!("Invalid BEFloat64 value"); + SerialType::F64 => { + if buf.len() < 8 { + crate::bail_corrupt_error!("Invalid BEFloat64 value"); + } + Ok(( + RefValue::Float(f64::from_be_bytes([ + buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], + ])), + 8, + )) } - return Ok(( - RefValue::Float(f64::from_be_bytes([ - buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], - ])), - 8, - )); - } - - if serial_type.is_constint0() { - return Ok((RefValue::Integer(0), 0)); - } - - if serial_type.is_constint1() { - return Ok((RefValue::Integer(1), 0)); - } - - if serial_type.is_blob() { - let n = serial_type.blob_size(); - if buf.len() < n { - crate::bail_corrupt_error!("Invalid Blob value"); + SerialType::ConstInt0 => Ok((RefValue::Integer(0), 0)), + SerialType::ConstInt1 => Ok((RefValue::Integer(1), 0)), + SerialType::Blob { content_size } => { + if buf.len() < content_size { + crate::bail_corrupt_error!("Invalid Blob value"); + } + if content_size == 0 { + Ok((RefValue::Blob(RawSlice::new(std::ptr::null(), 0)), 0)) + } else { + let ptr = &buf[0] as *const u8; + let slice = RawSlice::new(ptr, content_size); + Ok((RefValue::Blob(slice), content_size)) + } } - if n == 0 { - return Ok((RefValue::Blob(RawSlice::new(std::ptr::null(), 0)), 0)); + SerialType::Text { content_size } => { + if buf.len() < content_size { + crate::bail_corrupt_error!( + "Invalid String value, length {} < expected length {}", + buf.len(), + content_size + ); + } + let slice = if content_size == 0 { + RawSlice::new(std::ptr::null(), 0) + } else { + let ptr = &buf[0] as *const u8; + RawSlice::new(ptr, content_size) + }; + Ok(( + RefValue::Text(TextRef { + value: slice, + subtype: TextSubtype::Text, + }), + content_size, + )) } - let ptr = &buf[0] as *const u8; - let slice = RawSlice::new(ptr, n); - return Ok((RefValue::Blob(slice), n)); } - - if serial_type.is_string() { - let n = serial_type.string_size(); - if buf.len() < n { - crate::bail_corrupt_error!( - "Invalid String value, length {} < expected length {}", - buf.len(), - n - ); - } - let slice = if n == 0 { - RawSlice::new(std::ptr::null(), 0) - } else { - let ptr = &buf[0] as *const u8; - RawSlice::new(ptr, n) - }; - return Ok(( - RefValue::Text(TextRef { - value: slice, - subtype: TextSubtype::Text, - }), - n, - )); - } - - // This should never happen if validate_serial_type is used correctly - crate::bail_corrupt_error!("Invalid serial type: {}", serial_type) } #[inline(always)] @@ -1676,32 +1563,32 @@ mod tests { use rstest::rstest; #[rstest] - #[case(&[], SERIAL_TYPE_NULL, OwnedValue::Null)] - #[case(&[255], SERIAL_TYPE_INT8, OwnedValue::Integer(-1))] - #[case(&[0x12, 0x34], SERIAL_TYPE_BEINT16, OwnedValue::Integer(0x1234))] - #[case(&[0xFE], SERIAL_TYPE_INT8, OwnedValue::Integer(-2))] - #[case(&[0x12, 0x34, 0x56], SERIAL_TYPE_BEINT24, OwnedValue::Integer(0x123456))] - #[case(&[0x12, 0x34, 0x56, 0x78], SERIAL_TYPE_BEINT32, OwnedValue::Integer(0x12345678))] - #[case(&[0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC], SERIAL_TYPE_BEINT48, OwnedValue::Integer(0x123456789ABC))] - #[case(&[0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xFF], SERIAL_TYPE_BEINT64, OwnedValue::Integer(0x123456789ABCDEFF))] - #[case(&[0x40, 0x09, 0x21, 0xFB, 0x54, 0x44, 0x2D, 0x18], SERIAL_TYPE_BEFLOAT64, OwnedValue::Float(std::f64::consts::PI))] - #[case(&[1, 2], SERIAL_TYPE_CONSTINT0, OwnedValue::Integer(0))] - #[case(&[65, 66], SERIAL_TYPE_CONSTINT1, OwnedValue::Integer(1))] - #[case(&[1, 2, 3], 18, OwnedValue::Blob(vec![1, 2, 3].into()))] - #[case(&[], 12, OwnedValue::Blob(vec![].into()))] // empty blob - #[case(&[65, 66, 67], 19, OwnedValue::build_text("ABC"))] - #[case(&[0x80], SERIAL_TYPE_INT8, OwnedValue::Integer(-128))] - #[case(&[0x80, 0], SERIAL_TYPE_BEINT16, OwnedValue::Integer(-32768))] - #[case(&[0x80, 0, 0], SERIAL_TYPE_BEINT24, OwnedValue::Integer(-8388608))] - #[case(&[0x80, 0, 0, 0], SERIAL_TYPE_BEINT32, OwnedValue::Integer(-2147483648))] - #[case(&[0x80, 0, 0, 0, 0, 0], SERIAL_TYPE_BEINT48, OwnedValue::Integer(-140737488355328))] - #[case(&[0x80, 0, 0, 0, 0, 0, 0, 0], SERIAL_TYPE_BEINT64, OwnedValue::Integer(-9223372036854775808))] - #[case(&[0x7f], SERIAL_TYPE_INT8, OwnedValue::Integer(127))] - #[case(&[0x7f, 0xff], SERIAL_TYPE_BEINT16, OwnedValue::Integer(32767))] - #[case(&[0x7f, 0xff, 0xff], SERIAL_TYPE_BEINT24, OwnedValue::Integer(8388607))] - #[case(&[0x7f, 0xff, 0xff, 0xff], SERIAL_TYPE_BEINT32, OwnedValue::Integer(2147483647))] - #[case(&[0x7f, 0xff, 0xff, 0xff, 0xff, 0xff], SERIAL_TYPE_BEINT48, OwnedValue::Integer(140737488355327))] - #[case(&[0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff], SERIAL_TYPE_BEINT64, OwnedValue::Integer(9223372036854775807))] + #[case(&[], SerialType::Null, OwnedValue::Null)] + #[case(&[255], SerialType::I8, OwnedValue::Integer(-1))] + #[case(&[0x12, 0x34], SerialType::I16, OwnedValue::Integer(0x1234))] + #[case(&[0xFE], SerialType::I8, OwnedValue::Integer(-2))] + #[case(&[0x12, 0x34, 0x56], SerialType::I24, OwnedValue::Integer(0x123456))] + #[case(&[0x12, 0x34, 0x56, 0x78], SerialType::I32, OwnedValue::Integer(0x12345678))] + #[case(&[0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC], SerialType::I48, OwnedValue::Integer(0x123456789ABC))] + #[case(&[0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xFF], SerialType::I64, OwnedValue::Integer(0x123456789ABCDEFF))] + #[case(&[0x40, 0x09, 0x21, 0xFB, 0x54, 0x44, 0x2D, 0x18], SerialType::F64, OwnedValue::Float(std::f64::consts::PI))] + #[case(&[1, 2], SerialType::ConstInt0, OwnedValue::Integer(0))] + #[case(&[65, 66], SerialType::ConstInt1, OwnedValue::Integer(1))] + #[case(&[1, 2, 3], SerialType::Blob { content_size: 3 }, OwnedValue::Blob(vec![1, 2, 3].into()))] + #[case(&[], SerialType::Blob { content_size: 0 }, OwnedValue::Blob(vec![].into()))] // empty blob + #[case(&[65, 66, 67], SerialType::Text { content_size: 3 }, OwnedValue::build_text("ABC"))] + #[case(&[0x80], SerialType::I8, OwnedValue::Integer(-128))] + #[case(&[0x80, 0], SerialType::I16, OwnedValue::Integer(-32768))] + #[case(&[0x80, 0, 0], SerialType::I24, OwnedValue::Integer(-8388608))] + #[case(&[0x80, 0, 0, 0], SerialType::I32, OwnedValue::Integer(-2147483648))] + #[case(&[0x80, 0, 0, 0, 0, 0], SerialType::I48, OwnedValue::Integer(-140737488355328))] + #[case(&[0x80, 0, 0, 0, 0, 0, 0, 0], SerialType::I64, OwnedValue::Integer(-9223372036854775808))] + #[case(&[0x7f], SerialType::I8, OwnedValue::Integer(127))] + #[case(&[0x7f, 0xff], SerialType::I16, OwnedValue::Integer(32767))] + #[case(&[0x7f, 0xff, 0xff], SerialType::I24, OwnedValue::Integer(8388607))] + #[case(&[0x7f, 0xff, 0xff, 0xff], SerialType::I32, OwnedValue::Integer(2147483647))] + #[case(&[0x7f, 0xff, 0xff, 0xff, 0xff, 0xff], SerialType::I48, OwnedValue::Integer(140737488355327))] + #[case(&[0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff], SerialType::I64, OwnedValue::Integer(9223372036854775807))] fn test_read_value( #[case] buf: &[u8], #[case] serial_type: SerialType, @@ -1713,46 +1600,47 @@ mod tests { #[test] fn test_serial_type_helpers() { - assert!(SERIAL_TYPE_NULL.is_null()); - assert!(SERIAL_TYPE_INT8.is_int8()); - assert!(SERIAL_TYPE_BEINT16.is_beint16()); - assert!(SERIAL_TYPE_BEINT24.is_beint24()); - assert!(SERIAL_TYPE_BEINT32.is_beint32()); - assert!(SERIAL_TYPE_BEINT48.is_beint48()); - assert!(SERIAL_TYPE_BEINT64.is_beint64()); - assert!(SERIAL_TYPE_BEFLOAT64.is_befloat64()); - assert!(SERIAL_TYPE_CONSTINT0.is_constint0()); - assert!(SERIAL_TYPE_CONSTINT1.is_constint1()); - - assert!(12u64.is_blob()); - assert!(14u64.is_blob()); - assert!(13u64.is_string()); - assert!(15u64.is_string()); - - assert_eq!(12u64.blob_size(), 0); - assert_eq!(14u64.blob_size(), 1); - assert_eq!(16u64.blob_size(), 2); - - assert_eq!(13u64.string_size(), 0); - assert_eq!(15u64.string_size(), 1); - assert_eq!(17u64.string_size(), 2); + assert_eq!( + TryInto::::try_into(12u64).unwrap(), + SerialType::Blob { content_size: 0 } + ); + assert_eq!( + TryInto::::try_into(14u64).unwrap(), + SerialType::Blob { content_size: 1 } + ); + assert_eq!( + TryInto::::try_into(13u64).unwrap(), + SerialType::Text { content_size: 0 } + ); + assert_eq!( + TryInto::::try_into(15u64).unwrap(), + SerialType::Text { content_size: 1 } + ); + assert_eq!( + TryInto::::try_into(16u64).unwrap(), + SerialType::Blob { content_size: 2 } + ); + assert_eq!( + TryInto::::try_into(17u64).unwrap(), + SerialType::Text { content_size: 2 } + ); } #[rstest] - #[case(0, SERIAL_TYPE_NULL)] - #[case(1, SERIAL_TYPE_INT8)] - #[case(2, SERIAL_TYPE_BEINT16)] - #[case(3, SERIAL_TYPE_BEINT24)] - #[case(4, SERIAL_TYPE_BEINT32)] - #[case(5, SERIAL_TYPE_BEINT48)] - #[case(6, SERIAL_TYPE_BEINT64)] - #[case(7, SERIAL_TYPE_BEFLOAT64)] - #[case(8, SERIAL_TYPE_CONSTINT0)] - #[case(9, SERIAL_TYPE_CONSTINT1)] - #[case(12, 12)] // Blob(0) - #[case(13, 13)] // String(0) - #[case(14, 14)] // Blob(1) - #[case(15, 15)] // String(1) + #[case(0, SerialType::Null)] + #[case(1, SerialType::I8)] + #[case(2, SerialType::I16)] + #[case(3, SerialType::I24)] + #[case(4, SerialType::I32)] + #[case(5, SerialType::I48)] + #[case(6, SerialType::I64)] + #[case(7, SerialType::F64)] + #[case(8, SerialType::ConstInt0)] + #[case(9, SerialType::ConstInt1)] + #[case(12, SerialType::Blob { content_size: 0 })] + #[case(13, SerialType::Text { content_size: 0 })] + #[case(14, SerialType::Blob { content_size: 1 })] + #[case(15, SerialType::Text { content_size: 1 })] fn test_validate_serial_type(#[case] input: u64, #[case] expected: SerialType) { let result = validate_serial_type(input).unwrap(); assert_eq!(result, expected); diff --git a/core/types.rs b/core/types.rs index 045f13393..01b902d9f 100644 --- a/core/types.rs +++ b/core/types.rs @@ -787,7 +787,7 @@ impl ImmutableRecord { serials.push((serial_type_buf, n)); let value_size = match serial_type { - SerialType::Null => 0, + SerialType::Null | SerialType::ConstInt0 | SerialType::ConstInt1 => 0, SerialType::I8 => 1, SerialType::I16 => 2, SerialType::I24 => 3, @@ -845,6 +845,7 @@ impl ImmutableRecord { values.push(RefValue::Integer(*i)); let serial_type = SerialType::from(value); match serial_type { + SerialType::ConstInt0 | SerialType::ConstInt1 => {} SerialType::I8 => writer.extend_from_slice(&(*i as i8).to_be_bytes()), SerialType::I16 => writer.extend_from_slice(&(*i as i16).to_be_bytes()), SerialType::I24 => { @@ -853,7 +854,7 @@ impl ImmutableRecord { SerialType::I32 => writer.extend_from_slice(&(*i as i32).to_be_bytes()), SerialType::I48 => writer.extend_from_slice(&i.to_be_bytes()[2..]), // remove 2 most significant bytes SerialType::I64 => writer.extend_from_slice(&i.to_be_bytes()), - _ => unreachable!(), + other => panic!("Serial type is not an integer: {:?}", other), } } OwnedValue::Float(f) => { @@ -1113,7 +1114,7 @@ const I48_HIGH: i64 = 140737488355327; /// Sqlite Serial Types /// https://www.sqlite.org/fileformat.html#record_format #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -enum SerialType { +pub enum SerialType { Null, I8, I16, @@ -1122,6 +1123,8 @@ enum SerialType { I48, I64, F64, + ConstInt0, + ConstInt1, Text { content_size: usize }, Blob { content_size: usize }, } @@ -1131,6 +1134,8 @@ impl From<&OwnedValue> for SerialType { match value { OwnedValue::Null => SerialType::Null, OwnedValue::Integer(i) => match i { + 0 => SerialType::ConstInt0, + 1 => SerialType::ConstInt1, i if *i >= I8_LOW && *i <= I8_HIGH => SerialType::I8, i if *i >= I16_LOW && *i <= I16_HIGH => SerialType::I16, i if *i >= I24_LOW && *i <= I24_HIGH => SerialType::I24, @@ -1160,12 +1165,46 @@ impl From for u64 { SerialType::I48 => 5, SerialType::I64 => 6, SerialType::F64 => 7, + SerialType::ConstInt0 => 8, + SerialType::ConstInt1 => 9, SerialType::Text { content_size } => (content_size * 2 + 13) as u64, SerialType::Blob { content_size } => (content_size * 2 + 12) as u64, } } } +impl TryFrom for SerialType { + type Error = LimboError; + + fn try_from(serial_type: u64) -> Result { + match serial_type { + 0 => Ok(SerialType::Null), + 1 => Ok(SerialType::I8), + 2 => Ok(SerialType::I16), + 3 => Ok(SerialType::I24), + 4 => Ok(SerialType::I32), + 5 => Ok(SerialType::I48), + 6 => Ok(SerialType::I64), + 7 => Ok(SerialType::F64), + 8 => Ok(SerialType::ConstInt0), + 9 => Ok(SerialType::ConstInt1), + n if n >= 12 => match n % 2 { + 0 => Ok(SerialType::Blob { + content_size: (n as usize - 12) / 2, + }), + 1 => Ok(SerialType::Text { + content_size: (n as usize - 13) / 2, + }), + _ => unreachable!(), + }, + _ => Err(LimboError::Corrupt(format!( + "Invalid serial type: {}", + serial_type + ))), + } + } +} + impl Record { pub fn new(values: Vec) -> Self { Self { values } From 04adf8242a4a2a8de4414ef30e611cf5eb52d01c Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Thu, 24 Apr 2025 16:05:12 +0300 Subject: [PATCH 2/2] faster validate --- core/storage/sqlite3_ondisk.rs | 30 ++++++++++++++++++++++-------- core/types.rs | 7 +++++++ 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index 2e86bbdff..833c928fe 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -985,8 +985,12 @@ fn read_payload(unread: &'static [u8], payload_size: usize) -> (&'static [u8], O } } -pub fn validate_serial_type(value: u64) -> Result { - value.try_into() +#[inline(always)] +pub fn validate_serial_type(value: u64) -> Result<()> { + if !SerialType::u64_is_valid_serial_type(value) { + crate::bail_corrupt_error!("Invalid serial type: {}", value); + } + Ok(()) } pub struct SmallVec { @@ -1081,7 +1085,7 @@ pub fn read_record(payload: &[u8], reuse_immutable: &mut ImmutableRecord) -> Res let mut serial_types = SmallVec::::new(); while header_size > 0 { let (serial_type, nr) = read_varint(&reuse_immutable.get_payload()[pos..])?; - let _ = validate_serial_type(serial_type)?; + validate_serial_type(serial_type)?; serial_types.push(serial_type); pos += nr; assert!(header_size >= nr); @@ -1641,15 +1645,25 @@ mod tests { #[case(13, SerialType::Text { content_size: 0 })] #[case(14, SerialType::Blob { content_size: 1 })] #[case(15, SerialType::Text { content_size: 1 })] - fn test_validate_serial_type(#[case] input: u64, #[case] expected: SerialType) { - let result = validate_serial_type(input).unwrap(); + fn test_parse_serial_type(#[case] input: u64, #[case] expected: SerialType) { + let result = SerialType::try_from(input).unwrap(); assert_eq!(result, expected); } #[test] - fn test_invalid_serial_type() { - let result = validate_serial_type(10); - assert!(result.is_err()); + fn test_validate_serial_type() { + for i in 0..=9 { + let result = validate_serial_type(i); + assert!(result.is_ok()); + } + for i in 10..=11 { + let result = validate_serial_type(i); + assert!(result.is_err()); + } + for i in 12..=1000 { + let result = validate_serial_type(i); + assert!(result.is_ok()); + } } #[test] diff --git a/core/types.rs b/core/types.rs index 01b902d9f..4173324f8 100644 --- a/core/types.rs +++ b/core/types.rs @@ -1129,6 +1129,13 @@ pub enum SerialType { Blob { content_size: usize }, } +impl SerialType { + #[inline(always)] + pub fn u64_is_valid_serial_type(n: u64) -> bool { + n != 10 && n != 11 + } +} + impl From<&OwnedValue> for SerialType { fn from(value: &OwnedValue) -> Self { match value {