Add normalizing windows paths to sqlite spec

This commit is contained in:
PThorpe92
2025-02-18 21:29:26 -05:00
parent 42a0c18574
commit e86f00cb81

View File

@@ -382,18 +382,24 @@ pub fn columns_from_create_table_body(body: ast::CreateTableBody) -> Result<Vec<
#[derive(Debug, Default, PartialEq)]
pub struct OpenOptions<'a> {
/// The authority component of the URI. may be 'localhost' or empty
pub authority: Option<&'a str>,
/// The normalized path to the database file
pub path: String,
pub fragment: Option<String>,
/// The vfs query parameter causes the database connection to be opened using the VFS called NAME
pub vfs: Option<String>,
pub mode: Mode,
/// read-only, read-write, read-write and created if it does not exist, or pure in-memory database that never interacts with disk
pub mode: OpenMode,
/// Attempt to set the permissions of the new database file to match the existing file "filename".
pub modeof: Option<String>,
pub cache: Option<CacheMode>,
pub immutable: Option<bool>,
/// Specifies Cache mode shared | private
pub cache: CacheMode,
/// immutable=1|0 specifies that the database is stored on read-only media
pub immutable: bool,
}
#[derive(Clone, Default, Debug, Copy, PartialEq)]
pub enum Mode {
pub enum OpenMode {
ReadOnly,
ReadWrite,
Memory,
@@ -401,8 +407,9 @@ pub enum Mode {
ReadWriteCreate,
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[derive(Debug, Default, Clone, Copy, PartialEq)]
pub enum CacheMode {
#[default]
Private,
Shared,
}
@@ -417,13 +424,13 @@ impl From<&str> for CacheMode {
}
}
impl Mode {
impl OpenMode {
pub fn from_str(s: &str) -> Result<Self> {
match s.trim().to_lowercase().as_str() {
"ro" => Ok(Mode::ReadOnly),
"rw" => Ok(Mode::ReadWrite),
"memory" => Ok(Mode::Memory),
"rwc" => Ok(Mode::ReadWriteCreate),
"ro" => Ok(OpenMode::ReadOnly),
"rw" => Ok(OpenMode::ReadWrite),
"memory" => Ok(OpenMode::Memory),
"rwc" => Ok(OpenMode::ReadWriteCreate),
_ => Err(LimboError::InvalidArgument(format!(
"Invalid mode: '{}'. Expected one of 'ro', 'rw', 'memory', 'rwc'",
s
@@ -432,14 +439,36 @@ impl Mode {
}
pub fn get_flags(&self) -> OpenFlags {
match self {
Mode::ReadWriteCreate => OpenFlags::Create,
OpenMode::ReadWriteCreate => OpenFlags::Create,
_ => OpenFlags::None,
}
}
}
fn is_windows_path(path: &str) -> bool {
path.len() >= 3 && path.chars().nth(1) == Some(':') && path.chars().nth(2) == Some('/')
path.len() >= 3
&& path.chars().nth(1) == Some(':')
&& (path.chars().nth(2) == Some('/') || path.chars().nth(2) == Some('\\'))
}
/// converts windows-style paths to forward slashes, per SQLite spec.
fn normalize_windows_path(path: &str) -> String {
let mut normalized = path.replace("\\", "/");
// remove duplicate slashes (`//` → `/`)
while normalized.contains("//") {
normalized = normalized.replace("//", "/");
}
// if absolute windows path (`C:/...`), ensure it starts with `/`
if normalized.len() >= 3
&& !normalized.starts_with('/')
&& normalized.chars().nth(1) == Some(':')
&& normalized.chars().nth(2) == Some('/')
{
normalized.insert(0, '/');
}
normalized
}
/// Parses a SQLite URI, handling Windows and Unix paths separately.
@@ -454,23 +483,20 @@ pub fn parse_sqlite_uri(uri: &str) -> Result<OpenOptions> {
let mut opts = OpenOptions::default();
let without_scheme = &uri[5..];
let (without_fragment, fragment) = without_scheme
let (without_fragment, _) = without_scheme
.split_once('#')
.unwrap_or((without_scheme, ""));
if !fragment.is_empty() {
opts.fragment = Some(decode_percent(fragment));
}
let (without_query, query) = without_fragment
.split_once('?')
.unwrap_or((without_fragment, ""));
parse_query_params(query, &mut opts)?;
// Handle authority + path separately
// handle authority + path separately
if let Some(after_slashes) = without_query.strip_prefix("//") {
let (authority, path) = after_slashes.split_once('/').unwrap_or((after_slashes, ""));
// SQLite allows only `localhost` or empty authority.
// sqlite allows only `localhost` or empty authority.
if !(authority.is_empty() || authority == "localhost") {
return Err(LimboError::InvalidArgument(format!(
"Invalid authority '{}'. Only '' or 'localhost' allowed.",
@@ -484,7 +510,7 @@ pub fn parse_sqlite_uri(uri: &str) -> Result<OpenOptions> {
};
if is_windows_path(path) {
opts.path = format!("/{}", decode_percent(path)); // Ensure `/C:/` format
opts.path = normalize_windows_path(&decode_percent(path));
} else if !path.is_empty() {
opts.path = format!("/{}", decode_percent(path));
} else {
@@ -504,10 +530,10 @@ fn parse_query_params(query: &str, opts: &mut OpenOptions) -> Result<()> {
if let Some((key, value)) = param.split_once('=') {
let decoded_value = decode_percent(value);
match key {
"mode" => opts.mode = Mode::from_str(value)?,
"mode" => opts.mode = OpenMode::from_str(value)?,
"modeof" => opts.modeof = Some(decoded_value),
"cache" => opts.cache = Some(decoded_value.as_str().into()),
"immutable" => opts.immutable = Some(decoded_value == "1"),
"cache" => opts.cache = decoded_value.as_str().into(),
"immutable" => opts.immutable = decoded_value == "1",
"vfs" => opts.vfs = Some(decoded_value),
_ => {}
}
@@ -516,29 +542,64 @@ fn parse_query_params(query: &str, opts: &mut OpenOptions) -> Result<()> {
Ok(())
}
/// Decodes percent-encoded characters (e.g., `%20` → `' '`).
fn decode_percent(input: &str) -> String {
let mut result = String::new();
let mut chars = input.chars().peekable();
/// Decodes percent-encoded characters
/// this function was adapted from the 'urlencoding' crate. MIT
pub fn decode_percent(uri: &str) -> String {
let from_hex_digit = |digit: u8| -> Option<u8> {
match digit {
b'0'..=b'9' => Some(digit - b'0'),
b'A'..=b'F' => Some(digit - b'A' + 10),
b'a'..=b'f' => Some(digit - b'a' + 10),
_ => None,
}
};
while let Some(c) = chars.next() {
if c == '%' {
if let (Some(h1), Some(h2)) = (chars.next(), chars.next()) {
if let Ok(byte) = u8::from_str_radix(&format!("{}{}", h1, h2), 16) {
result.push(byte as char);
} else {
result.push('%');
result.push(h1);
result.push(h2);
let offset = uri.chars().take_while(|&c| c != '%').count();
if offset >= uri.len() {
return uri.to_string();
}
let mut decoded: Vec<u8> = Vec::with_capacity(uri.len());
let (ascii, mut data) = uri.as_bytes().split_at(offset);
decoded.extend_from_slice(ascii);
loop {
let mut parts = data.splitn(2, |&c| c == b'%');
let non_escaped_part = parts.next().unwrap();
let rest = parts.next();
if rest.is_none() && decoded.is_empty() {
return String::from_utf8_lossy(data).to_string();
}
decoded.extend_from_slice(non_escaped_part);
match rest {
Some(rest) => match rest.get(0..2) {
Some([first, second]) => match from_hex_digit(*first) {
Some(first_val) => match from_hex_digit(*second) {
Some(second_val) => {
decoded.push((first_val << 4) | second_val);
data = &rest[2..];
}
None => {
decoded.extend_from_slice(&[b'%', *first]);
data = &rest[1..];
}
},
None => {
decoded.push(b'%');
data = rest;
}
},
_ => {
decoded.push(b'%');
decoded.extend_from_slice(rest);
break;
}
} else {
result.push('%');
}
} else {
result.push(c);
},
None => break,
}
}
result
String::from_utf8_lossy(&decoded).to_string()
}
#[cfg(test)]
@@ -826,8 +887,8 @@ pub mod tests {
let opts = parse_sqlite_uri(uri).unwrap();
assert_eq!(opts.path, "/home/user/db.sqlite");
assert_eq!(opts.vfs, Some("unix".to_string()));
assert_eq!(opts.mode, Mode::ReadOnly);
assert_eq!(opts.immutable, Some(true));
assert_eq!(opts.mode, OpenMode::ReadOnly);
assert_eq!(opts.immutable, true);
}
#[test]
@@ -835,7 +896,6 @@ pub mod tests {
let uri = "file:/home/user/db.sqlite#section1";
let opts = parse_sqlite_uri(uri).unwrap();
assert_eq!(opts.path, "/home/user/db.sqlite");
assert_eq!(opts.fragment, Some("section1".to_string()));
}
#[test]
@@ -867,7 +927,7 @@ pub mod tests {
let uri = "file:/home/user/db.sqlite?mode=rw";
let opts = parse_sqlite_uri(uri).unwrap();
assert_eq!(opts.path, "/home/user/db.sqlite");
assert_eq!(opts.mode, Mode::ReadWrite);
assert_eq!(opts.mode, OpenMode::ReadWrite);
assert_eq!(opts.vfs, None);
}
@@ -883,8 +943,8 @@ pub mod tests {
let uri = "file:?mode=memory&cache=shared";
let opts = parse_sqlite_uri(uri).unwrap();
assert_eq!(opts.path, "");
assert_eq!(opts.mode, Mode::Memory);
assert_eq!(opts.cache, Some(CacheMode::Shared));
assert_eq!(opts.mode, OpenMode::Memory);
assert_eq!(opts.cache, CacheMode::Shared);
}
#[test]
@@ -892,7 +952,6 @@ pub mod tests {
let uri = "file:#fragment";
let opts = parse_sqlite_uri(uri).unwrap();
assert_eq!(opts.path, "");
assert_eq!(opts.fragment, Some("fragment".to_string()));
}
#[test]
@@ -909,9 +968,9 @@ pub mod tests {
let opts = parse_sqlite_uri(uri).unwrap();
assert_eq!(opts.path, "/home/user/db.sqlite");
assert_eq!(opts.vfs, Some("unix".to_string()));
assert_eq!(opts.mode, Mode::ReadWrite);
assert_eq!(opts.cache, Some(CacheMode::Private));
assert_eq!(opts.immutable, Some(false));
assert_eq!(opts.mode, OpenMode::ReadWrite);
assert_eq!(opts.cache, CacheMode::Private);
assert_eq!(opts.immutable, false);
}
#[test]
@@ -995,7 +1054,7 @@ pub mod tests {
assert_eq!(opts.path, "data.db");
assert_eq!(opts.authority, None);
assert_eq!(opts.vfs, None);
assert_eq!(opts.mode, Mode::ReadWriteCreate);
assert_eq!(opts.mode, OpenMode::ReadWriteCreate);
}
#[test]
@@ -1005,7 +1064,7 @@ pub mod tests {
assert_eq!(opts.path, "/home/data/data.db");
assert_eq!(opts.authority, None);
assert_eq!(opts.vfs, None);
assert_eq!(opts.mode, Mode::ReadWriteCreate);
assert_eq!(opts.mode, OpenMode::ReadWriteCreate);
}
#[test]
@@ -1031,4 +1090,66 @@ pub mod tests {
assert_eq!(opts.path, "/C:/Documents and Settings/fred/Desktop/data.db");
assert_eq!(opts.vfs, None);
}
#[test]
fn test_decode_percent_basic() {
assert_eq!(decode_percent("hello%20world"), "hello world");
assert_eq!(decode_percent("file%3Adata.db"), "file:data.db");
assert_eq!(decode_percent("path%2Fto%2Ffile"), "path/to/file");
}
#[test]
fn test_decode_percent_edge_cases() {
assert_eq!(decode_percent(""), "");
assert_eq!(decode_percent("plain_text"), "plain_text");
assert_eq!(
decode_percent("%2Fhome%2Fuser%2Fdb.sqlite"),
"/home/user/db.sqlite"
);
// multiple percent-encoded characters in sequence
assert_eq!(decode_percent("%41%42%43"), "ABC");
assert_eq!(decode_percent("%61%62%63"), "abc");
}
#[test]
fn test_decode_percent_invalid_sequences() {
// invalid percent encoding (single % without two hex digits)
assert_eq!(decode_percent("hello%"), "hello%");
// only one hex digit after %
assert_eq!(decode_percent("file%2"), "file%2");
// invalid hex digits (not 0-9, A-F, a-f)
assert_eq!(decode_percent("file%2X.db"), "file%2X.db");
// Incomplete sequence at the end, leave untouched
assert_eq!(decode_percent("path%2Fto%2"), "path/to%2");
}
#[test]
fn test_decode_percent_mixed_valid_invalid() {
assert_eq!(decode_percent("hello%20world%"), "hello world%");
assert_eq!(decode_percent("%2Fpath%2Xto%2Ffile"), "/path%2Xto/file");
assert_eq!(decode_percent("file%3Adata.db%2"), "file:data.db%2");
}
#[test]
fn test_decode_percent_special_characters() {
assert_eq!(
decode_percent("%21%40%23%24%25%5E%26%2A%28%29"),
"!@#$%^&*()"
);
assert_eq!(decode_percent("%5B%5D%7B%7D%7C%5C%3A"), "[]{}|\\:");
}
#[test]
fn test_decode_percent_unmodified_valid_text() {
// ensure already valid text remains unchanged
assert_eq!(
decode_percent("C:/Users/Example/Database.sqlite"),
"C:/Users/Example/Database.sqlite"
);
assert_eq!(
decode_percent("/home/user/db.sqlite"),
"/home/user/db.sqlite"
);
}
}