From 04adf8242a4a2a8de4414ef30e611cf5eb52d01c Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Thu, 24 Apr 2025 16:05:12 +0300 Subject: [PATCH] faster validate --- core/storage/sqlite3_ondisk.rs | 30 ++++++++++++++++++++++-------- core/types.rs | 7 +++++++ 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index 2e86bbdff..833c928fe 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -985,8 +985,12 @@ fn read_payload(unread: &'static [u8], payload_size: usize) -> (&'static [u8], O } } -pub fn validate_serial_type(value: u64) -> Result { - value.try_into() +#[inline(always)] +pub fn validate_serial_type(value: u64) -> Result<()> { + if !SerialType::u64_is_valid_serial_type(value) { + crate::bail_corrupt_error!("Invalid serial type: {}", value); + } + Ok(()) } pub struct SmallVec { @@ -1081,7 +1085,7 @@ pub fn read_record(payload: &[u8], reuse_immutable: &mut ImmutableRecord) -> Res let mut serial_types = SmallVec::::new(); while header_size > 0 { let (serial_type, nr) = read_varint(&reuse_immutable.get_payload()[pos..])?; - let _ = validate_serial_type(serial_type)?; + validate_serial_type(serial_type)?; serial_types.push(serial_type); pos += nr; assert!(header_size >= nr); @@ -1641,15 +1645,25 @@ mod tests { #[case(13, SerialType::Text { content_size: 0 })] #[case(14, SerialType::Blob { content_size: 1 })] #[case(15, SerialType::Text { content_size: 1 })] - fn test_validate_serial_type(#[case] input: u64, #[case] expected: SerialType) { - let result = validate_serial_type(input).unwrap(); + fn test_parse_serial_type(#[case] input: u64, #[case] expected: SerialType) { + let result = SerialType::try_from(input).unwrap(); assert_eq!(result, expected); } #[test] - fn test_invalid_serial_type() { - let result = validate_serial_type(10); - assert!(result.is_err()); + fn test_validate_serial_type() { + for i in 0..=9 { + let result = validate_serial_type(i); + assert!(result.is_ok()); + } + for i in 10..=11 { + let result = validate_serial_type(i); + assert!(result.is_err()); + } + for i in 12..=1000 { + let result = validate_serial_type(i); + assert!(result.is_ok()); + } } #[test] diff --git a/core/types.rs b/core/types.rs index 01b902d9f..4173324f8 100644 --- a/core/types.rs +++ b/core/types.rs @@ -1129,6 +1129,13 @@ pub enum SerialType { Blob { content_size: usize }, } +impl SerialType { + #[inline(always)] + pub fn u64_is_valid_serial_type(n: u64) -> bool { + n != 10 && n != 11 + } +} + impl From<&OwnedValue> for SerialType { fn from(value: &OwnedValue) -> Self { match value {