diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index 0bc8d1e67..c7ff9a3e9 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -1538,8 +1538,22 @@ pub fn read_varint(buf: &[u8]) -> Result<(u64, usize)> { } } } - v = (v << 8) + buf[8] as u64; - Ok((v, 9)) + match buf.get(8) { + Some(&c) => { + // Values requiring 9 bytes must have non-zero in the top 8 bits (value >= 1<<56). + // Since the final value is `(v<<8) + c`, the top 8 bits (v >> 48) must not be 0. + // If those are zero, this should be treated as corrupt. + // Perf? the comparison + branching happens only in parsing 9-byte varint which is rare. + if (v >> 48) == 0 { + bail_corrupt_error!("Invalid varint"); + } + v = (v << 8) + c as u64; + Ok((v, 9)) + } + None => { + bail_corrupt_error!("Invalid varint"); + } + } } pub fn varint_len(value: u64) -> usize { @@ -2208,4 +2222,14 @@ mod tests { assert_eq!(small_vec.get(8), None); } + + #[rstest] + #[case(&[])] // empty buffer + #[case(&[0x80])] // truncated 1-byte with continuation + #[case(&[0x80, 0x80])] // truncated 2-byte + #[case(&[0x81, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80])] // 9-byte truncated to 8 + #[case(&[0x80; 9])] // bits set without end + fn test_read_varint_malformed_inputs(#[case] buf: &[u8]) { + assert!(read_varint(buf).is_err()); + } }