From 932536a03f99c519bb6eecbd992647fd542304e1 Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Tue, 15 Jul 2025 17:57:52 +0300 Subject: [PATCH] compare_records: fix assumption that header size is 1 byte and serial type is 1 byte --- core/types.rs | 42 +++++++++++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/core/types.rs b/core/types.rs index ac9c00495..963c62906 100644 --- a/core/types.rs +++ b/core/types.rs @@ -1618,7 +1618,6 @@ pub fn get_tie_breaker_from_seek_op(seek_op: SeekOp) -> std::cmp::Ordering { /// /// The function uses the optimized path when ALL of these conditions are met: /// - Payload is at least 2 bytes (header size + first serial type) -/// - Header size ≤ 63 bytes (`payload[0] <= 63`) - safety constraint /// - First serial type indicates integer (`1-6`, `8`, or `9`) /// - First unpacked field is a `RefValue::Integer` /// @@ -1650,7 +1649,7 @@ fn compare_records_int( tie_breaker: std::cmp::Ordering, ) -> Result { let payload = serialized.get_payload(); - if payload.len() < 2 || payload[0] > 63 { + if payload.len() < 2 { return compare_records_generic( serialized, unpacked, @@ -1661,10 +1660,21 @@ fn compare_records_int( ); } - let header_size = payload[0] as usize; - let first_serial_type = payload[1]; + let (header_size, offset_1st_serialtype) = read_varint(payload)?; + let header_size = header_size as usize; - if !matches!(first_serial_type, 1..=6 | 8 | 9) { + if payload.len() < header_size { + return Err(LimboError::Corrupt(format!( + "Record payload too short: claimed header size {} but payload only {} bytes", + header_size, + payload.len() + ))); + } + + let (first_serial_type, _) = read_varint(&payload[offset_1st_serialtype..])?; + + let serialtype_is_integer = matches!(first_serial_type, 1..=6 | 8 | 9); + if !serialtype_is_integer { return compare_records_generic( serialized, unpacked, @@ -1677,7 +1687,7 @@ fn compare_records_int( let data_start = header_size; - let lhs_int = read_integer(&payload[data_start..], first_serial_type)?; + let lhs_int = read_integer(&payload[data_start..], first_serial_type as u8)?; let RefValue::Integer(rhs_int) = unpacked[0] else { return compare_records_generic( serialized, @@ -1768,11 +1778,21 @@ fn compare_records_string( ); } - let header_size = payload[0] as usize; - let first_serial_type = payload[1]; + let (header_size, offset_1st_serialtype) = read_varint(payload)?; + let header_size = header_size as usize; - // Check if serial type is not a string or if its a blob - if first_serial_type < 13 || (first_serial_type & 1) == 0 { + if payload.len() < header_size { + return Err(LimboError::Corrupt(format!( + "Record payload too short: claimed header size {} but payload only {} bytes", + header_size, + payload.len() + ))); + } + + let (first_serial_type, _) = read_varint(&payload[offset_1st_serialtype..])?; + + let serialtype_is_string = first_serial_type >= 13 && (first_serial_type & 1) == 1; + if !serialtype_is_string { return compare_records_generic( serialized, unpacked, @@ -1799,7 +1819,7 @@ fn compare_records_string( debug_assert!(data_start + string_len <= payload.len()); - let serial_type = SerialType::try_from(first_serial_type as u64)?; + let serial_type = SerialType::try_from(first_serial_type)?; let (lhs_value, _) = read_value(&payload[data_start..], serial_type)?; let RefValue::Text(lhs_text) = lhs_value else {