diff --git a/core/util.rs b/core/util.rs index b50609ebd..c4a6c3516 100644 --- a/core/util.rs +++ b/core/util.rs @@ -382,18 +382,24 @@ pub fn columns_from_create_table_body(body: ast::CreateTableBody) -> Result { + /// The authority component of the URI. may be 'localhost' or empty pub authority: Option<&'a str>, + /// The normalized path to the database file pub path: String, - pub fragment: Option, + /// The vfs query parameter causes the database connection to be opened using the VFS called NAME pub vfs: Option, - pub mode: Mode, + /// read-only, read-write, read-write and created if it does not exist, or pure in-memory database that never interacts with disk + pub mode: OpenMode, + /// Attempt to set the permissions of the new database file to match the existing file "filename". pub modeof: Option, - pub cache: Option, - pub immutable: Option, + /// Specifies Cache mode shared | private + pub cache: CacheMode, + /// immutable=1|0 specifies that the database is stored on read-only media + pub immutable: bool, } #[derive(Clone, Default, Debug, Copy, PartialEq)] -pub enum Mode { +pub enum OpenMode { ReadOnly, ReadWrite, Memory, @@ -401,8 +407,9 @@ pub enum Mode { ReadWriteCreate, } -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Debug, Default, Clone, Copy, PartialEq)] pub enum CacheMode { + #[default] Private, Shared, } @@ -417,13 +424,13 @@ impl From<&str> for CacheMode { } } -impl Mode { +impl OpenMode { pub fn from_str(s: &str) -> Result { match s.trim().to_lowercase().as_str() { - "ro" => Ok(Mode::ReadOnly), - "rw" => Ok(Mode::ReadWrite), - "memory" => Ok(Mode::Memory), - "rwc" => Ok(Mode::ReadWriteCreate), + "ro" => Ok(OpenMode::ReadOnly), + "rw" => Ok(OpenMode::ReadWrite), + "memory" => Ok(OpenMode::Memory), + "rwc" => Ok(OpenMode::ReadWriteCreate), _ => Err(LimboError::InvalidArgument(format!( "Invalid mode: '{}'. Expected one of 'ro', 'rw', 'memory', 'rwc'", s @@ -432,14 +439,36 @@ impl Mode { } pub fn get_flags(&self) -> OpenFlags { match self { - Mode::ReadWriteCreate => OpenFlags::Create, + OpenMode::ReadWriteCreate => OpenFlags::Create, _ => OpenFlags::None, } } } fn is_windows_path(path: &str) -> bool { - path.len() >= 3 && path.chars().nth(1) == Some(':') && path.chars().nth(2) == Some('/') + path.len() >= 3 + && path.chars().nth(1) == Some(':') + && (path.chars().nth(2) == Some('/') || path.chars().nth(2) == Some('\\')) +} + +/// converts windows-style paths to forward slashes, per SQLite spec. +fn normalize_windows_path(path: &str) -> String { + let mut normalized = path.replace("\\", "/"); + + // remove duplicate slashes (`//` → `/`) + while normalized.contains("//") { + normalized = normalized.replace("//", "/"); + } + + // if absolute windows path (`C:/...`), ensure it starts with `/` + if normalized.len() >= 3 + && !normalized.starts_with('/') + && normalized.chars().nth(1) == Some(':') + && normalized.chars().nth(2) == Some('/') + { + normalized.insert(0, '/'); + } + normalized } /// Parses a SQLite URI, handling Windows and Unix paths separately. @@ -454,23 +483,20 @@ pub fn parse_sqlite_uri(uri: &str) -> Result { let mut opts = OpenOptions::default(); let without_scheme = &uri[5..]; - let (without_fragment, fragment) = without_scheme + let (without_fragment, _) = without_scheme .split_once('#') .unwrap_or((without_scheme, "")); - if !fragment.is_empty() { - opts.fragment = Some(decode_percent(fragment)); - } let (without_query, query) = without_fragment .split_once('?') .unwrap_or((without_fragment, "")); parse_query_params(query, &mut opts)?; - // Handle authority + path separately + // handle authority + path separately if let Some(after_slashes) = without_query.strip_prefix("//") { let (authority, path) = after_slashes.split_once('/').unwrap_or((after_slashes, "")); - // SQLite allows only `localhost` or empty authority. + // sqlite allows only `localhost` or empty authority. if !(authority.is_empty() || authority == "localhost") { return Err(LimboError::InvalidArgument(format!( "Invalid authority '{}'. Only '' or 'localhost' allowed.", @@ -484,7 +510,7 @@ pub fn parse_sqlite_uri(uri: &str) -> Result { }; if is_windows_path(path) { - opts.path = format!("/{}", decode_percent(path)); // Ensure `/C:/` format + opts.path = normalize_windows_path(&decode_percent(path)); } else if !path.is_empty() { opts.path = format!("/{}", decode_percent(path)); } else { @@ -504,10 +530,10 @@ fn parse_query_params(query: &str, opts: &mut OpenOptions) -> Result<()> { if let Some((key, value)) = param.split_once('=') { let decoded_value = decode_percent(value); match key { - "mode" => opts.mode = Mode::from_str(value)?, + "mode" => opts.mode = OpenMode::from_str(value)?, "modeof" => opts.modeof = Some(decoded_value), - "cache" => opts.cache = Some(decoded_value.as_str().into()), - "immutable" => opts.immutable = Some(decoded_value == "1"), + "cache" => opts.cache = decoded_value.as_str().into(), + "immutable" => opts.immutable = decoded_value == "1", "vfs" => opts.vfs = Some(decoded_value), _ => {} } @@ -516,29 +542,64 @@ fn parse_query_params(query: &str, opts: &mut OpenOptions) -> Result<()> { Ok(()) } -/// Decodes percent-encoded characters (e.g., `%20` → `' '`). -fn decode_percent(input: &str) -> String { - let mut result = String::new(); - let mut chars = input.chars().peekable(); +/// Decodes percent-encoded characters +/// this function was adapted from the 'urlencoding' crate. MIT +pub fn decode_percent(uri: &str) -> String { + let from_hex_digit = |digit: u8| -> Option { + match digit { + b'0'..=b'9' => Some(digit - b'0'), + b'A'..=b'F' => Some(digit - b'A' + 10), + b'a'..=b'f' => Some(digit - b'a' + 10), + _ => None, + } + }; - while let Some(c) = chars.next() { - if c == '%' { - if let (Some(h1), Some(h2)) = (chars.next(), chars.next()) { - if let Ok(byte) = u8::from_str_radix(&format!("{}{}", h1, h2), 16) { - result.push(byte as char); - } else { - result.push('%'); - result.push(h1); - result.push(h2); + let offset = uri.chars().take_while(|&c| c != '%').count(); + + if offset >= uri.len() { + return uri.to_string(); + } + + let mut decoded: Vec = Vec::with_capacity(uri.len()); + let (ascii, mut data) = uri.as_bytes().split_at(offset); + decoded.extend_from_slice(ascii); + + loop { + let mut parts = data.splitn(2, |&c| c == b'%'); + let non_escaped_part = parts.next().unwrap(); + let rest = parts.next(); + if rest.is_none() && decoded.is_empty() { + return String::from_utf8_lossy(data).to_string(); + } + decoded.extend_from_slice(non_escaped_part); + match rest { + Some(rest) => match rest.get(0..2) { + Some([first, second]) => match from_hex_digit(*first) { + Some(first_val) => match from_hex_digit(*second) { + Some(second_val) => { + decoded.push((first_val << 4) | second_val); + data = &rest[2..]; + } + None => { + decoded.extend_from_slice(&[b'%', *first]); + data = &rest[1..]; + } + }, + None => { + decoded.push(b'%'); + data = rest; + } + }, + _ => { + decoded.push(b'%'); + decoded.extend_from_slice(rest); + break; } - } else { - result.push('%'); - } - } else { - result.push(c); + }, + None => break, } } - result + String::from_utf8_lossy(&decoded).to_string() } #[cfg(test)] @@ -826,8 +887,8 @@ pub mod tests { let opts = parse_sqlite_uri(uri).unwrap(); assert_eq!(opts.path, "/home/user/db.sqlite"); assert_eq!(opts.vfs, Some("unix".to_string())); - assert_eq!(opts.mode, Mode::ReadOnly); - assert_eq!(opts.immutable, Some(true)); + assert_eq!(opts.mode, OpenMode::ReadOnly); + assert_eq!(opts.immutable, true); } #[test] @@ -835,7 +896,6 @@ pub mod tests { let uri = "file:/home/user/db.sqlite#section1"; let opts = parse_sqlite_uri(uri).unwrap(); assert_eq!(opts.path, "/home/user/db.sqlite"); - assert_eq!(opts.fragment, Some("section1".to_string())); } #[test] @@ -867,7 +927,7 @@ pub mod tests { let uri = "file:/home/user/db.sqlite?mode=rw"; let opts = parse_sqlite_uri(uri).unwrap(); assert_eq!(opts.path, "/home/user/db.sqlite"); - assert_eq!(opts.mode, Mode::ReadWrite); + assert_eq!(opts.mode, OpenMode::ReadWrite); assert_eq!(opts.vfs, None); } @@ -883,8 +943,8 @@ pub mod tests { let uri = "file:?mode=memory&cache=shared"; let opts = parse_sqlite_uri(uri).unwrap(); assert_eq!(opts.path, ""); - assert_eq!(opts.mode, Mode::Memory); - assert_eq!(opts.cache, Some(CacheMode::Shared)); + assert_eq!(opts.mode, OpenMode::Memory); + assert_eq!(opts.cache, CacheMode::Shared); } #[test] @@ -892,7 +952,6 @@ pub mod tests { let uri = "file:#fragment"; let opts = parse_sqlite_uri(uri).unwrap(); assert_eq!(opts.path, ""); - assert_eq!(opts.fragment, Some("fragment".to_string())); } #[test] @@ -909,9 +968,9 @@ pub mod tests { let opts = parse_sqlite_uri(uri).unwrap(); assert_eq!(opts.path, "/home/user/db.sqlite"); assert_eq!(opts.vfs, Some("unix".to_string())); - assert_eq!(opts.mode, Mode::ReadWrite); - assert_eq!(opts.cache, Some(CacheMode::Private)); - assert_eq!(opts.immutable, Some(false)); + assert_eq!(opts.mode, OpenMode::ReadWrite); + assert_eq!(opts.cache, CacheMode::Private); + assert_eq!(opts.immutable, false); } #[test] @@ -995,7 +1054,7 @@ pub mod tests { assert_eq!(opts.path, "data.db"); assert_eq!(opts.authority, None); assert_eq!(opts.vfs, None); - assert_eq!(opts.mode, Mode::ReadWriteCreate); + assert_eq!(opts.mode, OpenMode::ReadWriteCreate); } #[test] @@ -1005,7 +1064,7 @@ pub mod tests { assert_eq!(opts.path, "/home/data/data.db"); assert_eq!(opts.authority, None); assert_eq!(opts.vfs, None); - assert_eq!(opts.mode, Mode::ReadWriteCreate); + assert_eq!(opts.mode, OpenMode::ReadWriteCreate); } #[test] @@ -1031,4 +1090,66 @@ pub mod tests { assert_eq!(opts.path, "/C:/Documents and Settings/fred/Desktop/data.db"); assert_eq!(opts.vfs, None); } + + #[test] + fn test_decode_percent_basic() { + assert_eq!(decode_percent("hello%20world"), "hello world"); + assert_eq!(decode_percent("file%3Adata.db"), "file:data.db"); + assert_eq!(decode_percent("path%2Fto%2Ffile"), "path/to/file"); + } + + #[test] + fn test_decode_percent_edge_cases() { + assert_eq!(decode_percent(""), ""); + assert_eq!(decode_percent("plain_text"), "plain_text"); + assert_eq!( + decode_percent("%2Fhome%2Fuser%2Fdb.sqlite"), + "/home/user/db.sqlite" + ); + // multiple percent-encoded characters in sequence + assert_eq!(decode_percent("%41%42%43"), "ABC"); + assert_eq!(decode_percent("%61%62%63"), "abc"); + } + + #[test] + fn test_decode_percent_invalid_sequences() { + // invalid percent encoding (single % without two hex digits) + assert_eq!(decode_percent("hello%"), "hello%"); + // only one hex digit after % + assert_eq!(decode_percent("file%2"), "file%2"); + // invalid hex digits (not 0-9, A-F, a-f) + assert_eq!(decode_percent("file%2X.db"), "file%2X.db"); + + // Incomplete sequence at the end, leave untouched + assert_eq!(decode_percent("path%2Fto%2"), "path/to%2"); + } + + #[test] + fn test_decode_percent_mixed_valid_invalid() { + assert_eq!(decode_percent("hello%20world%"), "hello world%"); + assert_eq!(decode_percent("%2Fpath%2Xto%2Ffile"), "/path%2Xto/file"); + assert_eq!(decode_percent("file%3Adata.db%2"), "file:data.db%2"); + } + + #[test] + fn test_decode_percent_special_characters() { + assert_eq!( + decode_percent("%21%40%23%24%25%5E%26%2A%28%29"), + "!@#$%^&*()" + ); + assert_eq!(decode_percent("%5B%5D%7B%7D%7C%5C%3A"), "[]{}|\\:"); + } + + #[test] + fn test_decode_percent_unmodified_valid_text() { + // ensure already valid text remains unchanged + assert_eq!( + decode_percent("C:/Users/Example/Database.sqlite"), + "C:/Users/Example/Database.sqlite" + ); + assert_eq!( + decode_percent("/home/user/db.sqlite"), + "/home/user/db.sqlite" + ); + } }