diff --git a/Cargo.lock b/Cargo.lock index 588cc563e..fb47c5f12 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -48,6 +48,12 @@ dependencies = [ "equator", ] +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + [[package]] name = "anarchist-readable-name-generator-lib" version = "0.1.2" @@ -1132,6 +1138,8 @@ version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" dependencies = [ + "allocator-api2", + "equivalent", "foldhash", ] @@ -1701,6 +1709,7 @@ dependencies = [ "limbo_completion", "limbo_crypto", "limbo_ext", + "limbo_ext_tests", "limbo_ipaddr", "limbo_macros", "limbo_percentile", @@ -1709,6 +1718,7 @@ dependencies = [ "limbo_sqlite3_parser", "limbo_time", "limbo_uuid", + "lru", "miette", "mimalloc", "parking_lot", @@ -1751,9 +1761,22 @@ dependencies = [ name = "limbo_ext" version = "0.0.16" dependencies = [ + "chrono", + "getrandom 0.3.1", "limbo_macros", ] +[[package]] +name = "limbo_ext_tests" +version = "0.0.16" +dependencies = [ + "env_logger 0.11.6", + "lazy_static", + "limbo_ext", + "log", + "mimalloc", +] + [[package]] name = "limbo_ipaddr" version = "0.0.16" @@ -1763,15 +1786,6 @@ dependencies = [ "mimalloc", ] -[[package]] -name = "limbo_kv" -version = "0.0.16" -dependencies = [ - "lazy_static", - "limbo_ext", - "mimalloc", -] - [[package]] name = "limbo_macros" version = "0.0.16" @@ -1944,6 +1958,15 @@ version = "0.4.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e" +[[package]] +name = "lru" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "227748d55f2f0ab4735d87fd623798cb6b664512fe979705f829c9f81c934465" +dependencies = [ + "hashbrown", +] + [[package]] name = "matchers" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index e66fbb044..cb98958f9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ members = [ "extensions/completion", "extensions/core", "extensions/crypto", - "extensions/kvstore", + "extensions/tests", "extensions/percentile", "extensions/regexp", "extensions/series", @@ -47,6 +47,7 @@ limbo_uuid = { path = "extensions/uuid", version = "0.0.16" } limbo_sqlite3_parser = { path = "vendored/sqlite3-parser", version = "0.0.16" } limbo_ipaddr = { path = "extensions/ipaddr", version = "0.0.16" } limbo_completion = { path = "extensions/completion", version = "0.0.16" } +limbo_ext_tests = { path = "extensions/tests", version = "0.0.16" } # Config for 'cargo dist' [workspace.metadata.dist] diff --git a/bindings/rust/src/lib.rs b/bindings/rust/src/lib.rs index b65624d10..03da2149a 100644 --- a/bindings/rust/src/lib.rs +++ b/bindings/rust/src/lib.rs @@ -6,6 +6,7 @@ pub use value::Value; pub use params::params_from_iter; use crate::params::*; +use std::num::NonZero; use std::rc::Rc; use std::sync::{Arc, Mutex}; @@ -63,7 +64,7 @@ unsafe impl Sync for Database {} impl Database { pub fn connect(&self) -> Result { - let conn = self.inner.connect().unwrap(); + let conn = self.inner.connect()?; #[allow(clippy::arc_with_non_send_sync)] let connection = Connection { inner: Arc::new(Mutex::new(conn)), @@ -125,8 +126,14 @@ impl Statement { pub async fn query(&mut self, params: impl IntoParams) -> Result { let params = params.into_params()?; match params { - crate::params::Params::None => {} - _ => todo!(), + params::Params::None => (), + params::Params::Positional(values) => { + for (i, value) in values.into_iter().enumerate() { + let mut stmt = self.inner.lock().unwrap(); + stmt.bind_at(NonZero::new(i + 1).unwrap(), value.into()); + } + } + params::Params::Named(_items) => todo!(), } #[allow(clippy::arc_with_non_send_sync)] let rows = Rows { @@ -136,8 +143,42 @@ impl Statement { } pub async fn execute(&mut self, params: impl IntoParams) -> Result { - let _params = params.into_params()?; - todo!(); + let params = params.into_params()?; + match params { + params::Params::None => (), + params::Params::Positional(values) => { + for (i, value) in values.into_iter().enumerate() { + let mut stmt = self.inner.lock().unwrap(); + stmt.bind_at(NonZero::new(i + 1).unwrap(), value.into()); + } + } + params::Params::Named(_items) => todo!(), + } + loop { + let mut stmt = self.inner.lock().unwrap(); + match stmt.step() { + Ok(limbo_core::StepResult::Row) => { + // unexpected row during execution, error out. + return Ok(2); + } + Ok(limbo_core::StepResult::Done) => { + return Ok(0); + } + Ok(limbo_core::StepResult::IO) => { + let _ = stmt.run_once(); + //return Ok(1); + } + Ok(limbo_core::StepResult::Busy) => { + return Ok(4); + } + Ok(limbo_core::StepResult::Interrupt) => { + return Ok(3); + } + Err(err) => { + return Err(err.into()); + } + } + } } } @@ -191,7 +232,12 @@ impl Row { let value = &self.values[index]; match value { limbo_core::OwnedValue::Integer(i) => Ok(Value::Integer(*i)), - _ => todo!(), + limbo_core::OwnedValue::Null => Ok(Value::Null), + limbo_core::OwnedValue::Float(f) => Ok(Value::Real(*f)), + limbo_core::OwnedValue::Text(text) => Ok(Value::Text(text.to_string())), + limbo_core::OwnedValue::Blob(items) => Ok(Value::Blob(items.to_vec())), + limbo_core::OwnedValue::Agg(_agg_context) => todo!(), + limbo_core::OwnedValue::Record(_record) => todo!(), } } } diff --git a/bindings/rust/src/value.rs b/bindings/rust/src/value.rs index d5e4e393b..899eeb4e3 100644 --- a/bindings/rust/src/value.rs +++ b/bindings/rust/src/value.rs @@ -110,6 +110,18 @@ impl Value { } } +impl Into for Value { + fn into(self) -> limbo_core::OwnedValue { + match self { + Value::Null => limbo_core::OwnedValue::Null, + Value::Integer(n) => limbo_core::OwnedValue::Integer(n), + Value::Real(n) => limbo_core::OwnedValue::Float(n), + Value::Text(t) => limbo_core::OwnedValue::from_text(&t), + Value::Blob(items) => limbo_core::OwnedValue::from_blob(items), + } + } +} + impl From for Value { fn from(value: i8) -> Value { Value::Integer(value as i64) diff --git a/cli/app.rs b/cli/app.rs index 49f93ab53..1a063bbfb 100644 --- a/cli/app.rs +++ b/cli/app.rs @@ -1,7 +1,7 @@ use crate::{ helper::LimboHelper, import::{ImportFile, IMPORT_HELP}, - input::{get_io, get_writer, DbLocation, Io, OutputMode, Settings, HELP_MSG}, + input::{get_io, get_writer, DbLocation, OutputMode, Settings, HELP_MSG}, opcodes_dictionary::OPCODE_DESCRIPTIONS, }; use comfy_table::{Attribute, Cell, CellAlignment, Color, ContentArrangement, Row, Table}; @@ -43,14 +43,11 @@ pub struct Opts { #[clap(short, long, help = "Print commands before execution")] pub echo: bool, #[clap( - default_value_t, - value_enum, - short, + short = 'v', long, - help = "Select I/O backend. The only other choice to 'syscall' is\n\ - \t'io-uring' when built for Linux with feature 'io_uring'\n" + help = "Select VFS. options are io_uring (if feature enabled), memory, and syscall" )] - pub io: Io, + pub vfs: Option, #[clap(long, help = "Enable experimental MVCC feature")] pub experimental_mvcc: bool, } @@ -89,6 +86,8 @@ pub enum Command { LoadExtension, /// Dump the current database as a list of SQL statements Dump, + /// List vfs modules available + ListVfs, } impl Command { @@ -102,6 +101,7 @@ impl Command { | Self::ShowInfo | Self::Tables | Self::SetOutput + | Self::ListVfs | Self::Dump => 0, Self::Open | Self::OutputMode @@ -131,6 +131,7 @@ impl Command { Self::LoadExtension => ".load", Self::Dump => ".dump", Self::Import => &IMPORT_HELP, + Self::ListVfs => ".vfslist", } } } @@ -155,6 +156,7 @@ impl FromStr for Command { ".import" => Ok(Self::Import), ".load" => Ok(Self::LoadExtension), ".dump" => Ok(Self::Dump), + ".vfslist" => Ok(Self::ListVfs), _ => Err("Unknown command".to_string()), } } @@ -207,15 +209,27 @@ impl<'a> Limbo<'a> { .database .as_ref() .map_or(":memory:".to_string(), |p| p.to_string_lossy().to_string()); - - let io = { - match db_file.as_str() { - ":memory:" => get_io(DbLocation::Memory, opts.io)?, - _path => get_io(DbLocation::Path, opts.io)?, - } + let (io, db) = if let Some(ref vfs) = opts.vfs { + Database::open_new(&db_file, vfs)? + } else { + let io = { + match db_file.as_str() { + ":memory:" => get_io( + DbLocation::Memory, + opts.vfs.as_ref().map_or("", |s| s.as_str()), + )?, + _path => get_io( + DbLocation::Path, + opts.vfs.as_ref().map_or("", |s| s.as_str()), + )?, + } + }; + ( + io.clone(), + Database::open_file(io.clone(), &db_file, opts.experimental_mvcc)?, + ) }; - let db = Database::open_file(io.clone(), &db_file, opts.experimental_mvcc)?; - let conn = db.connect().unwrap(); + let conn = db.connect()?; let h = LimboHelper::new(conn.clone(), io.clone()); rl.set_helper(Some(h)); let interrupt_count = Arc::new(AtomicUsize::new(0)); @@ -408,17 +422,21 @@ impl<'a> Limbo<'a> { } } - fn open_db(&mut self, path: &str) -> anyhow::Result<()> { + fn open_db(&mut self, path: &str, vfs_name: Option<&str>) -> anyhow::Result<()> { self.conn.close()?; - let io = { - match path { - ":memory:" => get_io(DbLocation::Memory, self.opts.io)?, - _path => get_io(DbLocation::Path, self.opts.io)?, - } + let (io, db) = if let Some(vfs_name) = vfs_name { + self.conn.open_new(path, vfs_name)? + } else { + let io = { + match path { + ":memory:" => get_io(DbLocation::Memory, &self.opts.io.to_string())?, + _path => get_io(DbLocation::Path, &self.opts.io.to_string())?, + } + }; + (io.clone(), Database::open_file(io.clone(), path, false)?) }; - self.io = Arc::clone(&io); - let db = Database::open_file(self.io.clone(), path, self.opts.experimental_mvcc)?; - self.conn = db.connect().unwrap(); + self.io = io; + self.conn = db.connect()?; self.opts.db_file = path.to_string(); Ok(()) } @@ -572,7 +590,8 @@ impl<'a> Limbo<'a> { std::process::exit(0) } Command::Open => { - if self.open_db(args[1]).is_err() { + let vfs = args.get(2).map(|s| &**s); + if self.open_db(args[1], vfs).is_err() { let _ = self.writeln("Error: Unable to open database file."); } } @@ -654,6 +673,12 @@ impl<'a> Limbo<'a> { let _ = self.write_fmt(format_args!("/****** ERROR: {} ******/", e)); } } + Command::ListVfs => { + let _ = self.writeln("Available VFS modules:"); + self.conn.list_vfs().iter().for_each(|v| { + let _ = self.writeln(v); + }); + } } } else { let _ = self.write_fmt(format_args!( diff --git a/cli/input.rs b/cli/input.rs index 459b9ac2a..627389984 100644 --- a/cli/input.rs +++ b/cli/input.rs @@ -1,6 +1,7 @@ use crate::app::Opts; use clap::ValueEnum; use std::{ + fmt::{Display, Formatter}, io::{self, Write}, sync::Arc, }; @@ -11,11 +12,26 @@ pub enum DbLocation { Path, } -#[derive(Copy, Clone, ValueEnum)] +#[allow(clippy::enum_variant_names)] +#[derive(Clone, Debug)] pub enum Io { Syscall, #[cfg(all(target_os = "linux", feature = "io_uring"))] IoUring, + External(String), + Memory, +} + +impl Display for Io { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Io::Memory => write!(f, "memory"), + Io::Syscall => write!(f, "syscall"), + #[cfg(all(target_os = "linux", feature = "io_uring"))] + Io::IoUring => write!(f, "io_uring"), + Io::External(str) => write!(f, "{}", str), + } + } } impl Default for Io { @@ -65,7 +81,6 @@ pub struct Settings { pub echo: bool, pub is_stdout: bool, pub io: Io, - pub experimental_mvcc: bool, } impl From<&Opts> for Settings { @@ -80,8 +95,14 @@ impl From<&Opts> for Settings { .database .as_ref() .map_or(":memory:".to_string(), |p| p.to_string_lossy().to_string()), - io: opts.io, - experimental_mvcc: opts.experimental_mvcc, + io: match opts.vfs.as_ref().unwrap_or(&String::new()).as_str() { + "memory" => Io::Memory, + "syscall" => Io::Syscall, + #[cfg(all(target_os = "linux", feature = "io_uring"))] + "io_uring" => Io::IoUring, + "" => Io::default(), + vfs => Io::External(vfs.to_string()), + }, } } } @@ -120,12 +141,13 @@ pub fn get_writer(output: &str) -> Box { } } -pub fn get_io(db_location: DbLocation, io_choice: Io) -> anyhow::Result> { +pub fn get_io(db_location: DbLocation, io_choice: &str) -> anyhow::Result> { Ok(match db_location { DbLocation::Memory => Arc::new(limbo_core::MemoryIO::new()), DbLocation::Path => { match io_choice { - Io::Syscall => { + "memory" => Arc::new(limbo_core::MemoryIO::new()), + "syscall" => { // We are building for Linux/macOS and syscall backend has been selected #[cfg(target_family = "unix")] { @@ -139,7 +161,8 @@ pub fn get_io(db_location: DbLocation, io_choice: Io) -> anyhow::Result Arc::new(limbo_core::UringIO::new()?), + "io_uring" => Arc::new(limbo_core::UringIO::new()?), + _ => Arc::new(limbo_core::PlatformIO::new()?), } } }) diff --git a/core/Cargo.toml b/core/Cargo.toml index 0914142a1..a33f3bba2 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -26,6 +26,7 @@ crypto = ["limbo_crypto/static"] series = ["limbo_series/static"] ipaddr = ["limbo_ipaddr/static"] completion = ["limbo_completion/static"] +testvfs = ["limbo_ext_tests/static"] [target.'cfg(target_os = "linux")'.dependencies] io-uring = { version = "0.6.1", optional = true } @@ -68,6 +69,7 @@ limbo_crypto = { workspace = true, optional = true, features = ["static"] } limbo_series = { workspace = true, optional = true, features = ["static"] } limbo_ipaddr = { workspace = true, optional = true, features = ["static"] } limbo_completion = { workspace = true, optional = true, features = ["static"] } +limbo_ext_tests = { workspace = true, optional = true, features = ["static"] } miette = "7.4.0" strum = "0.26" parking_lot = "0.12.3" @@ -97,6 +99,7 @@ rand = "0.8.5" # Required for quickcheck rand_chacha = "0.9.0" env_logger = "0.11.6" test-log = { version = "0.2.17", features = ["trace"] } +lru = "0.13.0" [[bench]] name = "benchmark" diff --git a/core/ext/dynamic.rs b/core/ext/dynamic.rs new file mode 100644 index 000000000..c6b43d81d --- /dev/null +++ b/core/ext/dynamic.rs @@ -0,0 +1,41 @@ +use crate::{Connection, LimboError}; +use libloading::{Library, Symbol}; +use limbo_ext::{ExtensionApi, ExtensionApiRef, ExtensionEntryPoint}; +use std::sync::{Arc, Mutex, OnceLock}; + +type ExtensionStore = Vec<(Arc, ExtensionApiRef)>; +static EXTENSIONS: OnceLock>> = OnceLock::new(); +pub fn get_extension_libraries() -> Arc> { + EXTENSIONS + .get_or_init(|| Arc::new(Mutex::new(Vec::new()))) + .clone() +} + +impl Connection { + pub fn load_extension>(&self, path: P) -> crate::Result<()> { + use limbo_ext::ExtensionApiRef; + + let api = Box::new(self.build_limbo_ext()); + let lib = + unsafe { Library::new(path).map_err(|e| LimboError::ExtensionError(e.to_string()))? }; + let entry: Symbol = unsafe { + lib.get(b"register_extension") + .map_err(|e| LimboError::ExtensionError(e.to_string()))? + }; + let api_ptr: *const ExtensionApi = Box::into_raw(api); + let api_ref = ExtensionApiRef { api: api_ptr }; + let result_code = unsafe { entry(api_ptr) }; + if result_code.is_ok() { + let extensions = get_extension_libraries(); + extensions.lock().unwrap().push((Arc::new(lib), api_ref)); + Ok(()) + } else { + if !api_ptr.is_null() { + let _ = unsafe { Box::from_raw(api_ptr.cast_mut()) }; + } + Err(LimboError::ExtensionError( + "Extension registration failed".to_string(), + )) + } + } +} diff --git a/core/ext/mod.rs b/core/ext/mod.rs index 9ed05adc4..9fdc9f90e 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -1,13 +1,22 @@ -use crate::{function::ExternalFunc, Connection}; +#[cfg(not(target_family = "wasm"))] +mod dynamic; +#[cfg(all(target_os = "linux", feature = "io_uring"))] +use crate::UringIO; +use crate::IO; +use crate::{function::ExternalFunc, Connection, Database, LimboError}; use limbo_ext::{ - ExtensionApi, InitAggFunction, ResultCode, ScalarFunction, VTabKind, VTabModuleImpl, + ExtensionApi, InitAggFunction, ResultCode, ScalarFunction, VTabKind, VTabModuleImpl, VfsImpl, }; pub use limbo_ext::{FinalizeFunction, StepFunction, Value as ExtValue, ValueType as ExtValueType}; use std::{ ffi::{c_char, c_void, CStr, CString}, rc::Rc, + sync::{Arc, Mutex, OnceLock}, }; type ExternAggFunc = (InitAggFunction, StepFunction, FinalizeFunction); +type Vfs = (String, Arc); + +static VFS_MODULES: OnceLock>> = OnceLock::new(); #[derive(Clone)] pub struct VTabImpl { @@ -15,6 +24,14 @@ pub struct VTabImpl { pub implementation: Rc, } +#[derive(Clone, Debug)] +pub struct VfsMod { + pub ctx: *const VfsImpl, +} + +unsafe impl Send for VfsMod {} +unsafe impl Sync for VfsMod {} + unsafe extern "C" fn register_scalar_function( ctx: *mut c_void, name: *const c_char, @@ -74,6 +91,108 @@ unsafe extern "C" fn register_module( conn.register_module_impl(&name_str, module, kind) } +#[allow(clippy::arc_with_non_send_sync)] +unsafe extern "C" fn register_vfs(name: *const c_char, vfs: *const VfsImpl) -> ResultCode { + if name.is_null() || vfs.is_null() { + return ResultCode::Error; + } + let c_str = unsafe { CString::from_raw(name as *mut i8) }; + let name_str = match c_str.to_str() { + Ok(s) => s.to_string(), + Err(_) => return ResultCode::Error, + }; + add_vfs_module(name_str, Arc::new(VfsMod { ctx: vfs })); + ResultCode::OK +} + +/// Get pointers to all the vfs extensions that need to be built in at compile time. +/// any other types that are defined in the same extension will not be registered +/// until the database file is opened and `register_builtins` is called. +#[cfg(feature = "fs")] +#[allow(clippy::arc_with_non_send_sync)] +pub fn add_builtin_vfs_extensions( + api: Option, +) -> crate::Result)>> { + let mut vfslist: Vec<*const VfsImpl> = Vec::new(); + let mut api = match api { + None => ExtensionApi { + ctx: std::ptr::null_mut(), + register_scalar_function, + register_aggregate_function, + register_vfs, + register_module, + builtin_vfs: vfslist.as_mut_ptr(), + builtin_vfs_count: 0, + }, + Some(mut api) => { + api.builtin_vfs = vfslist.as_mut_ptr(); + api + } + }; + register_static_vfs_modules(&mut api); + let mut vfslist = Vec::with_capacity(api.builtin_vfs_count as usize); + let slice = + unsafe { std::slice::from_raw_parts_mut(api.builtin_vfs, api.builtin_vfs_count as usize) }; + for vfs in slice { + if vfs.is_null() { + continue; + } + let vfsimpl = unsafe { &**vfs }; + let name = unsafe { + CString::from_raw(vfsimpl.name as *mut i8) + .to_str() + .map_err(|_| { + LimboError::ExtensionError("unable to register vfs extension".to_string()) + })? + .to_string() + }; + vfslist.push(( + name, + Arc::new(VfsMod { + ctx: vfsimpl as *const _, + }), + )); + } + Ok(vfslist) +} + +fn register_static_vfs_modules(_api: &mut ExtensionApi) { + #[cfg(feature = "testvfs")] + unsafe { + limbo_ext_tests::register_extension_static(_api); + } +} + +impl Database { + #[cfg(feature = "fs")] + #[allow(clippy::arc_with_non_send_sync, dead_code)] + pub fn open_with_vfs( + &self, + path: &str, + vfs: &str, + ) -> crate::Result<(Arc, Arc)> { + use crate::{MemoryIO, PlatformIO}; + + let io: Arc = match vfs { + "memory" => Arc::new(MemoryIO::new()), + "syscall" => Arc::new(PlatformIO::new()?), + #[cfg(all(target_os = "linux", feature = "io_uring"))] + "io_uring" => Arc::new(UringIO::new()?), + other => match get_vfs_modules().iter().find(|v| v.0 == vfs) { + Some((_, vfs)) => vfs.clone(), + None => { + return Err(LimboError::InvalidArgument(format!( + "no such VFS: {}", + other + ))); + } + }, + }; + let db = Self::open_file(io.clone(), path, false)?; + Ok((io, db)) + } +} + impl Connection { fn register_scalar_function_impl(&self, name: &str, func: ScalarFunction) -> ResultCode { self.syms.borrow_mut().functions.insert( @@ -120,44 +239,82 @@ impl Connection { register_scalar_function, register_aggregate_function, register_module, + register_vfs, + builtin_vfs: std::ptr::null_mut(), + builtin_vfs_count: 0, } } pub fn register_builtins(&self) -> Result<(), String> { #[allow(unused_variables)] - let ext_api = self.build_limbo_ext(); + let mut ext_api = self.build_limbo_ext(); #[cfg(feature = "uuid")] - if unsafe { !limbo_uuid::register_extension_static(&ext_api).is_ok() } { + if unsafe { !limbo_uuid::register_extension_static(&mut ext_api).is_ok() } { return Err("Failed to register uuid extension".to_string()); } #[cfg(feature = "percentile")] - if unsafe { !limbo_percentile::register_extension_static(&ext_api).is_ok() } { + if unsafe { !limbo_percentile::register_extension_static(&mut ext_api).is_ok() } { return Err("Failed to register percentile extension".to_string()); } #[cfg(feature = "regexp")] - if unsafe { !limbo_regexp::register_extension_static(&ext_api).is_ok() } { + if unsafe { !limbo_regexp::register_extension_static(&mut ext_api).is_ok() } { return Err("Failed to register regexp extension".to_string()); } #[cfg(feature = "time")] - if unsafe { !limbo_time::register_extension_static(&ext_api).is_ok() } { + if unsafe { !limbo_time::register_extension_static(&mut ext_api).is_ok() } { return Err("Failed to register time extension".to_string()); } #[cfg(feature = "crypto")] - if unsafe { !limbo_crypto::register_extension_static(&ext_api).is_ok() } { + if unsafe { !limbo_crypto::register_extension_static(&mut ext_api).is_ok() } { return Err("Failed to register crypto extension".to_string()); } #[cfg(feature = "series")] - if unsafe { !limbo_series::register_extension_static(&ext_api).is_ok() } { + if unsafe { !limbo_series::register_extension_static(&mut ext_api).is_ok() } { return Err("Failed to register series extension".to_string()); } #[cfg(feature = "ipaddr")] - if unsafe { !limbo_ipaddr::register_extension_static(&ext_api).is_ok() } { + if unsafe { !limbo_ipaddr::register_extension_static(&mut ext_api).is_ok() } { return Err("Failed to register ipaddr extension".to_string()); } #[cfg(feature = "completion")] - if unsafe { !limbo_completion::register_extension_static(&ext_api).is_ok() } { + if unsafe { !limbo_completion::register_extension_static(&mut ext_api).is_ok() } { return Err("Failed to register completion extension".to_string()); } + #[cfg(feature = "fs")] + { + let vfslist = add_builtin_vfs_extensions(Some(ext_api)).map_err(|e| e.to_string())?; + for (name, vfs) in vfslist { + add_vfs_module(name, vfs); + } + } Ok(()) } } + +fn add_vfs_module(name: String, vfs: Arc) { + let mut modules = VFS_MODULES + .get_or_init(|| Mutex::new(Vec::new())) + .lock() + .unwrap(); + if !modules.iter().any(|v| v.0 == name) { + modules.push((name, vfs)); + } +} + +pub fn list_vfs_modules() -> Vec { + VFS_MODULES + .get_or_init(|| Mutex::new(Vec::new())) + .lock() + .unwrap() + .iter() + .map(|v| v.0.clone()) + .collect() +} + +fn get_vfs_modules() -> Vec { + VFS_MODULES + .get_or_init(|| Mutex::new(Vec::new())) + .lock() + .unwrap() + .clone() +} diff --git a/core/function.rs b/core/function.rs index fa10d9787..333266eea 100644 --- a/core/function.rs +++ b/core/function.rs @@ -71,6 +71,7 @@ impl Display for ExternalFunc { #[derive(Debug, Clone, PartialEq)] pub enum JsonFunc { Json, + Jsonb, JsonArray, JsonArrayLength, JsonArrowExtract, @@ -95,6 +96,7 @@ impl Display for JsonFunc { "{}", match self { Self::Json => "json".to_string(), + Self::Jsonb => "jsonb".to_string(), Self::JsonArray => "json_array".to_string(), Self::JsonExtract => "json_extract".to_string(), Self::JsonArrayLength => "json_array_length".to_string(), @@ -549,6 +551,8 @@ impl Func { #[cfg(feature = "json")] "json" => Ok(Self::Json(JsonFunc::Json)), #[cfg(feature = "json")] + "jsonb" => Ok(Self::Json(JsonFunc::Jsonb)), + #[cfg(feature = "json")] "json_array_length" => Ok(Self::Json(JsonFunc::JsonArrayLength)), #[cfg(feature = "json")] "json_array" => Ok(Self::Json(JsonFunc::JsonArray)), diff --git a/core/io/mod.rs b/core/io/mod.rs index 519109565..32ac354f4 100644 --- a/core/io/mod.rs +++ b/core/io/mod.rs @@ -19,11 +19,21 @@ pub trait File: Send + Sync { fn size(&self) -> Result; } +#[derive(Copy, Clone)] pub enum OpenFlags { None, Create, } +impl OpenFlags { + pub fn to_flags(&self) -> i32 { + match self { + Self::None => 0, + Self::Create => 1, + } + } +} + pub trait IO: Send + Sync { fn open_file(&self, path: &str, flags: OpenFlags, direct: bool) -> Result>; @@ -203,5 +213,6 @@ cfg_block! { } mod memory; +mod vfs; pub use memory::MemoryIO; mod common; diff --git a/core/io/vfs.rs b/core/io/vfs.rs new file mode 100644 index 000000000..8ddbbc732 --- /dev/null +++ b/core/io/vfs.rs @@ -0,0 +1,153 @@ +use crate::ext::VfsMod; +use crate::{LimboError, Result}; +use limbo_ext::{VfsFileImpl, VfsImpl}; +use std::cell::RefCell; +use std::ffi::{c_void, CString}; +use std::sync::Arc; + +use super::{Buffer, Completion, File, OpenFlags, IO}; + +impl IO for VfsMod { + fn open_file(&self, path: &str, flags: OpenFlags, direct: bool) -> Result> { + let c_path = CString::new(path).map_err(|_| { + LimboError::ExtensionError("Failed to convert path to CString".to_string()) + })?; + let ctx = self.ctx as *mut c_void; + let vfs = unsafe { &*self.ctx }; + let file = unsafe { (vfs.open)(ctx, c_path.as_ptr(), flags.to_flags(), direct) }; + if file.is_null() { + return Err(LimboError::ExtensionError("File not found".to_string())); + } + Ok(Arc::new(limbo_ext::VfsFileImpl::new(file, self.ctx)?)) + } + + fn run_once(&self) -> Result<()> { + if self.ctx.is_null() { + return Err(LimboError::ExtensionError("VFS is null".to_string())); + } + let vfs = unsafe { &*self.ctx }; + let result = unsafe { (vfs.run_once)(vfs.vfs) }; + if !result.is_ok() { + return Err(LimboError::ExtensionError(result.to_string())); + } + Ok(()) + } + + fn generate_random_number(&self) -> i64 { + if self.ctx.is_null() { + return -1; + } + let vfs = unsafe { &*self.ctx }; + unsafe { (vfs.gen_random_number)() } + } + + fn get_current_time(&self) -> String { + if self.ctx.is_null() { + return "".to_string(); + } + unsafe { + let vfs = &*self.ctx; + let chars = (vfs.current_time)(); + let cstr = CString::from_raw(chars as *mut i8); + cstr.to_string_lossy().into_owned() + } + } +} + +impl File for VfsFileImpl { + fn lock_file(&self, exclusive: bool) -> Result<()> { + let vfs = unsafe { &*self.vfs }; + let result = unsafe { (vfs.lock)(self.file, exclusive) }; + if result.is_ok() { + return Err(LimboError::ExtensionError(result.to_string())); + } + Ok(()) + } + + fn unlock_file(&self) -> Result<()> { + if self.vfs.is_null() { + return Err(LimboError::ExtensionError("VFS is null".to_string())); + } + let vfs = unsafe { &*self.vfs }; + let result = unsafe { (vfs.unlock)(self.file) }; + if result.is_ok() { + return Err(LimboError::ExtensionError(result.to_string())); + } + Ok(()) + } + + fn pread(&self, pos: usize, c: Completion) -> Result<()> { + let r = match &c { + Completion::Read(ref r) => r, + _ => unreachable!(), + }; + let result = { + let mut buf = r.buf_mut(); + let count = buf.len(); + let vfs = unsafe { &*self.vfs }; + unsafe { (vfs.read)(self.file, buf.as_mut_ptr(), count, pos as i64) } + }; + if result < 0 { + Err(LimboError::ExtensionError("pread failed".to_string())) + } else { + c.complete(0); + Ok(()) + } + } + + fn pwrite(&self, pos: usize, buffer: Arc>, c: Completion) -> Result<()> { + let buf = buffer.borrow(); + let count = buf.as_slice().len(); + if self.vfs.is_null() { + return Err(LimboError::ExtensionError("VFS is null".to_string())); + } + let vfs = unsafe { &*self.vfs }; + let result = unsafe { + (vfs.write)( + self.file, + buf.as_slice().as_ptr() as *mut u8, + count, + pos as i64, + ) + }; + + if result < 0 { + Err(LimboError::ExtensionError("pwrite failed".to_string())) + } else { + c.complete(result); + Ok(()) + } + } + + fn sync(&self, c: Completion) -> Result<()> { + let vfs = unsafe { &*self.vfs }; + let result = unsafe { (vfs.sync)(self.file) }; + if result < 0 { + Err(LimboError::ExtensionError("sync failed".to_string())) + } else { + c.complete(0); + Ok(()) + } + } + + fn size(&self) -> Result { + let vfs = unsafe { &*self.vfs }; + let result = unsafe { (vfs.size)(self.file) }; + if result < 0 { + Err(LimboError::ExtensionError("size failed".to_string())) + } else { + Ok(result as u64) + } + } +} + +impl Drop for VfsMod { + fn drop(&mut self) { + if self.ctx.is_null() { + return; + } + unsafe { + let _ = Box::from_raw(self.ctx as *mut VfsImpl); + } + } +} diff --git a/core/json/jsonb.rs b/core/json/jsonb.rs new file mode 100644 index 000000000..911f293be --- /dev/null +++ b/core/json/jsonb.rs @@ -0,0 +1,1671 @@ +use crate::{bail_parse_error, LimboError, Result}; +use std::{fmt::Write, iter::Peekable, str::from_utf8}; + +const PAYLOAD_SIZE8: u8 = 12; +const PAYLOAD_SIZE16: u8 = 13; +const PAYLOAD_SIZE32: u8 = 14; +const MAX_JSON_DEPTH: usize = 1000; +const INFINITY_CHAR_COUNT: u8 = 5; + +#[derive(Debug, Clone)] +pub struct Jsonb { + data: Vec, +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum ElementType { + NULL = 0, + TRUE = 1, + FALSE = 2, + INT = 3, + INT5 = 4, + FLOAT = 5, + FLOAT5 = 6, + TEXT = 7, + TEXTJ = 8, + TEXT5 = 9, + TEXTRAW = 10, + ARRAY = 11, + OBJECT = 12, + RESERVED1 = 13, + RESERVED2 = 14, + RESERVED3 = 15, +} + +impl TryFrom for ElementType { + type Error = LimboError; + + fn try_from(value: u8) -> std::result::Result { + match value { + 0 => Ok(Self::NULL), + 1 => Ok(Self::TRUE), + 2 => Ok(Self::FALSE), + 3 => Ok(Self::INT), + 4 => Ok(Self::INT5), + 5 => Ok(Self::FLOAT), + 6 => Ok(Self::FLOAT5), + 7 => Ok(Self::TEXT), + 8 => Ok(Self::TEXTJ), + 9 => Ok(Self::TEXT5), + 10 => Ok(Self::TEXTRAW), + 11 => Ok(Self::ARRAY), + 12 => Ok(Self::OBJECT), + 13 => Ok(Self::RESERVED1), + 14 => Ok(Self::RESERVED2), + 15 => Ok(Self::RESERVED3), + _ => bail_parse_error!("Failed to recognize jsonvalue type"), + } + } +} + +type PayloadSize = usize; + +#[derive(Debug, Clone)] +pub struct JsonbHeader(ElementType, PayloadSize); + +impl JsonbHeader { + fn new(element_type: ElementType, payload_size: PayloadSize) -> Self { + Self(element_type, payload_size) + } + + fn from_slice(cursor: usize, slice: &[u8]) -> Result<(Self, usize)> { + match slice.get(cursor) { + Some(header_byte) => { + // Extract first 4 bits (values 0-15) + let element_type = header_byte & 15; + if element_type > 12 { + bail_parse_error!("Invalid element type: {}", element_type); + } + // Get the last 4 bits for header_size + let header_size = header_byte >> 4; + let offset: usize; + let total_size = match header_size { + size if size <= 11 => { + offset = 1; + size as usize + } + + 12 => match slice.get(cursor + 1) { + Some(value) => { + offset = 2; + *value as usize + } + None => bail_parse_error!("Failed to read 1-byte size"), + }, + + 13 => match Self::get_size_bytes(slice, cursor + 1, 2) { + Ok(bytes) => { + offset = 3; + u16::from_be_bytes([bytes[0], bytes[1]]) as usize + } + Err(e) => return Err(e), + }, + + 14 => match Self::get_size_bytes(slice, cursor + 1, 4) { + Ok(bytes) => { + offset = 5; + u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize + } + Err(e) => return Err(e), + }, + + _ => unreachable!(), + }; + + Ok((Self(element_type.try_into()?, total_size), offset)) + } + None => bail_parse_error!("Failed to read header byte"), + } + } + + fn into_bytes(self) -> [u8; 5] { + let mut bytes = [0; 5]; + let element_type = self.0; + let payload_size = self.1; + if payload_size <= 11 { + bytes[0] = (element_type as u8) | ((payload_size as u8) << 4); + } else if payload_size <= 0xFF { + bytes[0] = (element_type as u8) | (PAYLOAD_SIZE8 << 4); + bytes[1] = payload_size as u8; + } else if payload_size <= 0xFFFF { + bytes[0] = (element_type as u8) | (PAYLOAD_SIZE16 << 4); + + let size_bytes = (payload_size as u16).to_be_bytes(); + bytes[1] = size_bytes[0]; + bytes[2] = size_bytes[1]; + } else if payload_size <= 0xFFFFFFFF { + bytes[0] = (element_type as u8) | (PAYLOAD_SIZE32 << 4); + + let size_bytes = (payload_size as u32).to_be_bytes(); + + bytes[1] = size_bytes[0]; + bytes[2] = size_bytes[1]; + bytes[3] = size_bytes[2]; + bytes[4] = size_bytes[3]; + } else { + panic!("Payload size too large for encoding"); + } + + bytes + } + + fn get_size_bytes(slice: &[u8], start: usize, count: usize) -> Result<&[u8]> { + match slice.get(start..start + count) { + Some(bytes) => Ok(bytes), + None => bail_parse_error!("Failed to read header size"), + } + } +} + +impl Jsonb { + pub fn new(capacity: usize, data: Option<&[u8]>) -> Self { + if let Some(data) = data { + return Self { + data: data.to_vec(), + }; + } + Self { + data: Vec::with_capacity(capacity), + } + } + + pub fn len(&self) -> usize { + self.data.len() + } + + fn read_header(&self, cursor: usize) -> Result<(JsonbHeader, usize)> { + let (header, offset) = JsonbHeader::from_slice(cursor, &self.data)?; + + Ok((header, offset)) + } + + pub fn is_valid(&self) -> Result<()> { + match self.read_header(0) { + Ok((header, offset)) => { + if let Some(_) = self.data.get(offset..offset + header.1) { + Ok(()) + } else { + bail_parse_error!("malformed JSON") + } + } + Err(_) => bail_parse_error!("malformed JSON"), + } + } + + pub fn to_string(&self) -> Result { + let mut result = String::with_capacity(self.data.len() * 2); + self.write_to_string(&mut result)?; + + Ok(result) + } + + fn write_to_string(&self, string: &mut String) -> Result<()> { + let cursor = 0; + let _ = self.serialize_value(string, cursor); + Ok(()) + } + + fn serialize_value(&self, string: &mut String, cursor: usize) -> Result { + let (header, skip_header) = self.read_header(cursor)?; + let cursor = cursor + skip_header; + + let current_cursor = match header { + JsonbHeader(ElementType::OBJECT, len) => self.serialize_object(string, cursor, len)?, + JsonbHeader(ElementType::ARRAY, len) => self.serialize_array(string, cursor, len)?, + JsonbHeader(ElementType::TEXT, len) + | JsonbHeader(ElementType::TEXTRAW, len) + | JsonbHeader(ElementType::TEXTJ, len) + | JsonbHeader(ElementType::TEXT5, len) => { + self.serialize_string(string, cursor, len, &header.0)? + } + JsonbHeader(ElementType::INT, len) + | JsonbHeader(ElementType::INT5, len) + | JsonbHeader(ElementType::FLOAT, len) + | JsonbHeader(ElementType::FLOAT5, len) => { + self.serialize_number(string, cursor, len, &header.0)? + } + + JsonbHeader(ElementType::TRUE, _) => self.serialize_boolean(string, cursor, true), + JsonbHeader(ElementType::FALSE, _) => self.serialize_boolean(string, cursor, false), + JsonbHeader(ElementType::NULL, _) => self.serialize_null(string, cursor), + JsonbHeader(_, _) => { + unreachable!(); + } + }; + Ok(current_cursor) + } + + fn serialize_object(&self, string: &mut String, cursor: usize, len: usize) -> Result { + let end_cursor = cursor + len; + let mut current_cursor = cursor; + string.push('{'); + while current_cursor < end_cursor { + let (key_header, key_header_offset) = self.read_header(current_cursor)?; + current_cursor += key_header_offset; + let JsonbHeader(element_type, len) = key_header; + + match element_type { + ElementType::TEXT + | ElementType::TEXTRAW + | ElementType::TEXTJ + | ElementType::TEXT5 => { + current_cursor = + self.serialize_string(string, current_cursor, len, &element_type)?; + } + _ => bail_parse_error!("malformed JSON"), + } + + string.push(':'); + current_cursor = self.serialize_value(string, current_cursor)?; + if current_cursor < end_cursor { + string.push(','); + } + } + string.push('}'); + Ok(current_cursor) + } + + fn serialize_array(&self, string: &mut String, cursor: usize, len: usize) -> Result { + let end_cursor = cursor + len; + let mut current_cursor = cursor; + + string.push('['); + + while current_cursor < end_cursor { + current_cursor = self.serialize_value(string, current_cursor)?; + if current_cursor < end_cursor { + string.push(','); + } + } + + string.push(']'); + Ok(current_cursor) + } + + fn serialize_string( + &self, + string: &mut String, + cursor: usize, + len: usize, + kind: &ElementType, + ) -> Result { + let word_slice = &self.data[cursor..cursor + len]; + string.push('"'); + match kind { + // Can be serialized as is. Do not need escaping + ElementType::TEXT => { + let word = from_utf8(word_slice).map_err(|_| { + LimboError::ParseError("Failed to serialize string!".to_string()) + })?; + string.push_str(word); + } + + // Contain standard json escapes + ElementType::TEXTJ => { + let word = from_utf8(word_slice).map_err(|_| { + LimboError::ParseError("Failed to serialize string!".to_string()) + })?; + string.push_str(word); + } + + // We have to escape some JSON5 escape sequences + ElementType::TEXT5 => { + let mut i = 0; + while i < word_slice.len() { + let ch = word_slice[i]; + + // Handle normal characters that don't need escaping + if self.is_json_ok(ch) || ch == b'\'' { + string.push(ch as char); + i += 1; + continue; + } + + // Handle special cases + match ch { + // Double quotes need escaping + b'"' => { + string.push_str("\\\""); + i += 1; + } + + // Control characters (0x00-0x1F) + ch if ch <= 0x1F => { + match ch { + // \b + 0x08 => string.push_str("\\b"), + b'\t' => string.push_str("\\t"), + b'\n' => string.push_str("\\n"), + // \f + 0x0C => string.push_str("\\f"), + b'\r' => string.push_str("\\r"), + _ => { + // Format as \u00XX + let hex = format!("\\u{:04x}", ch); + string.push_str(&hex); + } + } + i += 1; + } + + // Handle escape sequences + b'\\' if i + 1 < word_slice.len() => { + let next_ch = word_slice[i + 1]; + match next_ch { + // Single quote + b'\'' => { + string.push('\''); + i += 2; + } + + // Vertical tab + b'v' => { + string.push_str("\\u0009"); + i += 2; + } + + // Hex escapes like \x27 + b'x' if i + 3 < word_slice.len() => { + string.push_str("\\u00"); + string.push(word_slice[i + 2] as char); + string.push(word_slice[i + 3] as char); + i += 4; + } + + // Null character + b'0' => { + string.push_str("\\u0000"); + i += 2; + } + + // CR line continuation + b'\r' => { + if i + 2 < word_slice.len() && word_slice[i + 2] == b'\n' { + i += 3; // Skip CRLF + } else { + i += 2; // Skip CR + } + } + + // LF line continuation + b'\n' => { + i += 2; + } + + // Unicode line separators (U+2028 and U+2029) + 0xe2 if i + 3 < word_slice.len() + && word_slice[i + 2] == 0x80 + && (word_slice[i + 3] == 0xa8 || word_slice[i + 3] == 0xa9) => + { + i += 4; + } + + // All other escapes pass through + _ => { + string.push('\\'); + string.push(next_ch as char); + i += 2; + } + } + } + + // Default case - just push the character + _ => { + string.push(ch as char); + i += 1; + } + } + } + } + + ElementType::TEXTRAW => { + let word = from_utf8(word_slice).map_err(|_| { + LimboError::ParseError("Failed to serialize string!".to_string()) + })?; + + for ch in word.chars() { + match ch { + '"' => string.push_str("\\\""), + '\\' => string.push_str("\\\\"), + '\x08' => string.push_str("\\b"), + '\x0C' => string.push_str("\\f"), + '\n' => string.push_str("\\n"), + '\r' => string.push_str("\\r"), + '\t' => string.push_str("\\t"), + c if c <= '\u{001F}' => { + string.push_str(&format!("\\u{:04x}", c as u32)); + } + _ => string.push(ch), + } + } + } + + _ => { + unreachable!() + } + } + string.push('"'); + Ok(cursor + len) + } + + fn is_json_ok(&self, ch: u8) -> bool { + (0x20..=0x7E).contains(&ch) && ch != b'"' && ch != b'\\' + } + + fn serialize_number( + &self, + string: &mut String, + cursor: usize, + len: usize, + kind: &ElementType, + ) -> Result { + let current_cursor = cursor + len; + let num_slice = from_utf8(&self.data[cursor..current_cursor]) + .map_err(|_| LimboError::ParseError("Failed to parse integer".to_string()))?; + + match kind { + ElementType::INT | ElementType::FLOAT => { + string.push_str(num_slice); + } + ElementType::INT5 => { + self.serialize_int5(string, num_slice)?; + } + ElementType::FLOAT5 => { + self.serialize_float5(string, num_slice)?; + } + _ => unreachable!(), + } + Ok(current_cursor) + } + + fn serialize_int5(&self, string: &mut String, hex_str: &str) -> Result<()> { + // Check if number is hex + if hex_str.len() > 2 + && (hex_str[..2].eq_ignore_ascii_case("0x") + || (hex_str.starts_with("-") || hex_str.starts_with("+")) + && hex_str[1..3].eq_ignore_ascii_case("0x")) + { + let (sign, hex_part) = if hex_str.starts_with("-0x") || hex_str.starts_with("-0X") { + ("-", &hex_str[3..]) + } else if hex_str.starts_with("+0x") || hex_str.starts_with("+0X") { + ("", &hex_str[3..]) + } else { + ("", &hex_str[2..]) + }; + + // Add sign + string.push_str(sign); + + let mut value = 0u64; + + for ch in hex_part.chars() { + if !ch.is_ascii_hexdigit() { + bail_parse_error!("Failed to parse hex digit: {}", hex_part); + } + + if (value >> 60) != 0 { + string.push_str("9.0e999"); + return Ok(()); + } + + value = value * 16 + ch.to_digit(16).unwrap_or(0) as u64; + } + write!(string, "{}", value) + .map_err(|_| LimboError::ParseError("Error writing string to json!".to_string()))?; + } else { + string.push_str(hex_str); + } + + Ok(()) + } + + fn serialize_float5(&self, string: &mut String, float_str: &str) -> Result<()> { + if float_str.len() < 2 { + bail_parse_error!("Integer is less then 2 chars: {}", float_str); + } + match float_str { + "9e999" | "-9e999" => { + string.push_str(float_str); + } + val if val.starts_with("-.") => { + string.push_str("-0."); + string.push_str(&val[2..]); + } + val if val.starts_with("+.") => { + string.push_str("0."); + string.push_str(&val[2..]); + } + val if val.starts_with(".") => { + string.push_str("0."); + string.push_str(&val[1..]); + } + val if val + .chars() + .next() + .map_or(false, |c| c.is_ascii_alphanumeric() || c == '+' || c == '-') => + { + string.push_str(val); + string.push('0'); + } + _ => bail_parse_error!("Unable to serialize float5: {}", float_str), + } + + Ok(()) + } + + fn serialize_boolean(&self, string: &mut String, cursor: usize, val: bool) -> usize { + if val { + string.push_str("true"); + } else { + string.push_str("false"); + } + + cursor + } + + fn serialize_null(&self, string: &mut String, cursor: usize) -> usize { + string.push_str("null"); + cursor + } + + fn deserialize_value<'a, I>(&mut self, input: &mut Peekable, depth: usize) -> Result + where + I: Iterator, + { + if depth > MAX_JSON_DEPTH { + bail_parse_error!("Too deep") + }; + let current_depth = depth + 1; + skip_whitespace(input); + match input.peek() { + Some(b'{') => { + input.next(); // consume '{' + self.deserialize_obj(input, current_depth) + } + Some(b'[') => { + input.next(); // consume '[' + self.deserialize_array(input, current_depth) + } + Some(b't') => self.deserialize_true(input), + Some(b'f') => self.deserialize_false(input), + Some(b'n') => self.deserialize_null_or_nan(input), + Some(b'"') => self.deserialize_string(input), + Some(b'\'') => self.deserialize_string(input), + Some(&&c) + if c.is_ascii_digit() + || c == b'-' + || c == b'+' + || c == b'.' + || c.to_ascii_lowercase() == b'i' => + { + self.deserialize_number(input) + } + Some(ch) => bail_parse_error!("Unexpected character: {}", ch), + None => Ok(0), + } + } + + pub fn deserialize_obj<'a, I>(&mut self, input: &mut Peekable, depth: usize) -> Result + where + I: Iterator, + { + if depth > MAX_JSON_DEPTH { + bail_parse_error!("Too deep!") + } + let header_pos = self.len(); + self.write_element_header(header_pos, ElementType::OBJECT, 0)?; + let obj_start = self.len(); + let mut first = true; + let current_depth = depth + 1; + loop { + skip_whitespace(input); + + match input.peek() { + Some(&&b'}') => { + input.next(); // consume '}' + if first { + return Ok(1); // empty header + } else { + let obj_size = self.len() - obj_start; + self.write_element_header(header_pos, ElementType::OBJECT, obj_size)?; + return Ok(obj_size + 2); + } + } + Some(&&b',') if !first => { + input.next(); // consume ',' + skip_whitespace(input); + } + Some(_) => { + // Parse key (must be string) + self.deserialize_string(input)?; + + skip_whitespace(input); + + // Expect and consume ':' + if input.next() != Some(&b':') { + bail_parse_error!("Expected ':' after object key"); + } + + skip_whitespace(input); + + // Parse value - can be any JSON value including another object + self.deserialize_value(input, current_depth)?; + + first = false; + } + None => { + bail_parse_error!("Unexpected end of input!") + } + } + } + } + + pub fn deserialize_array<'a, I>( + &mut self, + input: &mut Peekable, + depth: usize, + ) -> Result + where + I: Iterator, + { + if depth > MAX_JSON_DEPTH { + bail_parse_error!("Too deep"); + } + let header_pos = self.len(); + self.write_element_header(header_pos, ElementType::ARRAY, 0)?; + let arr_start = self.len(); + let mut first = true; + let current_depth = depth + 1; + loop { + skip_whitespace(input); + + match input.peek() { + Some(&&b']') => { + input.next(); + if first { + return Ok(1); + } else { + let arr_len = self.len() - arr_start; + let header_size = + self.write_element_header(header_pos, ElementType::ARRAY, arr_len)?; + return Ok(arr_len + header_size); + } + } + Some(&&b',') if !first => { + input.next(); // consume ',' + skip_whitespace(input); + } + Some(_) => { + skip_whitespace(input); + self.deserialize_value(input, current_depth)?; + + first = false; + } + None => { + bail_parse_error!("Unexpected end of input!") + } + } + } + } + + fn deserialize_string<'a, I>(&mut self, input: &mut Peekable) -> Result + where + I: Iterator, + { + let string_start = self.len(); + let quote = input.next().unwrap(); // " + let quoted = quote == &b'"' || quote == &b'\''; + let mut len = 0; + self.write_element_header(string_start, ElementType::TEXT, 0)?; + let payload_start = self.len(); + + if input.peek().is_none() { + bail_parse_error!("Unexpected end of input in string handling"); + }; + + let mut element_type = ElementType::TEXT; + // This needed to support 1 char unquoted JSON5 keys + if !quoted { + self.data.push(*quote); + len += 1; + if let Some(&&c) = input.peek() { + if c == b':' { + self.write_element_header(string_start, element_type, len)?; + + return Ok(self.len() - payload_start); + } + } + }; + + while let Some(c) = input.next() { + if c == quote && quoted { + break; + } else if c == &b'\\' { + // Handle escapes + if let Some(&esc) = input.next() { + match esc { + b'b' => { + self.data.push(b'\\'); + self.data.push(b'b'); + len += 2; + element_type = ElementType::TEXTJ; + } + b'f' => { + self.data.push(b'\\'); + self.data.push(b'f'); + len += 2; + element_type = ElementType::TEXTJ; + } + b'n' => { + self.data.push(b'\\'); + self.data.push(b'n'); + len += 2; + element_type = ElementType::TEXTJ; + } + b'r' => { + self.data.push(b'\\'); + self.data.push(b'r'); + len += 2; + element_type = ElementType::TEXTJ; + } + b't' => { + self.data.push(b'\\'); + self.data.push(b't'); + len += 2; + element_type = ElementType::TEXTJ; + } + b'\\' | b'"' | b'/' => { + self.data.push(b'\\'); + self.data.push(esc); + len += 2; + element_type = ElementType::TEXTJ; + } + b'u' => { + // Unicode escape + element_type = ElementType::TEXTJ; + self.data.push(b'\\'); + self.data.push(b'u'); + len += 2; + for _ in 0..4 { + if let Some(&h) = input.next() { + if is_hex_digit(h) { + self.data.push(h); + len += 1; + } else { + bail_parse_error!("Incomplete Unicode escape"); + } + } else { + bail_parse_error!("Incomplete Unicode escape"); + } + } + } + // JSON5 extensions + b'\n' => { + element_type = ElementType::TEXT5; + self.data.push(b'\\'); + self.data.push(b'\n'); + len += 2; + } + b'\'' => { + element_type = ElementType::TEXT5; + self.data.push(b'\\'); + self.data.push(b'\''); + len += 2; + } + b'0' => { + element_type = ElementType::TEXT5; + self.data.push(b'\\'); + self.data.push(b'0'); + len += 2; + } + b'v' => { + element_type = ElementType::TEXT5; + self.data.push(b'\\'); + self.data.push(b'v'); + len += 2; + } + b'x' => { + element_type = ElementType::TEXT5; + self.data.push(b'\\'); + self.data.push(b'x'); + len += 2; + for _ in 0..2 { + if let Some(&h) = input.next() { + if is_hex_digit(h) { + self.data.push(h); + len += 1; + } else { + bail_parse_error!("Invalid hex escape sequence"); + } + } else { + bail_parse_error!("Incomplete hex escape sequence"); + } + } + } + _ => { + bail_parse_error!("Invalid escape sequence") + } + } + } else { + bail_parse_error!("Unexpected end of input in escape sequence"); + } + } else if c <= &0x1F { + element_type = ElementType::TEXT5; + self.data.push(*c); + len += 1; + } else { + self.data.push(*c); + len += 1; + } + if let Some(&&c) = input.peek() { + if (c == b':' || c.is_ascii_whitespace()) && !quoted { + break; + } + } + } + + // Write header and payload + self.write_element_header(string_start, element_type, len)?; + + Ok(self.len() - payload_start) + } + + pub fn deserialize_number<'a, I>(&mut self, input: &mut Peekable) -> Result + where + I: Iterator, + { + let num_start = self.len(); + let mut len = 0; + let mut is_float = false; + let mut is_json5 = false; + + // Dummy header + self.write_element_header(num_start, ElementType::INT, 0)?; + + // Handle sign + if input.peek() == Some(&&b'-') || input.peek() == Some(&&b'+') { + if input.peek() == Some(&&b'+') { + is_json5 = true; + input.next(); + } else { + self.data.push(*input.next().unwrap()); + len += 1; + } + } + + // Handle json5 float number + if input.peek() == Some(&&b'.') { + is_json5 = true; + }; + + // Check for hex (JSON5) + if input.peek() == Some(&&b'0') { + self.data.push(*input.next().unwrap()); + len += 1; + let next_ch = input.peek(); + if let Some(&&ch) = next_ch { + if ch == b'x' || ch == b'X' { + self.data.push(*input.next().unwrap()); + len += 1; + while let Some(&&byte) = input.peek() { + if is_hex_digit(byte) { + self.data.push(*input.next().unwrap()); + len += 1; + } else { + break; + } + } + + self.write_element_header(num_start, ElementType::INT5, len)?; + + return Ok(self.len() - num_start); + } else if ch.is_ascii_alphanumeric() { + bail_parse_error!("Leading zero is not allowed") + } + } + } + + // Check for Infinity + if input.peek().map(|x| x.to_ascii_lowercase()) == Some(b'i') { + for expected in b"infinity" { + if input.next().map(|x| x.to_ascii_lowercase()) != Some(*expected) { + bail_parse_error!("Failed to parse number"); + } + } + self.write_element_header( + num_start, + ElementType::FLOAT5, + len + INFINITY_CHAR_COUNT as usize, + )?; + + self.data.extend_from_slice(b"9e999"); + + return Ok(self.len() - num_start); + }; + + // Regular number parsing + while let Some(&&ch) = input.peek() { + match ch { + b'0'..=b'9' => { + self.data.push(*input.next().unwrap()); + len += 1; + } + b'.' => { + is_float = true; + self.data.push(*input.next().unwrap()); + let next_ch = input.peek(); + match next_ch { + Some(ch) => { + if !ch.is_ascii_alphanumeric() { + is_json5 = true; + } + } + None => { + is_json5 = true; + } + }; + len += 1; + } + b'e' | b'E' => { + is_float = true; + self.data.push(*input.next().unwrap()); + len += 1; + if input.peek() == Some(&&b'+') || input.peek() == Some(&&b'-') { + self.data.push(*input.next().unwrap()); + len += 1; + } + } + _ => break, + } + } + + // Write appropriate header and payload + let element_type = if is_float { + if is_json5 { + ElementType::FLOAT5 + } else { + ElementType::FLOAT + } + } else { + if is_json5 { + ElementType::INT5 + } else { + ElementType::INT + } + }; + + self.write_element_header(num_start, element_type, len)?; + + Ok(self.len() - num_start) + } + + pub fn deserialize_null_or_nan<'a, I>(&mut self, input: &mut Peekable) -> Result + where + I: Iterator, + { + let start = self.len(); + let nul = b"null"; + let nan = b"nan"; + let mut nan_score = 0; + let mut nul_score = 0; + for i in 0..4 { + if nan_score == 3 { + self.data.push(ElementType::NULL as u8); + return Ok(self.len() - start); + }; + let nul_ch = nul.get(i); + let nan_ch = nan.get(i); + let ch = input.next(); + if nan_ch != ch && nul_ch != ch { + bail_parse_error!("expected null or nan"); + } + if nan_ch == ch { + nan_score += 1; + } + if nul_ch == ch { + nul_score += 1; + } + } + if nul_score == 4 { + self.data.push(ElementType::NULL as u8); + Ok(self.len() - start) + } else { + bail_parse_error!("expected null or nan"); + } + } + + pub fn deserialize_true<'a, I>(&mut self, input: &mut Peekable) -> Result + where + I: Iterator, + { + let start = self.len(); + for expected in b"true" { + if input.next() != Some(expected) { + bail_parse_error!("Expected 'true'"); + } + } + self.data.push(ElementType::TRUE as u8); + Ok(self.len() - start) + } + + fn deserialize_false<'a, I>(&mut self, input: &mut Peekable) -> Result + where + I: Iterator, + { + let start = self.len(); + for expected in b"false" { + if input.next() != Some(expected) { + bail_parse_error!("Expected 'false'"); + } + } + self.data.push(ElementType::FALSE as u8); + Ok(self.len() - start) + } + + fn write_element_header( + &mut self, + cursor: usize, + element_type: ElementType, + payload_size: usize, + ) -> Result { + let header = JsonbHeader::new(element_type, payload_size).into_bytes(); + if cursor == self.len() { + for byte in header { + if byte != 0 { + self.data.push(byte); + } + } + } else { + self.data[cursor] = header[0]; + self.data.splice( + cursor + 1..cursor + 1, + header[1..].iter().filter(|&&x| x != 0).cloned(), + ); + } + Ok(header.iter().filter(|&&x| x != 0).count()) + } + + fn from_str(input: &str) -> Result { + let mut result = Self::new(input.len(), None); + let mut input_iter = input.as_bytes().iter().peekable(); + while input_iter.peek().is_some() { + result.deserialize_value(&mut input_iter, 0)?; + } + + Ok(result) + } + + pub fn data(self) -> Vec { + self.data + } +} + +impl std::str::FromStr for Jsonb { + type Err = LimboError; + + fn from_str(s: &str) -> std::result::Result { + Self::from_str(s) + } +} + +pub fn skip_whitespace<'a, I>(input: &mut Peekable) +where + I: Iterator, +{ + while let Some(&ch) = input.peek() { + match ch { + b' ' | b'\t' | b'\n' | b'\r' => { + input.next(); + } + b'/' => { + // Handle JSON5 comments + input.next(); + if let Some(&&next_ch) = input.peek() { + if next_ch == b'/' { + // Line comment - skip until newline + input.next(); + while let Some(&c) = input.next() { + if c == b'\n' { + break; + } + } + } else if next_ch == b'*' { + // Block comment - skip until "*/" + input.next(); + let mut prev = b'\0'; + while let Some(&c) = input.next() { + if prev == b'*' && c == b'/' { + break; + } + prev = c; + } + } else { + // Not a comment, put the '/' back + break; + } + } else { + break; + } + } + _ => break, + } + } +} + +fn is_hex_digit(b: u8) -> bool { + matches!(b, b'0'..=b'9' | b'a'..=b'f' | b'A'..=b'F') +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_null_serialization() { + // Create JSONB with null value + let mut jsonb = Jsonb::new(10, None); + jsonb.data.push(ElementType::NULL as u8); + + // Test serialization + let json_str = jsonb.to_string().unwrap(); + assert_eq!(json_str, "null"); + + // Test round-trip + let reparsed = Jsonb::from_str("null").unwrap(); + assert_eq!(reparsed.data[0] as u8, ElementType::NULL as u8); + } + + #[test] + fn test_boolean_serialization() { + // True + let mut jsonb_true = Jsonb::new(10, None); + jsonb_true.data.push(ElementType::TRUE as u8); + assert_eq!(jsonb_true.to_string().unwrap(), "true"); + + // False + let mut jsonb_false = Jsonb::new(10, None); + jsonb_false.data.push(ElementType::FALSE as u8); + assert_eq!(jsonb_false.to_string().unwrap(), "false"); + + // Round-trip + let true_parsed = Jsonb::from_str("true").unwrap(); + assert_eq!(true_parsed.data[0] as u8, ElementType::TRUE as u8); + + let false_parsed = Jsonb::from_str("false").unwrap(); + assert_eq!(false_parsed.data[0] as u8, ElementType::FALSE as u8); + } + + #[test] + fn test_integer_serialization() { + // Standard integer + let parsed = Jsonb::from_str("42").unwrap(); + assert_eq!(parsed.to_string().unwrap(), "42"); + + // Negative integer + let parsed = Jsonb::from_str("-123").unwrap(); + assert_eq!(parsed.to_string().unwrap(), "-123"); + + // Zero + let parsed = Jsonb::from_str("0").unwrap(); + assert_eq!(parsed.to_string().unwrap(), "0"); + + // Verify correct type + let header = JsonbHeader::from_slice(0, &parsed.data).unwrap().0; + assert!(matches!(header.0, ElementType::INT)); + } + + #[test] + fn test_json5_integer_serialization() { + // Hexadecimal notation + let parsed = Jsonb::from_str("0x1A").unwrap(); + assert_eq!(parsed.to_string().unwrap(), "26"); // Should convert to decimal + + // Positive sign (JSON5) + let parsed = Jsonb::from_str("+42").unwrap(); + assert_eq!(parsed.to_string().unwrap(), "42"); + + // Negative hexadecimal + let parsed = Jsonb::from_str("-0xFF").unwrap(); + assert_eq!(parsed.to_string().unwrap(), "-255"); + + // Verify correct type + let header = JsonbHeader::from_slice(0, &parsed.data).unwrap().0; + assert!(matches!(header.0, ElementType::INT5)); + } + + #[test] + fn test_float_serialization() { + // Standard float + let parsed = Jsonb::from_str("3.14159").unwrap(); + assert_eq!(parsed.to_string().unwrap(), "3.14159"); + + // Negative float + let parsed = Jsonb::from_str("-2.718").unwrap(); + assert_eq!(parsed.to_string().unwrap(), "-2.718"); + + // Scientific notation + let parsed = Jsonb::from_str("6.022e23").unwrap(); + assert_eq!(parsed.to_string().unwrap(), "6.022e23"); + + // Verify correct type + let header = JsonbHeader::from_slice(0, &parsed.data).unwrap().0; + assert!(matches!(header.0, ElementType::FLOAT)); + } + + #[test] + fn test_json5_float_serialization() { + // Leading decimal point + let parsed = Jsonb::from_str(".123").unwrap(); + assert_eq!(parsed.to_string().unwrap(), "0.123"); + + // Trailing decimal point + let parsed = Jsonb::from_str("42.").unwrap(); + assert_eq!(parsed.to_string().unwrap(), "42.0"); + + // Plus sign in exponent + let parsed = Jsonb::from_str("1.5e+10").unwrap(); + assert_eq!(parsed.to_string().unwrap(), "1.5e+10"); + + // Infinity + let parsed = Jsonb::from_str("Infinity").unwrap(); + assert_eq!(parsed.to_string().unwrap(), "9e999"); + + // Negative Infinity + let parsed = Jsonb::from_str("-Infinity").unwrap(); + assert_eq!(parsed.to_string().unwrap(), "-9e999"); + + // Verify correct type + let header = JsonbHeader::from_slice(0, &parsed.data).unwrap().0; + assert!(matches!(header.0, ElementType::FLOAT5)); + } + + #[test] + fn test_string_serialization() { + // Simple string + let parsed = Jsonb::from_str(r#""hello world""#).unwrap(); + assert_eq!(parsed.to_string().unwrap(), r#""hello world""#); + + // String with escaped characters + let parsed = Jsonb::from_str(r#""hello\nworld""#).unwrap(); + assert_eq!(parsed.to_string().unwrap(), r#""hello\nworld""#); + + // Unicode escape + let parsed = Jsonb::from_str(r#""hello\u0020world""#).unwrap(); + assert_eq!(parsed.to_string().unwrap(), r#""hello\u0020world""#); + + // Verify correct type + let header = JsonbHeader::from_slice(0, &parsed.data).unwrap().0; + assert!(matches!(header.0, ElementType::TEXTJ)); + } + + #[test] + fn test_json5_string_serialization() { + // Single quotes + let parsed = Jsonb::from_str("'hello world'").unwrap(); + assert_eq!(parsed.to_string().unwrap(), r#""hello world""#); + + // Hex escape + let parsed = Jsonb::from_str(r#"'\x41\x42\x43'"#).unwrap(); + assert_eq!(parsed.to_string().unwrap(), r#""\u0041\u0042\u0043""#); + + // Multiline string with line continuation + let parsed = Jsonb::from_str( + r#""hello \ +world""#, + ) + .unwrap(); + assert_eq!(parsed.to_string().unwrap(), r#""hello world""#); + + // Escaped single quote + let parsed = Jsonb::from_str(r#"'Don\'t worry'"#).unwrap(); + assert_eq!(parsed.to_string().unwrap(), r#""Don't worry""#); + + // Verify correct type + let header = JsonbHeader::from_slice(0, &parsed.data).unwrap().0; + assert!(matches!(header.0, ElementType::TEXT5)); + } + + #[test] + fn test_array_serialization() { + // Empty array + let parsed = Jsonb::from_str("[]").unwrap(); + assert_eq!(parsed.to_string().unwrap(), "[]"); + + // Simple array + let parsed = Jsonb::from_str("[1,2,3]").unwrap(); + assert_eq!(parsed.to_string().unwrap(), "[1,2,3]"); + + // Nested array + let parsed = Jsonb::from_str("[[1,2],[3,4]]").unwrap(); + assert_eq!(parsed.to_string().unwrap(), "[[1,2],[3,4]]"); + + // Mixed types array + let parsed = Jsonb::from_str(r#"[1,"text",true,null,{"key":"value"}]"#).unwrap(); + assert_eq!( + parsed.to_string().unwrap(), + r#"[1,"text",true,null,{"key":"value"}]"# + ); + + // Verify correct type + let header = JsonbHeader::from_slice(0, &parsed.data).unwrap().0; + assert!(matches!(header.0, ElementType::ARRAY)); + } + + #[test] + fn test_json5_array_serialization() { + // Trailing comma + let parsed = Jsonb::from_str("[1,2,3,]").unwrap(); + assert_eq!(parsed.to_string().unwrap(), "[1,2,3]"); + + // Comments in array + let parsed = Jsonb::from_str("[1,/* comment */2,3]").unwrap(); + assert_eq!(parsed.to_string().unwrap(), "[1,2,3]"); + + // Line comment in array + let parsed = Jsonb::from_str("[1,// line comment\n2,3]").unwrap(); + assert_eq!(parsed.to_string().unwrap(), "[1,2,3]"); + } + + #[test] + fn test_object_serialization() { + // Empty object + let parsed = Jsonb::from_str("{}").unwrap(); + assert_eq!(parsed.to_string().unwrap(), "{}"); + + // Simple object + let parsed = Jsonb::from_str(r#"{"key":"value"}"#).unwrap(); + assert_eq!(parsed.to_string().unwrap(), r#"{"key":"value"}"#); + + // Multiple properties + let parsed = Jsonb::from_str(r#"{"a":1,"b":2,"c":3}"#).unwrap(); + assert_eq!(parsed.to_string().unwrap(), r#"{"a":1,"b":2,"c":3}"#); + + // Nested object + let parsed = Jsonb::from_str(r#"{"outer":{"inner":"value"}}"#).unwrap(); + assert_eq!( + parsed.to_string().unwrap(), + r#"{"outer":{"inner":"value"}}"# + ); + + // Mixed values + let parsed = + Jsonb::from_str(r#"{"str":"text","num":42,"bool":true,"null":null,"arr":[1,2]}"#) + .unwrap(); + assert_eq!( + parsed.to_string().unwrap(), + r#"{"str":"text","num":42,"bool":true,"null":null,"arr":[1,2]}"# + ); + + // Verify correct type + let header = JsonbHeader::from_slice(0, &parsed.data).unwrap().0; + assert!(matches!(header.0, ElementType::OBJECT)); + } + + #[test] + fn test_json5_object_serialization() { + // Unquoted keys + let parsed = Jsonb::from_str("{key:\"value\"}").unwrap(); + assert_eq!(parsed.to_string().unwrap(), r#"{"key":"value"}"#); + + // Trailing comma + let parsed = Jsonb::from_str(r#"{"a":1,"b":2,}"#).unwrap(); + assert_eq!(parsed.to_string().unwrap(), r#"{"a":1,"b":2}"#); + + // Comments in object + let parsed = Jsonb::from_str(r#"{"a":1,/*comment*/"b":2}"#).unwrap(); + assert_eq!(parsed.to_string().unwrap(), r#"{"a":1,"b":2}"#); + + // Single quotes for keys and values + let parsed = Jsonb::from_str("{'a':'value'}").unwrap(); + assert_eq!(parsed.to_string().unwrap(), r#"{"a":"value"}"#); + } + + #[test] + fn test_complex_json() { + let complex_json = r#"{ + "string": "Hello, world!", + "number": 42, + "float": 3.14159, + "boolean": true, + "null": null, + "array": [1, 2, 3, "text", {"nested": "object"}], + "object": { + "key1": "value1", + "key2": [4, 5, 6], + "key3": { + "nested": true + } + } + }"#; + + let parsed = Jsonb::from_str(complex_json).unwrap(); + // Round-trip test + let reparsed = Jsonb::from_str(&parsed.to_string().unwrap()).unwrap(); + assert_eq!(parsed.to_string().unwrap(), reparsed.to_string().unwrap()); + } + + #[test] + fn test_error_handling() { + // Invalid JSON syntax + assert!(Jsonb::from_str("{").is_err()); + assert!(Jsonb::from_str("[").is_err()); + assert!(Jsonb::from_str("}").is_err()); + assert!(Jsonb::from_str("]").is_err()); + + // Unclosed string + assert!(Jsonb::from_str(r#"{"key":"value"#).is_err()); + + // Invalid number format + assert!(Jsonb::from_str("01234").is_err()); // Leading zero not allowed in JSON + + // Invalid escape sequence + assert!(Jsonb::from_str(r#""\z""#).is_err()); + + // Missing colon in object + assert!(Jsonb::from_str(r#"{"key" "value"}"#).is_err()); + + // Trailing characters + assert!(Jsonb::from_str(r#"{"key":"value"} extra"#).is_err()); + } + + #[test] + fn test_depth_limit() { + // Create a JSON string that exceeds MAX_JSON_DEPTH + let mut deep_json = String::from("["); + for _ in 0..MAX_JSON_DEPTH + 1 { + deep_json.push_str("["); + } + for _ in 0..MAX_JSON_DEPTH + 1 { + deep_json.push_str("]"); + } + deep_json.push_str("]"); + + // Should fail due to exceeding depth limit + assert!(Jsonb::from_str(&deep_json).is_err()); + } + + #[test] + fn test_header_encoding() { + // Small payload (fits in 4 bits) + let header = JsonbHeader::new(ElementType::TEXT, 5); + let bytes = header.into_bytes(); + assert_eq!(bytes[0], (5 << 4) | (ElementType::TEXT as u8)); + + // Medium payload (8-bit) + let header = JsonbHeader::new(ElementType::TEXT, 200); + let bytes = header.into_bytes(); + assert_eq!(bytes[0], (PAYLOAD_SIZE8 << 4) | (ElementType::TEXT as u8)); + assert_eq!(bytes[1], 200); + + // Large payload (16-bit) + let header = JsonbHeader::new(ElementType::TEXT, 40000); + let bytes = header.into_bytes(); + assert_eq!(bytes[0], (PAYLOAD_SIZE16 << 4) | (ElementType::TEXT as u8)); + assert_eq!(bytes[1], (40000 >> 8) as u8); + assert_eq!(bytes[2], (40000 & 0xFF) as u8); + + // Extra large payload (32-bit) + let header = JsonbHeader::new(ElementType::TEXT, 70000); + let bytes = header.into_bytes(); + assert_eq!(bytes[0], (PAYLOAD_SIZE32 << 4) | (ElementType::TEXT as u8)); + assert_eq!(bytes[1], (70000 >> 24) as u8); + assert_eq!(bytes[2], ((70000 >> 16) & 0xFF) as u8); + assert_eq!(bytes[3], ((70000 >> 8) & 0xFF) as u8); + assert_eq!(bytes[4], (70000 & 0xFF) as u8); + } + + #[test] + fn test_header_decoding() { + // Create sample data with various headers + let data = vec![ + (5 << 4) | (ElementType::TEXT as u8), + (PAYLOAD_SIZE8 << 4) | (ElementType::ARRAY as u8), + 150, + (PAYLOAD_SIZE16 << 4) | (ElementType::OBJECT as u8), + 0x98, + 0x68, + ]; + + // Parse and verify each header + let (header1, offset1) = JsonbHeader::from_slice(0, &data).unwrap(); + assert_eq!(offset1, 1); + assert_eq!(header1.0, ElementType::TEXT); + assert_eq!(header1.1, 5); + + let (header2, offset2) = JsonbHeader::from_slice(1, &data).unwrap(); + assert_eq!(offset2, 2); + assert_eq!(header2.0, ElementType::ARRAY); + assert_eq!(header2.1, 150); + + let (header3, offset3) = JsonbHeader::from_slice(3, &data).unwrap(); + assert_eq!(offset3, 3); + assert_eq!(header3.0, ElementType::OBJECT); + assert_eq!(header3.1, 0x9868); // 39000 + } + + #[test] + fn test_unicode_escapes() { + // Basic unicode escape + let parsed = Jsonb::from_str(r#""\u00A9""#).unwrap(); // Copyright symbol + assert_eq!(parsed.to_string().unwrap(), r#""\u00A9""#); + + // Non-BMP character (surrogate pair) + let parsed = Jsonb::from_str(r#""\uD83D\uDE00""#).unwrap(); // Smiley emoji + assert_eq!(parsed.to_string().unwrap(), r#""\uD83D\uDE00""#); + } + + #[test] + fn test_json5_comments() { + // Line comments + let parsed = Jsonb::from_str( + r#"{ + // This is a line comment + "key": "value" + }"#, + ) + .unwrap(); + assert_eq!(parsed.to_string().unwrap(), r#"{"key":"value"}"#); + + // Block comments + let parsed = Jsonb::from_str( + r#"{ + /* This is a + block comment */ + "key": "value" + }"#, + ) + .unwrap(); + assert_eq!(parsed.to_string().unwrap(), r#"{"key":"value"}"#); + + // Comments inside array + let parsed = Jsonb::from_str( + r#"[1, // Comment + 2, /* Another comment */ 3]"#, + ) + .unwrap(); + assert_eq!(parsed.to_string().unwrap(), "[1,2,3]"); + } + + #[test] + fn test_whitespace_handling() { + // Various whitespace patterns + let json_with_whitespace = r#" + { + "key1" : "value1" , + "key2": [ 1, 2, 3 ] , + "key3": { + "nested" : true + } + } + "#; + + let parsed = Jsonb::from_str(json_with_whitespace).unwrap(); + assert_eq!( + parsed.to_string().unwrap(), + r#"{"key1":"value1","key2":[1,2,3],"key3":{"nested":true}}"# + ); + } + + #[test] + fn test_binary_roundtrip() { + // Test that binary data can be round-tripped through the JSONB format + let original = r#"{"test":"value","array":[1,2,3]}"#; + let parsed = Jsonb::from_str(original).unwrap(); + let binary_data = parsed.data.clone(); + + // Create a new Jsonb from the binary data + let from_binary = Jsonb::new(0, Some(&binary_data)); + assert_eq!(from_binary.to_string().unwrap(), original); + } + + #[test] + fn test_large_json() { + // Generate a large JSON with many elements + let mut large_array = String::from("["); + for i in 0..1000 { + large_array.push_str(&format!("{}", i)); + if i < 999 { + large_array.push_str(","); + } + } + large_array.push_str("]"); + + let parsed = Jsonb::from_str(&large_array).unwrap(); + assert!(parsed.to_string().unwrap().starts_with("[0,1,2,")); + assert!(parsed.to_string().unwrap().ends_with("998,999]")); + } + + #[test] + fn test_jsonb_is_valid() { + // Valid JSONB + let jsonb = Jsonb::from_str(r#"{"test":"value"}"#).unwrap(); + assert!(jsonb.is_valid().is_ok()); + + // Invalid JSONB (manually corrupted) + let mut invalid = jsonb.data.clone(); + if !invalid.is_empty() { + invalid[0] = 0xFF; // Invalid element type + let jsonb = Jsonb::new(0, Some(&invalid)); + assert!(jsonb.is_valid().is_err()); + } + } + + #[test] + fn test_special_characters_in_strings() { + // Test handling of various special characters + let json = r#"{ + "escaped_quotes": "He said \"Hello\"", + "backslashes": "C:\\Windows\\System32", + "control_chars": "\b\f\n\r\t", + "unicode": "\u00A9 2023" + }"#; + + let parsed = Jsonb::from_str(json).unwrap(); + let result = parsed.to_string().unwrap(); + + assert!(result.contains(r#""escaped_quotes":"He said \"Hello\"""#)); + assert!(result.contains(r#""backslashes":"C:\\Windows\\System32""#)); + assert!(result.contains(r#""control_chars":"\b\f\n\r\t""#)); + assert!(result.contains(r#""unicode":"\u00A9 2023""#)); + } +} diff --git a/core/json/mod.rs b/core/json/mod.rs index f7a2e0205..6f8b571f8 100644 --- a/core/json/mod.rs +++ b/core/json/mod.rs @@ -2,20 +2,23 @@ mod de; mod error; mod json_operations; mod json_path; +mod jsonb; mod ser; pub use crate::json::de::from_str; -use crate::json::de::ordered_object; use crate::json::error::Error as JsonError; pub use crate::json::json_operations::{json_patch, json_remove}; use crate::json::json_path::{json_path, JsonPath, PathElement}; pub use crate::json::ser::to_string; use crate::types::{OwnedValue, Text, TextSubtype}; +use crate::{bail_parse_error, json::de::ordered_object}; use indexmap::IndexMap; -use jsonb::Error as JsonbError; +use jsonb::Jsonb; use ser::to_string_pretty; use serde::{Deserialize, Serialize}; use std::borrow::Cow; +use std::rc::Rc; +use std::str::FromStr; #[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] #[serde(untagged)] @@ -49,13 +52,12 @@ pub fn get_json(json_value: &OwnedValue, indent: Option<&str>) -> crate::Result< Ok(OwnedValue::Text(Text::json(&json))) } OwnedValue::Blob(b) => { - // TODO: use get_json_value after we implement a single Struct - // to represent both JSON and JSONB - if let Ok(json) = jsonb::from_slice(b) { - Ok(OwnedValue::Text(Text::json(&json.to_string()))) - } else { - crate::bail_parse_error!("malformed JSON"); - } + let jsonbin = Jsonb::new(b.len(), Some(b)); + jsonbin.is_valid()?; + Ok(OwnedValue::Text(Text { + value: Rc::new(jsonbin.to_string()?.into_bytes()), + subtype: TextSubtype::Json, + })) } OwnedValue::Null => Ok(OwnedValue::Null), _ => { @@ -70,6 +72,28 @@ pub fn get_json(json_value: &OwnedValue, indent: Option<&str>) -> crate::Result< } } +pub fn jsonb(json_value: &OwnedValue) -> crate::Result { + let jsonbin = match json_value { + OwnedValue::Null | OwnedValue::Integer(_) | OwnedValue::Float(_) | OwnedValue::Text(_) => { + Jsonb::from_str(&json_value.to_string()) + } + OwnedValue::Blob(blob) => { + let blob = Jsonb::new(blob.len(), Some(&blob)); + blob.is_valid()?; + Ok(blob) + } + _ => { + unimplemented!() + } + }; + match jsonbin { + Ok(jsonbin) => Ok(OwnedValue::Blob(Rc::new(jsonbin.data()))), + Err(_) => { + bail_parse_error!("malformed JSON") + } + } +} + fn get_json_value(json_value: &OwnedValue) -> crate::Result { match json_value { OwnedValue::Text(ref t) => match from_str::(t.as_str()) { @@ -78,12 +102,8 @@ fn get_json_value(json_value: &OwnedValue) -> crate::Result { crate::bail_parse_error!("malformed JSON") } }, - OwnedValue::Blob(b) => { - if let Ok(_json) = jsonb::from_slice(b) { - todo!("jsonb to json conversion"); - } else { - crate::bail_parse_error!("malformed JSON"); - } + OwnedValue::Blob(_) => { + crate::bail_parse_error!("malformed JSON"); } OwnedValue::Null => Ok(Val::Null), OwnedValue::Float(f) => Ok(Val::Float(*f)), @@ -625,13 +645,9 @@ pub fn json_error_position(json: &OwnedValue) -> crate::Result { } } }, - OwnedValue::Blob(b) => match jsonb::from_slice(b) { - Ok(_) => Ok(OwnedValue::Integer(0)), - Err(JsonbError::Syntax(_, pos)) => Ok(OwnedValue::Integer(pos as i64)), - _ => Err(crate::error::LimboError::InternalError( - "failed to determine json error position".into(), - )), - }, + OwnedValue::Blob(_) => { + bail_parse_error!("Unsupported") + } OwnedValue::Null => Ok(OwnedValue::Null), _ => Ok(OwnedValue::Integer(0)), } @@ -667,10 +683,9 @@ pub fn is_json_valid(json_value: &OwnedValue) -> crate::Result { Ok(_) => Ok(OwnedValue::Integer(1)), Err(_) => Ok(OwnedValue::Integer(0)), }, - OwnedValue::Blob(b) => match jsonb::from_slice(b) { - Ok(_) => Ok(OwnedValue::Integer(1)), - Err(_) => Ok(OwnedValue::Integer(0)), - }, + OwnedValue::Blob(_) => { + bail_parse_error!("Unsuported!") + } OwnedValue::Null => Ok(OwnedValue::Null), _ => Ok(OwnedValue::Integer(1)), } @@ -814,11 +829,11 @@ mod tests { #[test] fn test_get_json_blob_valid_jsonb() { - let binary_json = b"\x40\0\0\x01\x10\0\0\x03\x10\0\0\x03\x61\x73\x64\x61\x64\x66".to_vec(); + let binary_json = vec![124, 55, 104, 101, 121, 39, 121, 111]; let input = OwnedValue::Blob(Rc::new(binary_json)); let result = get_json(&input, None).unwrap(); if let OwnedValue::Text(result_str) = result { - assert!(result_str.as_str().contains("\"asd\":\"adf\"")); + assert!(result_str.as_str().contains(r#"{"hey":"yo"}"#)); assert_eq!(result_str.subtype, TextSubtype::Json); } else { panic!("Expected OwnedValue::Text"); @@ -830,6 +845,7 @@ mod tests { let binary_json: Vec = vec![0xA2, 0x62, 0x6B, 0x31, 0x62, 0x76]; // Incomplete binary JSON let input = OwnedValue::Blob(Rc::new(binary_json)); let result = get_json(&input, None); + println!("{:?}", result); match result { Ok(_) => panic!("Expected error for malformed JSON"), Err(e) => assert!(e.to_string().contains("malformed JSON")), @@ -1070,13 +1086,6 @@ mod tests { assert_eq!(result, OwnedValue::Integer(0)); } - #[test] - fn test_json_error_position_blob() { - let input = OwnedValue::Blob(Rc::new(r#"["a",55,"b",72,,]"#.as_bytes().to_owned())); - let result = json_error_position(&input).unwrap(); - assert_eq!(result, OwnedValue::Integer(16)); - } - #[test] fn test_json_object_simple() { let key = OwnedValue::build_text("key"); diff --git a/core/lib.rs b/core/lib.rs index 9632e1829..eb2036d9e 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -23,11 +23,8 @@ mod vector; #[global_allocator] static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; +use ext::list_vfs_modules; use fallible_iterator::FallibleIterator; -#[cfg(not(target_family = "wasm"))] -use libloading::{Library, Symbol}; -#[cfg(not(target_family = "wasm"))] -use limbo_ext::{ExtensionApi, ExtensionEntryPoint}; use limbo_ext::{ResultCode, VTabKind, VTabModuleImpl}; use limbo_sqlite3_parser::{ast, ast::Cmd, lexer::sql::Parser}; use parking_lot::RwLock; @@ -204,6 +201,31 @@ impl Database { } Ok(conn) } + + /// Open a new database file with a specified VFS without an existing database + /// connection and symbol table to register extensions. + #[cfg(feature = "fs")] + #[allow(clippy::arc_with_non_send_sync)] + pub fn open_new(path: &str, vfs: &str) -> Result<(Arc, Arc)> { + let vfsmods = ext::add_builtin_vfs_extensions(None)?; + let io: Arc = match vfsmods.iter().find(|v| v.0 == vfs).map(|v| v.1.clone()) { + Some(vfs) => vfs, + None => match vfs.trim() { + "memory" => Arc::new(MemoryIO::new()), + "syscall" => Arc::new(PlatformIO::new()?), + #[cfg(all(target_os = "linux", feature = "io_uring"))] + "io_uring" => Arc::new(UringIO::new()?), + other => { + return Err(LimboError::InvalidArgument(format!( + "no such VFS: {}", + other + ))); + } + }, + }; + let db = Self::open_file(io.clone(), path, false)?; + Ok((io, db)) + } } pub fn maybe_init_database_file(file: &Arc, io: &Arc) -> Result<()> { @@ -279,8 +301,7 @@ impl Connection { match cmd { Cmd::Stmt(stmt) => { let program = Rc::new(translate::translate( - &self - .schema + self.schema .try_read() .ok_or(LimboError::SchemaLocked)? .deref(), @@ -321,8 +342,7 @@ impl Connection { match cmd { Cmd::Stmt(stmt) => { let program = Rc::new(translate::translate( - &self - .schema + self.schema .try_read() .ok_or(LimboError::SchemaLocked)? .deref(), @@ -338,8 +358,7 @@ impl Connection { } Cmd::Explain(stmt) => { let program = translate::translate( - &self - .schema + self.schema .try_read() .ok_or(LimboError::SchemaLocked)? .deref(), @@ -357,8 +376,7 @@ impl Connection { match stmt { ast::Stmt::Select(select) => { let mut plan = prepare_select_plan( - &self - .schema + self.schema .try_read() .ok_or(LimboError::SchemaLocked)? .deref(), @@ -368,8 +386,7 @@ impl Connection { )?; optimize_plan( &mut plan, - &self - .schema + self.schema .try_read() .ok_or(LimboError::SchemaLocked)? .deref(), @@ -387,6 +404,8 @@ impl Connection { QueryRunner::new(self, sql) } + /// Execute will run a query from start to finish taking ownership of I/O because it will run pending I/Os if it didn't finish. + /// TODO: make this api async pub fn execute(self: &Rc, sql: impl AsRef) -> Result<()> { let sql = sql.as_ref(); let mut parser = Parser::new(sql.as_bytes()); @@ -396,8 +415,7 @@ impl Connection { match cmd { Cmd::Explain(stmt) => { let program = translate::translate( - &self - .schema + self.schema .try_read() .ok_or(LimboError::SchemaLocked)? .deref(), @@ -413,8 +431,7 @@ impl Connection { Cmd::ExplainQueryPlan(_stmt) => todo!(), Cmd::Stmt(stmt) => { let program = translate::translate( - &self - .schema + self.schema .try_read() .ok_or(LimboError::SchemaLocked)? .deref(), @@ -428,7 +445,17 @@ impl Connection { let mut state = vdbe::ProgramState::new(program.max_registers, program.cursor_ref.len()); - program.step(&mut state, self._db.mv_store.clone(), self.pager.clone())?; + loop { + let res = program.step( + &mut state, + self._db.mv_store.clone(), + self.pager.clone(), + )?; + if matches!(res, StepResult::Done) { + break; + } + self._db.io.run_once()?; + } } } } @@ -449,30 +476,6 @@ impl Connection { Ok(checkpoint_result) } - #[cfg(not(target_family = "wasm"))] - pub fn load_extension>(&self, path: P) -> Result<()> { - let api = Box::new(self.build_limbo_ext()); - let lib = - unsafe { Library::new(path).map_err(|e| LimboError::ExtensionError(e.to_string()))? }; - let entry: Symbol = unsafe { - lib.get(b"register_extension") - .map_err(|e| LimboError::ExtensionError(e.to_string()))? - }; - let api_ptr: *const ExtensionApi = Box::into_raw(api); - let result_code = unsafe { entry(api_ptr) }; - if result_code.is_ok() { - self.syms.borrow_mut().extensions.push((lib, api_ptr)); - Ok(()) - } else { - if !api_ptr.is_null() { - let _ = unsafe { Box::from_raw(api_ptr.cast_mut()) }; - } - Err(LimboError::ExtensionError( - "Extension registration failed".to_string(), - )) - } - } - /// Close a connection and checkpoint. pub fn close(&self) -> Result<()> { loop { @@ -505,6 +508,28 @@ impl Connection { pub fn total_changes(&self) -> i64 { self.total_changes.get() } + + #[cfg(feature = "fs")] + pub fn open_new(&self, path: &str, vfs: &str) -> Result<(Arc, Arc)> { + Database::open_with_vfs(&self._db, path, vfs) + } + + pub fn list_vfs(&self) -> Vec { + let mut all_vfs = vec![String::from("memory")]; + #[cfg(feature = "fs")] + { + #[cfg(all(feature = "fs", target_family = "unix"))] + { + all_vfs.push("syscall".to_string()); + } + #[cfg(all(feature = "fs", target_os = "linux", feature = "io_uring"))] + { + all_vfs.push("io_uring".to_string()); + } + } + all_vfs.extend(list_vfs_modules()); + all_vfs + } } pub struct Statement { @@ -542,6 +567,10 @@ impl Statement { .step(&mut self.state, self.mv_store.clone(), self.pager.clone()) } + pub fn run_once(&self) -> Result<()> { + self.pager.io.run_once() + } + pub fn num_columns(&self) -> usize { self.program.result_columns.len() } @@ -707,8 +736,6 @@ impl VirtualTable { pub(crate) struct SymbolTable { pub functions: HashMap>, - #[cfg(not(target_family = "wasm"))] - extensions: Vec<(Library, *const ExtensionApi)>, pub vtabs: HashMap>, pub vtab_modules: HashMap>, } @@ -753,8 +780,6 @@ impl SymbolTable { Self { functions: HashMap::new(), vtabs: HashMap::new(), - #[cfg(not(target_family = "wasm"))] - extensions: Vec::new(), vtab_modules: HashMap::new(), } } diff --git a/core/storage/btree.rs b/core/storage/btree.rs index 0069a8f76..bdbaf7ce9 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -284,7 +284,7 @@ impl BTreeCursor { } let cell_idx = cell_idx as usize; - debug!( + tracing::trace!( "get_prev_record current id={} cell={}", page.get().id, cell_idx @@ -359,7 +359,7 @@ impl BTreeCursor { let mem_page_rc = self.stack.top(); let cell_idx = self.stack.current_cell_index() as usize; - debug!("current id={} cell={}", mem_page_rc.get().id, cell_idx); + tracing::trace!("current id={} cell={}", mem_page_rc.get().id, cell_idx); return_if_locked!(mem_page_rc); if !mem_page_rc.is_loaded() { self.pager.load_page(mem_page_rc.clone())?; @@ -846,8 +846,7 @@ impl BTreeCursor { cell_payload.as_slice(), cell_idx, self.usable_space() as u16, - ) - .unwrap(); + )?; contents.overflow_cells.len() }; let write_info = self @@ -1658,7 +1657,7 @@ impl BTreeCursor { } pub fn rewind(&mut self) -> Result> { - if let Some(_) = &self.mv_cursor { + if self.mv_cursor.is_some() { let (rowid, record) = return_if_io!(self.get_next_record(None)); self.rowid.replace(rowid); self.record.replace(record); @@ -2314,41 +2313,49 @@ impl CellArray { fn find_free_cell(page_ref: &PageContent, usable_space: u16, amount: usize) -> Result { // NOTE: freelist is in ascending order of keys and pc // unuse_space is reserved bytes at the end of page, therefore we must substract from maxpc - let mut pc = page_ref.first_freeblock() as usize; let mut prev_pc = page_ref.offset + PAGE_HEADER_OFFSET_FIRST_FREEBLOCK; + let mut pc = page_ref.first_freeblock() as usize; + let maxpc = usable_space as usize - amount; - let buf = page_ref.as_ptr(); - - let usable_space = usable_space as usize; - let maxpc = usable_space - amount; while pc <= maxpc { - let next = u16::from_be_bytes(buf[pc..pc + 2].try_into().unwrap()); - let size = u16::from_be_bytes(buf[pc + 2..pc + 4].try_into().unwrap()); - if amount <= size as usize { - if amount == size as usize { - // delete whole thing - page_ref.write_u16(PAGE_HEADER_OFFSET_FIRST_FREEBLOCK, next); - } else { - // take only the part we are interested in by reducing the size - let new_size = size - amount as u16; - // size includes 4 bytes of freeblock - // we need to leave the free block at least - if new_size >= 4 { - buf[pc + 2..pc + 4].copy_from_slice(&new_size.to_be_bytes()); - } else { - // increase fragment size and delete entry from free list - buf[prev_pc..prev_pc + 2].copy_from_slice(&next.to_be_bytes()); - let frag = page_ref.num_frag_free_bytes() + new_size as u8; - page_ref.write_u8(PAGE_HEADER_OFFSET_FRAGMENTED_BYTES_COUNT, frag); - } - pc += new_size as usize; - } - return Ok(pc); + if pc + 4 > usable_space as usize { + return_corrupt!("Free block header extends beyond page"); } + + let next = page_ref.read_u16_no_offset(pc); + let size = page_ref.read_u16_no_offset(pc + 2); + + if amount <= size as usize { + let new_size = size as usize - amount; + if new_size < 4 { + // The code is checking if using a free slot that would leave behind a very small fragment (x < 4 bytes) + // would cause the total fragmentation to exceed the limit of 60 bytes + // check sqlite docs https://www.sqlite.org/fileformat.html#:~:text=A%20freeblock%20requires,not%20exceed%2060 + if page_ref.num_frag_free_bytes() > 57 { + return Ok(0); + } + // Delete the slot from freelist and update the page's fragment count. + page_ref.write_u16(prev_pc, next); + let frag = page_ref.num_frag_free_bytes() + new_size as u8; + page_ref.write_u8(PAGE_HEADER_OFFSET_FRAGMENTED_BYTES_COUNT, frag); + return Ok(pc); + } else if new_size + pc > maxpc { + return_corrupt!("Free block extends beyond page end"); + } else { + // Requested amount fits inside the current free slot so we reduce its size + // to account for newly allocated space. + page_ref.write_u16(pc + 2, new_size as u16); + return Ok(pc + new_size); + } + } + prev_pc = pc; pc = next as usize; - if pc <= prev_pc && pc != 0 { - return_corrupt!("Free list not in ascending order"); + if pc <= prev_pc { + if pc != 0 { + return_corrupt!("Free list not in ascending order"); + } + return Ok(0); } } if pc > maxpc + amount - 4 { @@ -2523,6 +2530,14 @@ fn free_cell_range( len: u16, usable_space: u16, ) -> Result<()> { + if len < 4 { + return_corrupt!("Minimum cell size is 4"); + } + + if offset > usable_space.saturating_sub(4) { + return_corrupt!("Start offset beyond usable space"); + } + let mut size = len; let mut end = offset + len; let mut pointer_to_pc = page.offset as u16 + 1; @@ -2584,7 +2599,6 @@ fn free_cell_range( } let frag = page.num_frag_free_bytes() - removed_fragmentation; page.write_u8(PAGE_HEADER_OFFSET_FRAGMENTED_BYTES_COUNT, frag); - pc }; @@ -2602,6 +2616,7 @@ fn free_cell_range( page.write_u16_no_offset(offset as usize, pc); page.write_u16_no_offset(offset as usize + 2, size); } + Ok(()) } @@ -3037,7 +3052,6 @@ mod tests { use std::sync::Arc; use std::sync::Mutex; - use rand::{thread_rng, Rng}; use tempfile::TempDir; use crate::{ @@ -3526,21 +3540,25 @@ mod tests { } #[test] + #[ignore] pub fn btree_insert_fuzz_run_random() { btree_insert_fuzz_run(128, 16, |rng| (rng.next_u32() % 4096) as usize); } #[test] + #[ignore] pub fn btree_insert_fuzz_run_small() { btree_insert_fuzz_run(1, 1024, |rng| (rng.next_u32() % 128) as usize); } #[test] + #[ignore] pub fn btree_insert_fuzz_run_big() { btree_insert_fuzz_run(64, 32, |rng| 3 * 1024 + (rng.next_u32() % 1024) as usize); } #[test] + #[ignore] pub fn btree_insert_fuzz_run_overflow() { btree_insert_fuzz_run(64, 32, |rng| (rng.next_u32() % 32 * 1024) as usize); } @@ -3924,7 +3942,9 @@ mod tests { let mut cells = Vec::new(); let usable_space = 4096; let mut i = 1000; - let seed = thread_rng().gen(); + // let seed = thread_rng().gen(); + // let seed = 15292777653676891381; + let seed = 9261043168681395159; tracing::info!("seed {}", seed); let mut rng = ChaCha8Rng::seed_from_u64(seed); while i > 0 { @@ -3978,6 +3998,76 @@ mod tests { } } + #[test] + pub fn test_fuzz_drop_defragment_insert_issue_1085() { + // This test is used to demonstrate that issue at https://github.com/tursodatabase/limbo/issues/1085 + // is FIXED. + let db = get_database(); + let conn = db.connect().unwrap(); + + let page = get_page(2); + let page = page.get_contents(); + let header_size = 8; + + let mut total_size = 0; + let mut cells = Vec::new(); + let usable_space = 4096; + let mut i = 1000; + for seed in [15292777653676891381, 9261043168681395159] { + tracing::info!("seed {}", seed); + let mut rng = ChaCha8Rng::seed_from_u64(seed); + while i > 0 { + i -= 1; + match rng.next_u64() % 3 { + 0 => { + // allow appends with extra place to insert + let cell_idx = rng.next_u64() as usize % (page.cell_count() + 1); + let free = compute_free_space(page, usable_space); + let record = Record::new([OwnedValue::Integer(i as i64)].to_vec()); + let mut payload: Vec = Vec::new(); + fill_cell_payload( + page.page_type(), + Some(i as u64), + &mut payload, + &record, + 4096, + conn.pager.clone(), + ); + if (free as usize) < payload.len() - 2 { + // do not try to insert overflow pages because they require balancing + continue; + } + insert_into_cell(page, &payload, cell_idx, 4096).unwrap(); + assert!(page.overflow_cells.is_empty()); + total_size += payload.len() as u16 + 2; + cells.push(Cell { pos: i, payload }); + } + 1 => { + if page.cell_count() == 0 { + continue; + } + let cell_idx = rng.next_u64() as usize % page.cell_count(); + let (_, len) = page.cell_get_raw_region( + cell_idx, + payload_overflow_threshold_max(page.page_type(), 4096), + payload_overflow_threshold_min(page.page_type(), 4096), + usable_space as usize, + ); + drop_cell(page, cell_idx, usable_space).unwrap(); + total_size -= len as u16 + 2; + cells.remove(cell_idx); + } + 2 => { + defragment_page(page, usable_space); + } + _ => unreachable!(), + } + let free = compute_free_space(page, usable_space); + assert_eq!(free, 4096 - total_size - header_size); + } + } + } + #[test] pub fn test_free_space() { let db = get_database(); @@ -4002,7 +4092,7 @@ mod tests { let page = page.get_contents(); let usable_space = 4096; - let record = Record::new([OwnedValue::Integer(0 as i64)].to_vec()); + let record = Record::new([OwnedValue::Integer(0)].to_vec()); let payload = add_record(0, 0, page, record, &conn); assert_eq!(page.cell_count(), 1); @@ -4040,7 +4130,7 @@ mod tests { drop_cell(page, 0, usable_space).unwrap(); assert_eq!(page.cell_count(), 0); - let record = Record::new([OwnedValue::Integer(0 as i64)].to_vec()); + let record = Record::new([OwnedValue::Integer(0)].to_vec()); let payload = add_record(0, 0, page, record, &conn); assert_eq!(page.cell_count(), 1); diff --git a/core/storage/page_cache.rs b/core/storage/page_cache.rs index 2ead4a58b..2004c511a 100644 --- a/core/storage/page_cache.rs +++ b/core/storage/page_cache.rs @@ -62,37 +62,34 @@ impl DumbLruPageCache { pub fn insert(&mut self, key: PageCacheKey, value: PageRef) { self._delete(key.clone(), false); debug!("cache_insert(key={:?})", key); - let mut entry = Box::new(PageCacheEntry { + let entry = Box::new(PageCacheEntry { key: key.clone(), next: None, prev: None, page: value, }); - self.touch(&mut entry); + let ptr_raw = Box::into_raw(entry); + let ptr = unsafe { ptr_raw.as_mut().unwrap().as_non_null() }; + self.touch(ptr); - if self.map.borrow().len() >= self.capacity { + self.map.borrow_mut().insert(key, ptr); + if self.len() > self.capacity { self.pop_if_not_dirty(); } - let b = Box::into_raw(entry); - let as_non_null = NonNull::new(b).unwrap(); - self.map.borrow_mut().insert(key, as_non_null); } pub fn delete(&mut self, key: PageCacheKey) { + debug!("cache_delete(key={:?})", key); self._delete(key, true) } pub fn _delete(&mut self, key: PageCacheKey, clean_page: bool) { - debug!("cache_delete(key={:?}, clean={})", key, clean_page); let ptr = self.map.borrow_mut().remove(&key); if ptr.is_none() { return; } - let mut ptr = ptr.unwrap(); - { - let ptr = unsafe { ptr.as_mut() }; - self.detach(ptr, clean_page); - } + let ptr = ptr.unwrap(); + self.detach(ptr, clean_page); unsafe { std::ptr::drop_in_place(ptr.as_ptr()) }; } @@ -103,13 +100,18 @@ impl DumbLruPageCache { } pub fn get(&mut self, key: &PageCacheKey) -> Option { + self.peek(key, true) + } + + /// Get page without promoting entry + pub fn peek(&mut self, key: &PageCacheKey, touch: bool) -> Option { debug!("cache_get(key={:?})", key); - let ptr = self.get_ptr(key); - ptr?; - let ptr = unsafe { ptr.unwrap().as_mut() }; - let page = ptr.page.clone(); - //self.detach(ptr); - self.touch(ptr); + let mut ptr = self.get_ptr(key)?; + let page = unsafe { ptr.as_mut().page.clone() }; + if touch { + self.detach(ptr, false); + self.touch(ptr); + } Some(page) } @@ -118,19 +120,17 @@ impl DumbLruPageCache { todo!(); } - fn detach(&mut self, entry: &mut PageCacheEntry, clean_page: bool) { - let mut current = entry.as_non_null(); - + fn detach(&mut self, mut entry: NonNull, clean_page: bool) { if clean_page { // evict buffer - let page = &entry.page; + let page = unsafe { &entry.as_mut().page }; page.clear_loaded(); debug!("cleaning up page {}", page.get().id); let _ = page.get().contents.take(); } let (next, prev) = unsafe { - let c = current.as_mut(); + let c = entry.as_mut(); let next = c.next; let prev = c.prev; c.prev = None; @@ -140,9 +140,16 @@ impl DumbLruPageCache { // detach match (prev, next) { - (None, None) => {} - (None, Some(_)) => todo!(), - (Some(p), None) => { + (None, None) => { + self.head.replace(None); + self.tail.replace(None); + } + (None, Some(mut n)) => { + unsafe { n.as_mut().prev = None }; + self.head.borrow_mut().replace(n); + } + (Some(mut p), None) => { + unsafe { p.as_mut().next = None }; self.tail = RefCell::new(Some(p)); } (Some(mut p), Some(mut n)) => unsafe { @@ -154,19 +161,20 @@ impl DumbLruPageCache { }; } - fn touch(&mut self, entry: &mut PageCacheEntry) { - let mut current = entry.as_non_null(); - unsafe { - let c = current.as_mut(); - c.next = *self.head.borrow(); - } - + /// inserts into head, assuming we detached first + fn touch(&mut self, mut entry: NonNull) { if let Some(mut head) = *self.head.borrow_mut() { unsafe { + entry.as_mut().next.replace(head); let head = head.as_mut(); - head.prev = Some(current); + head.prev = Some(entry); } } + + if self.tail.borrow().is_none() { + self.tail.borrow_mut().replace(entry); + } + self.head.borrow_mut().replace(entry); } fn pop_if_not_dirty(&mut self) { @@ -174,12 +182,14 @@ impl DumbLruPageCache { if tail.is_none() { return; } - let tail = unsafe { tail.unwrap().as_mut() }; - if tail.page.is_dirty() { + let mut tail = tail.unwrap(); + let tail_entry = unsafe { tail.as_mut() }; + if tail_entry.page.is_dirty() { // TODO: drop from another clean entry? return; } self.detach(tail, true); + assert!(self.map.borrow_mut().remove(&tail_entry.key).is_some()); } pub fn clear(&mut self) { @@ -188,4 +198,148 @@ impl DumbLruPageCache { self.delete(key); } } + + pub fn print(&mut self) { + println!("page_cache={}", self.map.borrow().len()); + println!("page_cache={:?}", self.map.borrow()) + } + + pub fn len(&self) -> usize { + self.map.borrow().len() + } +} + +#[cfg(test)] +mod tests { + use std::{num::NonZeroUsize, sync::Arc}; + + use lru::LruCache; + use rand_chacha::{ + rand_core::{RngCore, SeedableRng}, + ChaCha8Rng, + }; + + use crate::{storage::page_cache::DumbLruPageCache, Page}; + + use super::PageCacheKey; + + #[test] + fn test_page_cache_evict() { + let mut cache = DumbLruPageCache::new(1); + let key1 = insert_page(&mut cache, 1); + let key2 = insert_page(&mut cache, 2); + assert_eq!(cache.get(&key2).unwrap().get().id, 2); + assert!(cache.get(&key1).is_none()); + } + + #[test] + fn test_page_cache_fuzz() { + let seed = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(); + let mut rng = ChaCha8Rng::seed_from_u64(seed); + tracing::info!("super seed: {}", seed); + let max_pages = 10; + let mut cache = DumbLruPageCache::new(10); + let mut lru = LruCache::new(NonZeroUsize::new(10).unwrap()); + + for _ in 0..10000 { + match rng.next_u64() % 3 { + 0 => { + // add + let id_page = rng.next_u64() % max_pages; + let id_frame = rng.next_u64() % max_pages; + let key = PageCacheKey::new(id_page as usize, Some(id_frame)); + #[allow(clippy::arc_with_non_send_sync)] + let page = Arc::new(Page::new(id_page as usize)); + // println!("inserting page {:?}", key); + cache.insert(key.clone(), page.clone()); + lru.push(key, page); + assert!(cache.len() <= 10); + } + 1 => { + // remove + let random = rng.next_u64() % 2 == 0; + let key = if random || lru.is_empty() { + let id_page = rng.next_u64() % max_pages; + let id_frame = rng.next_u64() % max_pages; + let key = PageCacheKey::new(id_page as usize, Some(id_frame)); + key + } else { + let i = rng.next_u64() as usize % lru.len(); + let key = lru.iter().skip(i).next().unwrap().0.clone(); + key + }; + // println!("removing page {:?}", key); + lru.pop(&key); + cache.delete(key); + } + 2 => { + // test contents + for (key, page) in &lru { + // println!("getting page {:?}", key); + cache.peek(&key, false).unwrap(); + assert_eq!(page.get().id, key.pgno); + } + } + _ => unreachable!(), + } + } + } + + #[test] + fn test_page_cache_insert_and_get() { + let mut cache = DumbLruPageCache::new(2); + let key1 = insert_page(&mut cache, 1); + let key2 = insert_page(&mut cache, 2); + assert_eq!(cache.get(&key1).unwrap().get().id, 1); + assert_eq!(cache.get(&key2).unwrap().get().id, 2); + } + + #[test] + fn test_page_cache_over_capacity() { + let mut cache = DumbLruPageCache::new(2); + let key1 = insert_page(&mut cache, 1); + let key2 = insert_page(&mut cache, 2); + let key3 = insert_page(&mut cache, 3); + assert!(cache.get(&key1).is_none()); + assert_eq!(cache.get(&key2).unwrap().get().id, 2); + assert_eq!(cache.get(&key3).unwrap().get().id, 3); + } + + #[test] + fn test_page_cache_delete() { + let mut cache = DumbLruPageCache::new(2); + let key1 = insert_page(&mut cache, 1); + cache.delete(key1.clone()); + assert!(cache.get(&key1).is_none()); + } + + #[test] + fn test_page_cache_clear() { + let mut cache = DumbLruPageCache::new(2); + let key1 = insert_page(&mut cache, 1); + let key2 = insert_page(&mut cache, 2); + cache.clear(); + assert!(cache.get(&key1).is_none()); + assert!(cache.get(&key2).is_none()); + } + + fn insert_page(cache: &mut DumbLruPageCache, id: usize) -> PageCacheKey { + let key = PageCacheKey::new(id, None); + #[allow(clippy::arc_with_non_send_sync)] + let page = Arc::new(Page::new(id)); + cache.insert(key.clone(), page.clone()); + key + } + + #[test] + fn test_page_cache_insert_sequential() { + let mut cache = DumbLruPageCache::new(2); + for i in 0..10000 { + let key = insert_page(&mut cache, i); + assert_eq!(cache.peek(&key, false).unwrap().get().id, i); + } + } } diff --git a/core/storage/pager.rs b/core/storage/pager.rs index 54ef933de..5c43eee1a 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -122,7 +122,7 @@ impl Page { } } -#[derive(Clone)] +#[derive(Clone, Debug)] enum FlushState { Start, WaitAppendFrames, @@ -247,6 +247,7 @@ impl Pager { match checkpoint_status { CheckpointStatus::IO => Ok(checkpoint_status), CheckpointStatus::Done(_) => { + self.wal.borrow().end_write_tx()?; self.wal.borrow().end_read_tx()?; Ok(checkpoint_status) } @@ -260,11 +261,11 @@ impl Pager { /// Reads a page from the database. pub fn read_page(&self, page_idx: usize) -> Result { - trace!("read_page(page_idx = {})", page_idx); + tracing::debug!("read_page(page_idx = {})", page_idx); let mut page_cache = self.page_cache.write(); let page_key = PageCacheKey::new(page_idx, Some(self.wal.borrow().get_max_frame())); if let Some(page) = page_cache.get(&page_key) { - trace!("read_page(page_idx = {}) = cached", page_idx); + tracing::debug!("read_page(page_idx = {}) = cached", page_idx); return Ok(page.clone()); } let page = Arc::new(Page::new(page_idx)); @@ -347,6 +348,7 @@ impl Pager { let mut checkpoint_result = CheckpointResult::new(); loop { let state = self.flush_info.borrow().state.clone(); + trace!("cacheflush {:?}", state); match state { FlushState::Start => { let db_size = self.db_header.lock().unwrap().database_size; @@ -362,6 +364,10 @@ impl Pager { db_size, self.flush_info.borrow().in_flight_writes.clone(), )?; + // This page is no longer valid. + // For example: + // We took page with key (page_num, max_frame) -- this page is no longer valid for that max_frame so it must be invalidated. + cache.delete(page_key); } self.dirty_pages.borrow_mut().clear(); self.flush_info.borrow_mut().state = FlushState::WaitAppendFrames; diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index a55d180f4..7ec197517 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -454,25 +454,25 @@ impl PageContent { } pub fn write_u8(&self, pos: usize, value: u8) { - tracing::debug!("write_u8(pos={}, value={})", pos, value); + tracing::trace!("write_u8(pos={}, value={})", pos, value); let buf = self.as_ptr(); buf[self.offset + pos] = value; } pub fn write_u16(&self, pos: usize, value: u16) { - tracing::debug!("write_u16(pos={}, value={})", pos, value); + tracing::trace!("write_u16(pos={}, value={})", pos, value); let buf = self.as_ptr(); buf[self.offset + pos..self.offset + pos + 2].copy_from_slice(&value.to_be_bytes()); } pub fn write_u16_no_offset(&self, pos: usize, value: u16) { - tracing::debug!("write_u16(pos={}, value={})", pos, value); + tracing::trace!("write_u16(pos={}, value={})", pos, value); let buf = self.as_ptr(); buf[pos..pos + 2].copy_from_slice(&value.to_be_bytes()); } pub fn write_u32(&self, pos: usize, value: u32) { - tracing::debug!("write_u32(pos={}, value={})", pos, value); + tracing::trace!("write_u32(pos={}, value={})", pos, value); let buf = self.as_ptr(); buf[self.offset + pos..self.offset + pos + 4].copy_from_slice(&value.to_be_bytes()); } @@ -562,7 +562,7 @@ impl PageContent { payload_overflow_threshold_min: usize, usable_size: usize, ) -> Result { - tracing::debug!("cell_get(idx={})", idx); + tracing::trace!("cell_get(idx={})", idx); let buf = self.as_ptr(); let ncells = self.cell_count(); @@ -674,6 +674,28 @@ impl PageContent { let buf = self.as_ptr(); write_header_to_buf(buf, header); } + + pub fn debug_print_freelist(&self, usable_space: u16) { + let mut pc = self.first_freeblock() as usize; + let mut block_num = 0; + println!("---- Free List Blocks ----"); + println!("first freeblock pointer: {}", pc); + println!("cell content area: {}", self.cell_content_area()); + println!("fragmented bytes: {}", self.num_frag_free_bytes()); + + while pc != 0 && pc <= usable_space as usize { + let next = self.read_u16_no_offset(pc); + let size = self.read_u16_no_offset(pc + 2); + + println!( + "block {}: position={}, size={}, next={}", + block_num, pc, size, next + ); + pc = next as usize; + block_num += 1; + } + println!("--------------"); + } } pub fn begin_read_page( @@ -1308,7 +1330,7 @@ pub fn begin_read_wal_frame( let frame = page.clone(); let complete = Box::new(move |buf: Arc>| { let frame = frame.clone(); - finish_read_page(2, buf, frame).unwrap(); + finish_read_page(page.get().id, buf, frame).unwrap(); }); let c = Completion::Read(ReadCompletion::new(buf, complete)); io.pread(offset, c)?; diff --git a/core/storage/wal.rs b/core/storage/wal.rs index e9e610253..d6442b2bf 100644 --- a/core/storage/wal.rs +++ b/core/storage/wal.rs @@ -176,6 +176,7 @@ pub trait Wal { mode: CheckpointMode, ) -> Result; fn sync(&mut self) -> Result; + fn get_max_frame_in_wal(&self) -> u64; fn get_max_frame(&self) -> u64; fn get_min_frame(&self) -> u64; } @@ -333,8 +334,8 @@ impl Wal for WalFile { } } - // If we didn't find any mark, then let's add a new one - if max_read_mark_index == -1 { + // If we didn't find any mark or we can update, let's update them + if (max_read_mark as u64) < max_frame_in_wal || max_read_mark_index == -1 { for (index, lock) in shared.read_locks.iter_mut().enumerate() { let busy = !lock.write(); if !busy { @@ -361,10 +362,11 @@ impl Wal for WalFile { self.max_frame = max_read_mark as u64; self.min_frame = shared.nbackfills + 1; tracing::debug!( - "begin_read_tx(min_frame={}, max_frame={}, lock={})", + "begin_read_tx(min_frame={}, max_frame={}, lock={}, max_frame_in_wal={})", self.min_frame, self.max_frame, - self.max_frame_read_lock_index + self.max_frame_read_lock_index, + max_frame_in_wal ); Ok(LimboResult::Ok) } @@ -500,14 +502,18 @@ impl Wal for WalFile { // TODO(pere): check what frames are safe to checkpoint between many readers! self.ongoing_checkpoint.min_frame = self.min_frame; let mut shared = self.shared.write(); - let max_frame_in_wal = shared.max_frame as u32; let mut max_safe_frame = shared.max_frame; - for read_lock in shared.read_locks.iter_mut() { + for (read_lock_idx, read_lock) in shared.read_locks.iter_mut().enumerate() { let this_mark = read_lock.value.load(Ordering::SeqCst); if this_mark < max_safe_frame as u32 { let busy = !read_lock.write(); if !busy { - read_lock.value.store(max_frame_in_wal, Ordering::SeqCst); + let new_mark = if read_lock_idx == 0 { + max_safe_frame as u32 + } else { + READMARK_NOT_USED + }; + read_lock.value.store(new_mark, Ordering::SeqCst); read_lock.unlock(); } else { max_safe_frame = this_mark as u64; @@ -613,6 +619,7 @@ impl Wal for WalFile { shared.pages_in_frames.clear(); shared.max_frame = 0; shared.nbackfills = 0; + // TODO(pere): truncate wal file here. } else { shared.nbackfills = self.ongoing_checkpoint.max_frame; } @@ -658,6 +665,10 @@ impl Wal for WalFile { } } + fn get_max_frame_in_wal(&self) -> u64 { + self.shared.read().max_frame + } + fn get_max_frame(&self) -> u64 { self.max_frame } diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 8a075bb54..24e7418b3 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -882,7 +882,7 @@ pub fn translate_expr( } #[cfg(feature = "json")] Func::Json(j) => match j { - JsonFunc::Json => { + JsonFunc::Json | JsonFunc::Jsonb => { let args = expect_arguments_exact!(args, 1, j); translate_function( diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index b65bd80d5..7621146d6 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -52,7 +52,7 @@ use crate::{ function::JsonFunc, json::get_json, json::is_json_valid, json::json_array, json::json_array_length, json::json_arrow_extract, json::json_arrow_shift_extract, json::json_error_position, json::json_extract, json::json_object, json::json_patch, - json::json_quote, json::json_remove, json::json_set, json::json_type, + json::json_quote, json::json_remove, json::json_set, json::json_type, json::jsonb, }; use crate::{info, CheckpointStatus}; use crate::{ @@ -2131,6 +2131,14 @@ impl Program { Err(e) => return Err(e), } } + JsonFunc::Jsonb => { + let json_value = &state.registers[*start_reg]; + let json_blob = jsonb(json_value); + match json_blob { + Ok(json) => state.registers[*dest] = json, + Err(e) => return Err(e), + } + } JsonFunc::JsonArray | JsonFunc::JsonObject => { let reg_values = &state.registers[*start_reg..*start_reg + arg_count]; @@ -3200,6 +3208,7 @@ impl Program { connection.deref(), ), TransactionState::Read => { + connection.transaction_state.replace(TransactionState::None); pager.end_read_tx()?; Ok(StepResult::Done) } @@ -3226,19 +3235,20 @@ impl Program { let checkpoint_status = pager.end_tx()?; match checkpoint_status { CheckpointStatus::Done(_) => { + if self.change_cnt_on { + if let Some(conn) = self.connection.upgrade() { + conn.set_changes(self.n_change.get()); + } + } connection.transaction_state.replace(TransactionState::None); let _ = halt_state.take(); } CheckpointStatus::IO => { + tracing::trace!("Checkpointing IO"); *halt_state = Some(HaltState::Checkpointing); return Ok(StepResult::IO); } } - if self.change_cnt_on { - if let Some(conn) = self.connection.upgrade() { - conn.set_changes(self.n_change.get()); - } - } Ok(StepResult::Done) } } diff --git a/extensions/core/Cargo.toml b/extensions/core/Cargo.toml index 1389e39c1..c6450a33d 100644 --- a/extensions/core/Cargo.toml +++ b/extensions/core/Cargo.toml @@ -13,3 +13,7 @@ static = [] [dependencies] limbo_macros = { workspace = true } + +[target.'cfg(not(target_family = "wasm"))'.dependencies] +getrandom = "0.3.1" +chrono = "0.4.40" diff --git a/extensions/core/README.md b/extensions/core/README.md index fd514165b..ae848b0d7 100644 --- a/extensions/core/README.md +++ b/extensions/core/README.md @@ -10,7 +10,7 @@ like traditional `sqlite3` extensions, but are able to be written in much more e - [ x ] **Scalar Functions**: Create scalar functions using the `scalar` macro. - [ x ] **Aggregate Functions**: Define aggregate functions with `AggregateDerive` macro and `AggFunc` trait. - [ x ] **Virtual tables**: Create a module for a virtual table with the `VTabModuleDerive` macro and `VTabCursor` trait. - - [] **VFS Modules** + - [ x ] **VFS Modules**: Extend Limbo's OS interface by implementing `VfsExtension` and `VfsFile` traits. --- ## Installation @@ -59,9 +59,14 @@ register_extension!{ scalars: { double }, // name of your function, if different from attribute name aggregates: { Percentile }, vtabs: { CsvVTable }, + vfs: { ExampleFS }, } ``` +**NOTE**: Currently, any Derive macro used from this crate is required to be in the same +file as the `register_extension` macro. + + ### Scalar Example: ```rust use limbo_ext::{register_extension, Value, scalar}; @@ -279,6 +284,106 @@ impl VTabCursor for CsvCursor { } ``` +### VFS Example + + +```rust +use limbo_ext::{ExtResult as Result, VfsDerive, VfsExtension, VfsFile}; + +/// Your struct must also impl Default +#[derive(VfsDerive, Default)] +struct ExampleFS; + + +struct ExampleFile { + file: std::fs::File, +} + +impl VfsExtension for ExampleFS { + /// The name of your vfs module + const NAME: &'static str = "example"; + + type File = ExampleFile; + + fn open(&self, path: &str, flags: i32, _direct: bool) -> Result { + let file = OpenOptions::new() + .read(true) + .write(true) + .create(flags & 1 != 0) + .open(path) + .map_err(|_| ResultCode::Error)?; + Ok(TestFile { file }) + } + + fn run_once(&self) -> Result<()> { + // (optional) method to cycle/advance IO, if your extension is asynchronous + Ok(()) + } + + fn close(&self, file: Self::File) -> Result<()> { + // (optional) method to close or drop the file + Ok(()) + } + + fn generate_random_number(&self) -> i64 { + // (optional) method to generate random number. Used for testing + let mut buf = [0u8; 8]; + getrandom::fill(&mut buf).unwrap(); + i64::from_ne_bytes(buf) + } + + fn get_current_time(&self) -> String { + // (optional) method to generate random number. Used for testing + chrono::Local::now().format("%Y-%m-%d %H:%M:%S").to_string() + } +} + +impl VfsFile for ExampleFile { + fn read( + &mut self, + buf: &mut [u8], + count: usize, + offset: i64, + ) -> Result { + if file.file.seek(SeekFrom::Start(offset as u64)).is_err() { + return Err(ResultCode::Error); + } + file.file + .read(&mut buf[..count]) + .map_err(|_| ResultCode::Error) + .map(|n| n as i32) + } + + fn write(&mut self, buf: &[u8], count: usize, offset: i64) -> Result { + if self.file.seek(SeekFrom::Start(offset as u64)).is_err() { + return Err(ResultCode::Error); + } + self.file + .write(&buf[..count]) + .map_err(|_| ResultCode::Error) + .map(|n| n as i32) + } + + fn sync(&self) -> Result<()> { + self.file.sync_all().map_err(|_| ResultCode::Error) + } + + fn lock(&self, _exclusive: bool) -> Result<()> { + // (optional) method to lock the file + Ok(()) + } + + fn unlock(&self) -> Result<()> { + // (optional) method to lock the file + Ok(()) + } + + fn size(&self) -> i64 { + self.file.metadata().map(|m| m.len() as i64).unwrap_or(-1) + } +} +``` + ## Cargo.toml Config Edit the workspace `Cargo.toml` to include your extension as a workspace dependency, e.g: diff --git a/extensions/core/src/lib.rs b/extensions/core/src/lib.rs index 03a4cac85..e81684e57 100644 --- a/extensions/core/src/lib.rs +++ b/extensions/core/src/lib.rs @@ -1,10 +1,16 @@ mod types; +mod vfs_modules; +#[cfg(not(target_family = "wasm"))] +pub use limbo_macros::VfsDerive; pub use limbo_macros::{register_extension, scalar, AggregateDerive, VTabModuleDerive}; use std::{ fmt::Display, os::raw::{c_char, c_void}, }; pub use types::{ResultCode, Value, ValueType}; +pub use vfs_modules::{RegisterVfsFn, VfsFileImpl, VfsImpl}; +#[cfg(not(target_family = "wasm"))] +pub use vfs_modules::{VfsExtension, VfsFile}; pub type ExtResult = std::result::Result; @@ -14,6 +20,36 @@ pub struct ExtensionApi { pub register_scalar_function: RegisterScalarFn, pub register_aggregate_function: RegisterAggFn, pub register_module: RegisterModuleFn, + pub register_vfs: RegisterVfsFn, + pub builtin_vfs: *mut *const VfsImpl, + pub builtin_vfs_count: i32, +} +unsafe impl Send for ExtensionApi {} +unsafe impl Send for ExtensionApiRef {} + +#[repr(C)] +pub struct ExtensionApiRef { + pub api: *const ExtensionApi, +} + +impl ExtensionApi { + /// Since we want the option to build in extensions at compile time as well, + /// we add a slice of VfsImpls to the extension API, and this is called with any + /// libraries that we load staticly that will add their VFS implementations to the list. + pub fn add_builtin_vfs(&mut self, vfs: *const VfsImpl) -> ResultCode { + if vfs.is_null() || self.builtin_vfs.is_null() { + return ResultCode::Error; + } + let mut new = unsafe { + let slice = + std::slice::from_raw_parts_mut(self.builtin_vfs, self.builtin_vfs_count as usize); + Vec::from(slice) + }; + new.push(vfs); + self.builtin_vfs = Box::into_raw(new.into_boxed_slice()) as *mut *const VfsImpl; + self.builtin_vfs_count += 1; + ResultCode::OK + } } pub type ExtensionEntryPoint = unsafe extern "C" fn(api: *const ExtensionApi) -> ResultCode; diff --git a/extensions/core/src/types.rs b/extensions/core/src/types.rs index 618f37e18..90adb3863 100644 --- a/extensions/core/src/types.rs +++ b/extensions/core/src/types.rs @@ -165,6 +165,7 @@ impl TextValue { }) } + #[cfg(feature = "core_only")] fn free(self) { if !self.text.is_null() { let _ = unsafe { Box::from_raw(self.text as *mut u8) }; @@ -231,7 +232,7 @@ impl Blob { } unsafe { std::slice::from_raw_parts(self.data, self.size as usize) } } - + #[cfg(feature = "core_only")] fn free(self) { if !self.data.is_null() { let _ = unsafe { Box::from_raw(self.data as *mut u8) }; diff --git a/extensions/core/src/vfs_modules.rs b/extensions/core/src/vfs_modules.rs new file mode 100644 index 000000000..67fd7c020 --- /dev/null +++ b/extensions/core/src/vfs_modules.rs @@ -0,0 +1,114 @@ +use crate::{ExtResult, ResultCode}; +use std::ffi::{c_char, c_void}; + +#[cfg(not(target_family = "wasm"))] +pub trait VfsExtension: Default + Send + Sync { + const NAME: &'static str; + type File: VfsFile; + fn open_file(&self, path: &str, flags: i32, direct: bool) -> ExtResult; + fn run_once(&self) -> ExtResult<()> { + Ok(()) + } + fn close(&self, _file: Self::File) -> ExtResult<()> { + Ok(()) + } + fn generate_random_number(&self) -> i64 { + let mut buf = [0u8; 8]; + getrandom::fill(&mut buf).unwrap(); + i64::from_ne_bytes(buf) + } + fn get_current_time(&self) -> String { + chrono::Local::now().format("%Y-%m-%d %H:%M:%S").to_string() + } +} +#[cfg(not(target_family = "wasm"))] +pub trait VfsFile: Send + Sync { + fn lock(&mut self, _exclusive: bool) -> ExtResult<()> { + Ok(()) + } + fn unlock(&self) -> ExtResult<()> { + Ok(()) + } + fn read(&mut self, buf: &mut [u8], count: usize, offset: i64) -> ExtResult; + fn write(&mut self, buf: &[u8], count: usize, offset: i64) -> ExtResult; + fn sync(&self) -> ExtResult<()>; + fn size(&self) -> i64; +} + +#[repr(C)] +pub struct VfsImpl { + pub name: *const c_char, + pub vfs: *const c_void, + pub open: VfsOpen, + pub close: VfsClose, + pub read: VfsRead, + pub write: VfsWrite, + pub sync: VfsSync, + pub lock: VfsLock, + pub unlock: VfsUnlock, + pub size: VfsSize, + pub run_once: VfsRunOnce, + pub current_time: VfsGetCurrentTime, + pub gen_random_number: VfsGenerateRandomNumber, +} + +pub type RegisterVfsFn = + unsafe extern "C" fn(name: *const c_char, vfs: *const VfsImpl) -> ResultCode; + +pub type VfsOpen = unsafe extern "C" fn( + ctx: *const c_void, + path: *const c_char, + flags: i32, + direct: bool, +) -> *const c_void; + +pub type VfsClose = unsafe extern "C" fn(file: *const c_void) -> ResultCode; + +pub type VfsRead = + unsafe extern "C" fn(file: *const c_void, buf: *mut u8, count: usize, offset: i64) -> i32; + +pub type VfsWrite = + unsafe extern "C" fn(file: *const c_void, buf: *const u8, count: usize, offset: i64) -> i32; + +pub type VfsSync = unsafe extern "C" fn(file: *const c_void) -> i32; + +pub type VfsLock = unsafe extern "C" fn(file: *const c_void, exclusive: bool) -> ResultCode; + +pub type VfsUnlock = unsafe extern "C" fn(file: *const c_void) -> ResultCode; + +pub type VfsSize = unsafe extern "C" fn(file: *const c_void) -> i64; + +pub type VfsRunOnce = unsafe extern "C" fn(file: *const c_void) -> ResultCode; + +pub type VfsGetCurrentTime = unsafe extern "C" fn() -> *const c_char; + +pub type VfsGenerateRandomNumber = unsafe extern "C" fn() -> i64; + +#[repr(C)] +pub struct VfsFileImpl { + pub file: *const c_void, + pub vfs: *const VfsImpl, +} +unsafe impl Send for VfsFileImpl {} +unsafe impl Sync for VfsFileImpl {} + +impl VfsFileImpl { + pub fn new(file: *const c_void, vfs: *const VfsImpl) -> ExtResult { + if file.is_null() || vfs.is_null() { + return Err(ResultCode::Error); + } + Ok(Self { file, vfs }) + } +} + +impl Drop for VfsFileImpl { + fn drop(&mut self) { + if self.vfs.is_null() || self.file.is_null() { + return; + } + let vfs = unsafe { &*self.vfs }; + unsafe { + (vfs.close)(self.file); + } + } +} diff --git a/extensions/kvstore/Cargo.toml b/extensions/tests/Cargo.toml similarity index 75% rename from extensions/kvstore/Cargo.toml rename to extensions/tests/Cargo.toml index cac010bb6..aa3ba8fdb 100644 --- a/extensions/kvstore/Cargo.toml +++ b/extensions/tests/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "limbo_kv" +name = "limbo_ext_tests" version.workspace = true authors.workspace = true edition.workspace = true @@ -13,8 +13,10 @@ crate-type = ["cdylib", "lib"] static= [ "limbo_ext/static" ] [dependencies] +env_logger = "0.11.6" lazy_static = "1.5.0" limbo_ext = { workspace = true, features = ["static"] } +log = "0.4.26" [target.'cfg(not(target_family = "wasm"))'.dependencies] -mimalloc = { version = "*", default-features = false } +mimalloc = { version = "0.1", default-features = false } diff --git a/extensions/kvstore/src/lib.rs b/extensions/tests/src/lib.rs similarity index 61% rename from extensions/kvstore/src/lib.rs rename to extensions/tests/src/lib.rs index a9de7c71d..92e4f874f 100644 --- a/extensions/kvstore/src/lib.rs +++ b/extensions/tests/src/lib.rs @@ -1,16 +1,23 @@ use lazy_static::lazy_static; use limbo_ext::{ - register_extension, ResultCode, VTabCursor, VTabKind, VTabModule, VTabModuleDerive, Value, + register_extension, scalar, ExtResult, ResultCode, VTabCursor, VTabKind, VTabModule, + VTabModuleDerive, Value, }; +#[cfg(not(target_family = "wasm"))] +use limbo_ext::{VfsDerive, VfsExtension, VfsFile}; use std::collections::BTreeMap; +use std::fs::{File, OpenOptions}; +use std::io::{Read, Seek, SeekFrom, Write}; use std::sync::Mutex; -lazy_static! { - static ref GLOBAL_STORE: Mutex> = Mutex::new(BTreeMap::new()); -} - register_extension! { vtabs: { KVStoreVTab }, + scalars: { test_scalar }, + vfs: { TestFS }, +} + +lazy_static! { + static ref GLOBAL_STORE: Mutex> = Mutex::new(BTreeMap::new()); } #[derive(VTabModuleDerive, Default)] @@ -128,7 +135,7 @@ impl VTabCursor for KVStoreCursor { if self.index.is_some_and(|c| c < self.rows.len()) { self.rows[self.index.unwrap_or(0)].0 } else { - println!("rowid: -1"); + log::error!("rowid: -1"); -1 } } @@ -145,3 +152,72 @@ impl VTabCursor for KVStoreCursor { ::next(self) } } + +pub struct TestFile { + file: File, +} + +#[cfg(target_family = "wasm")] +pub struct TestFS; + +#[cfg(not(target_family = "wasm"))] +#[derive(VfsDerive, Default)] +pub struct TestFS; + +// Test that we can have additional extension types in the same file +// and still register the vfs at comptime if linking staticly +#[scalar(name = "test_scalar")] +fn test_scalar(_args: limbo_ext::Value) -> limbo_ext::Value { + limbo_ext::Value::from_integer(42) +} + +#[cfg(not(target_family = "wasm"))] +impl VfsExtension for TestFS { + const NAME: &'static str = "testvfs"; + type File = TestFile; + fn open_file(&self, path: &str, flags: i32, _direct: bool) -> ExtResult { + let _ = env_logger::try_init(); + log::debug!("opening file with testing VFS: {} flags: {}", path, flags); + let file = OpenOptions::new() + .read(true) + .write(true) + .create(flags & 1 != 0) + .open(path) + .map_err(|_| ResultCode::Error)?; + Ok(TestFile { file }) + } +} + +#[cfg(not(target_family = "wasm"))] +impl VfsFile for TestFile { + fn read(&mut self, buf: &mut [u8], count: usize, offset: i64) -> ExtResult { + log::debug!("reading file with testing VFS: bytes: {count} offset: {offset}"); + if self.file.seek(SeekFrom::Start(offset as u64)).is_err() { + return Err(ResultCode::Error); + } + self.file + .read(&mut buf[..count]) + .map_err(|_| ResultCode::Error) + .map(|n| n as i32) + } + + fn write(&mut self, buf: &[u8], count: usize, offset: i64) -> ExtResult { + log::debug!("writing to file with testing VFS: bytes: {count} offset: {offset}"); + if self.file.seek(SeekFrom::Start(offset as u64)).is_err() { + return Err(ResultCode::Error); + } + self.file + .write(&buf[..count]) + .map_err(|_| ResultCode::Error) + .map(|n| n as i32) + } + + fn sync(&self) -> ExtResult<()> { + log::debug!("syncing file with testing VFS"); + self.file.sync_all().map_err(|_| ResultCode::Error) + } + + fn size(&self) -> i64 { + self.file.metadata().map(|m| m.len() as i64).unwrap_or(-1) + } +} diff --git a/macros/src/args.rs b/macros/src/args.rs index 12446b660..b0d45d20a 100644 --- a/macros/src/args.rs +++ b/macros/src/args.rs @@ -7,6 +7,7 @@ pub(crate) struct RegisterExtensionInput { pub aggregates: Vec, pub scalars: Vec, pub vtabs: Vec, + pub vfs: Vec, } impl syn::parse::Parse for RegisterExtensionInput { @@ -14,11 +15,12 @@ impl syn::parse::Parse for RegisterExtensionInput { let mut aggregates = Vec::new(); let mut scalars = Vec::new(); let mut vtabs = Vec::new(); + let mut vfs = Vec::new(); while !input.is_empty() { if input.peek(syn::Ident) && input.peek2(Token![:]) { let section_name: Ident = input.parse()?; input.parse::()?; - let names = ["aggregates", "scalars", "vtabs"]; + let names = ["aggregates", "scalars", "vtabs", "vfs"]; if names.contains(§ion_name.to_string().as_str()) { let content; syn::braced!(content in input); @@ -30,6 +32,7 @@ impl syn::parse::Parse for RegisterExtensionInput { "aggregates" => aggregates = parsed_items, "scalars" => scalars = parsed_items, "vtabs" => vtabs = parsed_items, + "vfs" => vfs = parsed_items, _ => unreachable!(), }; @@ -48,6 +51,7 @@ impl syn::parse::Parse for RegisterExtensionInput { aggregates, scalars, vtabs, + vfs, }) } } diff --git a/macros/src/lib.rs b/macros/src/lib.rs index df8f8bd85..0fd69a4db 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -623,6 +623,222 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { TokenStream::from(expanded) } +#[proc_macro_derive(VfsDerive)] +pub fn derive_vfs_module(input: TokenStream) -> TokenStream { + let derive_input = parse_macro_input!(input as DeriveInput); + let struct_name = &derive_input.ident; + let register_fn_name = format_ident!("register_{}", struct_name); + let register_static = format_ident!("register_static_{}", struct_name); + let open_fn_name = format_ident!("{}_open", struct_name); + let close_fn_name = format_ident!("{}_close", struct_name); + let read_fn_name = format_ident!("{}_read", struct_name); + let write_fn_name = format_ident!("{}_write", struct_name); + let lock_fn_name = format_ident!("{}_lock", struct_name); + let unlock_fn_name = format_ident!("{}_unlock", struct_name); + let sync_fn_name = format_ident!("{}_sync", struct_name); + let size_fn_name = format_ident!("{}_size", struct_name); + let run_once_fn_name = format_ident!("{}_run_once", struct_name); + let generate_random_number_fn_name = format_ident!("{}_generate_random_number", struct_name); + let get_current_time_fn_name = format_ident!("{}_get_current_time", struct_name); + + let expanded = quote! { + #[allow(non_snake_case)] + pub unsafe extern "C" fn #register_static() -> *const ::limbo_ext::VfsImpl { + let ctx = #struct_name::default(); + let ctx = ::std::boxed::Box::into_raw(::std::boxed::Box::new(ctx)) as *const ::std::ffi::c_void; + let name = ::std::ffi::CString::new(<#struct_name as ::limbo_ext::VfsExtension>::NAME).unwrap().into_raw(); + let vfs_mod = ::limbo_ext::VfsImpl { + vfs: ctx, + name, + open: #open_fn_name, + close: #close_fn_name, + read: #read_fn_name, + write: #write_fn_name, + lock: #lock_fn_name, + unlock: #unlock_fn_name, + sync: #sync_fn_name, + size: #size_fn_name, + run_once: #run_once_fn_name, + gen_random_number: #generate_random_number_fn_name, + current_time: #get_current_time_fn_name, + }; + ::std::boxed::Box::into_raw(::std::boxed::Box::new(vfs_mod)) as *const ::limbo_ext::VfsImpl + } + + #[no_mangle] + pub unsafe extern "C" fn #register_fn_name(api: &::limbo_ext::ExtensionApi) -> ::limbo_ext::ResultCode { + let ctx = #struct_name::default(); + let ctx = ::std::boxed::Box::into_raw(::std::boxed::Box::new(ctx)) as *const ::std::ffi::c_void; + let name = ::std::ffi::CString::new(<#struct_name as ::limbo_ext::VfsExtension>::NAME).unwrap().into_raw(); + let vfs_mod = ::limbo_ext::VfsImpl { + vfs: ctx, + name, + open: #open_fn_name, + close: #close_fn_name, + read: #read_fn_name, + write: #write_fn_name, + lock: #lock_fn_name, + unlock: #unlock_fn_name, + sync: #sync_fn_name, + size: #size_fn_name, + run_once: #run_once_fn_name, + gen_random_number: #generate_random_number_fn_name, + current_time: #get_current_time_fn_name, + }; + let vfsimpl = ::std::boxed::Box::into_raw(::std::boxed::Box::new(vfs_mod)) as *const ::limbo_ext::VfsImpl; + (api.register_vfs)(name, vfsimpl) + } + + #[no_mangle] + pub unsafe extern "C" fn #open_fn_name( + ctx: *const ::std::ffi::c_void, + path: *const ::std::ffi::c_char, + flags: i32, + direct: bool, + ) -> *const ::std::ffi::c_void { + let ctx = &*(ctx as *const ::limbo_ext::VfsImpl); + let Ok(path_str) = ::std::ffi::CStr::from_ptr(path).to_str() else { + return ::std::ptr::null_mut(); + }; + let vfs = &*(ctx.vfs as *const #struct_name); + let Ok(file_handle) = <#struct_name as ::limbo_ext::VfsExtension>::open_file(vfs, path_str, flags, direct) else { + return ::std::ptr::null(); + }; + let boxed = ::std::boxed::Box::into_raw(::std::boxed::Box::new(file_handle)) as *const ::std::ffi::c_void; + let Ok(vfs_file) = ::limbo_ext::VfsFileImpl::new(boxed, ctx) else { + return ::std::ptr::null(); + }; + ::std::boxed::Box::into_raw(::std::boxed::Box::new(vfs_file)) as *const ::std::ffi::c_void + } + + #[no_mangle] + pub unsafe extern "C" fn #close_fn_name(file_ptr: *const ::std::ffi::c_void) -> ::limbo_ext::ResultCode { + if file_ptr.is_null() { + return ::limbo_ext::ResultCode::Error; + } + let vfs_file: &mut ::limbo_ext::VfsFileImpl = &mut *(file_ptr as *mut ::limbo_ext::VfsFileImpl); + let vfs_instance = &*(vfs_file.vfs as *const #struct_name); + + // this time we need to own it so we can drop it + let file: ::std::boxed::Box<<#struct_name as ::limbo_ext::VfsExtension>::File> = + ::std::boxed::Box::from_raw(vfs_file.file as *mut <#struct_name as ::limbo_ext::VfsExtension>::File); + if let Err(e) = <#struct_name as ::limbo_ext::VfsExtension>::close(vfs_instance, *file) { + return e; + } + ::limbo_ext::ResultCode::OK + } + + #[no_mangle] + pub unsafe extern "C" fn #read_fn_name(file_ptr: *const ::std::ffi::c_void, buf: *mut u8, count: usize, offset: i64) -> i32 { + if file_ptr.is_null() { + return -1; + } + let vfs_file: &mut ::limbo_ext::VfsFileImpl = &mut *(file_ptr as *mut ::limbo_ext::VfsFileImpl); + let file: &mut <#struct_name as ::limbo_ext::VfsExtension>::File = + &mut *(vfs_file.file as *mut <#struct_name as ::limbo_ext::VfsExtension>::File); + match <#struct_name as ::limbo_ext::VfsExtension>::File::read(file, ::std::slice::from_raw_parts_mut(buf, count), count, offset) { + Ok(n) => n, + Err(_) => -1, + } + } + + #[no_mangle] + pub unsafe extern "C" fn #run_once_fn_name(ctx: *const ::std::ffi::c_void) -> ::limbo_ext::ResultCode { + if ctx.is_null() { + return ::limbo_ext::ResultCode::Error; + } + let ctx = &mut *(ctx as *mut #struct_name); + if let Err(e) = <#struct_name as ::limbo_ext::VfsExtension>::run_once(ctx) { + return e; + } + ::limbo_ext::ResultCode::OK + } + + #[no_mangle] + pub unsafe extern "C" fn #write_fn_name(file_ptr: *const ::std::ffi::c_void, buf: *const u8, count: usize, offset: i64) -> i32 { + if file_ptr.is_null() { + return -1; + } + let vfs_file: &mut ::limbo_ext::VfsFileImpl = &mut *(file_ptr as *mut ::limbo_ext::VfsFileImpl); + let file: &mut <#struct_name as ::limbo_ext::VfsExtension>::File = + &mut *(vfs_file.file as *mut <#struct_name as ::limbo_ext::VfsExtension>::File); + match <#struct_name as ::limbo_ext::VfsExtension>::File::write(file, ::std::slice::from_raw_parts(buf, count), count, offset) { + Ok(n) => n, + Err(_) => -1, + } + } + + #[no_mangle] + pub unsafe extern "C" fn #lock_fn_name(file_ptr: *const ::std::ffi::c_void, exclusive: bool) -> ::limbo_ext::ResultCode { + if file_ptr.is_null() { + return ::limbo_ext::ResultCode::Error; + } + let vfs_file: &mut ::limbo_ext::VfsFileImpl = &mut *(file_ptr as *mut ::limbo_ext::VfsFileImpl); + let file: &mut <#struct_name as ::limbo_ext::VfsExtension>::File = + &mut *(vfs_file.file as *mut <#struct_name as ::limbo_ext::VfsExtension>::File); + if let Err(e) = <#struct_name as ::limbo_ext::VfsExtension>::File::lock(file, exclusive) { + return e; + } + ::limbo_ext::ResultCode::OK + } + + #[no_mangle] + pub unsafe extern "C" fn #unlock_fn_name(file_ptr: *const ::std::ffi::c_void) -> ::limbo_ext::ResultCode { + if file_ptr.is_null() { + return ::limbo_ext::ResultCode::Error; + } + let vfs_file: &mut ::limbo_ext::VfsFileImpl = &mut *(file_ptr as *mut ::limbo_ext::VfsFileImpl); + let file: &mut <#struct_name as ::limbo_ext::VfsExtension>::File = + &mut *(vfs_file.file as *mut <#struct_name as ::limbo_ext::VfsExtension>::File); + if let Err(e) = <#struct_name as ::limbo_ext::VfsExtension>::File::unlock(file) { + return e; + } + ::limbo_ext::ResultCode::OK + } + + #[no_mangle] + pub unsafe extern "C" fn #sync_fn_name(file_ptr: *const ::std::ffi::c_void) -> i32 { + if file_ptr.is_null() { + return -1; + } + let vfs_file: &mut ::limbo_ext::VfsFileImpl = &mut *(file_ptr as *mut ::limbo_ext::VfsFileImpl); + let file: &mut <#struct_name as ::limbo_ext::VfsExtension>::File = + &mut *(vfs_file.file as *mut <#struct_name as ::limbo_ext::VfsExtension>::File); + if <#struct_name as ::limbo_ext::VfsExtension>::File::sync(file).is_err() { + return -1; + } + 0 + } + + #[no_mangle] + pub unsafe extern "C" fn #size_fn_name(file_ptr: *const ::std::ffi::c_void) -> i64 { + if file_ptr.is_null() { + return -1; + } + let vfs_file: &mut ::limbo_ext::VfsFileImpl = &mut *(file_ptr as *mut ::limbo_ext::VfsFileImpl); + let file: &mut <#struct_name as ::limbo_ext::VfsExtension>::File = + &mut *(vfs_file.file as *mut <#struct_name as ::limbo_ext::VfsExtension>::File); + <#struct_name as ::limbo_ext::VfsExtension>::File::size(file) + } + + #[no_mangle] + pub unsafe extern "C" fn #generate_random_number_fn_name() -> i64 { + let obj = #struct_name::default(); + <#struct_name as ::limbo_ext::VfsExtension>::generate_random_number(&obj) + } + + #[no_mangle] + pub unsafe extern "C" fn #get_current_time_fn_name() -> *const ::std::ffi::c_char { + let obj = #struct_name::default(); + let time = <#struct_name as ::limbo_ext::VfsExtension>::get_current_time(&obj); + // release ownership of the string to core + ::std::ffi::CString::new(time).unwrap().into_raw() as *const ::std::ffi::c_char + } + }; + + TokenStream::from(expanded) +} + /// Register your extension with 'core' by providing the relevant functions ///```ignore ///use limbo_ext::{register_extension, scalar, Value, AggregateDerive, AggFunc}; @@ -662,6 +878,7 @@ pub fn register_extension(input: TokenStream) -> TokenStream { aggregates, scalars, vtabs, + vfs, } = input_ast; let scalar_calls = scalars.iter().map(|scalar_ident| { @@ -699,6 +916,29 @@ pub fn register_extension(input: TokenStream) -> TokenStream { } } }); + let vfs_calls = vfs.iter().map(|vfs_ident| { + let register_fn = syn::Ident::new(&format!("register_{}", vfs_ident), vfs_ident.span()); + quote! { + { + let result = unsafe { #register_fn(api) }; + if !result.is_ok() { + return result; + } + } + } + }); + let static_vfs = vfs.iter().map(|vfs_ident| { + let static_register = + syn::Ident::new(&format!("register_static_{}", vfs_ident), vfs_ident.span()); + quote! { + { + let result = api.add_builtin_vfs(unsafe { #static_register()}); + if !result.is_ok() { + return result; + } + } + } + }); let static_aggregates = aggregate_calls.clone(); let static_scalars = scalar_calls.clone(); let static_vtabs = vtab_calls.clone(); @@ -710,27 +950,30 @@ pub fn register_extension(input: TokenStream) -> TokenStream { static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; #[cfg(feature = "static")] - pub unsafe extern "C" fn register_extension_static(api: &::limbo_ext::ExtensionApi) -> ::limbo_ext::ResultCode { - let api = unsafe { &*api }; + pub unsafe extern "C" fn register_extension_static(api: &mut ::limbo_ext::ExtensionApi) -> ::limbo_ext::ResultCode { #(#static_scalars)* #(#static_aggregates)* #(#static_vtabs)* + #[cfg(not(target_family = "wasm"))] + #(#static_vfs)* + ::limbo_ext::ResultCode::OK } #[cfg(not(feature = "static"))] #[no_mangle] pub unsafe extern "C" fn register_extension(api: &::limbo_ext::ExtensionApi) -> ::limbo_ext::ResultCode { - let api = unsafe { &*api }; #(#scalar_calls)* #(#aggregate_calls)* #(#vtab_calls)* + #(#vfs_calls)* + ::limbo_ext::ResultCode::OK } }; diff --git a/sqlite3/README.md b/sqlite3/README.md new file mode 100644 index 000000000..e58ea8c4a --- /dev/null +++ b/sqlite3/README.md @@ -0,0 +1,105 @@ +# SQLite3 Implementation for Limbo + +This directory contains a Rust implementation of the SQLite3 C API. The implementation serves as a compatibility layer between SQLite's C API and Limbo's native Rust database implementation. + +## Purpose + +This implementation provides SQLite3 API compatibility for Limbo, allowing existing applications that use SQLite to work with Limbo without modification. The code: + +1. Implements the SQLite3 C API functions in Rust +2. Translates between C and Rust data structures +3. Maps SQLite operations to equivalent Limbo operations +4. Maintains API compatibility with SQLite version 3.42.0 + +## Testing Strategy + +We employ a dual-testing approach to ensure complete compatibility with SQLite: + +### Test Database Setup + +Before running tests, you need to set up a test database: + +```bash +# Create testing directory +mkdir -p ../../testing + +# Create and initialize test database +sqlite3 ../../testing/testing.db ".databases" +``` + +This creates an empty SQLite database that both test suites will use. + +### 1. C Test Suite (`/tests`) +- Written in C to test the exact same API that real applications use +- Can be compiled and run against both: + - Official SQLite library (for verification) + - Our Rust implementation (for validation) +- Serves as the "source of truth" for correct behavior + +To run C tests against official SQLite: +```bash +cd tests +make clean +make LIBS="-lsqlite3" +./sqlite3-tests +``` + +To run C tests against our implementation: +```bash +cd tests +make clean +make LIBS="-L../target/debug -lsqlite3" +./sqlite3-tests +``` + +### 2. Rust Tests (`src/lib.rs`) +- Unit tests written in Rust +- Test the same functionality as C tests +- Provide better debugging capabilities +- Help with development and implementation + +To run Rust tests: +```bash +cargo test +``` + +### Why Two Test Suites? + +1. **Behavior Verification**: C tests ensure our implementation matches SQLite's behavior exactly by running the same tests against both +2. **Development Efficiency**: Rust tests provide better debugging and development experience +3. **Complete Coverage**: Both test suites together provide comprehensive testing from both C and Rust perspectives + +### Common Test Issues + +1. **Missing Test Database** + - Error: `SQLITE_CANTOPEN (14)` in tests + - Solution: Create test database as shown in "Test Database Setup" + +2. **Wrong Database Path** + - Tests expect database at `../../testing/testing.db` + - Verify path relative to where tests are run + +3. **Permission Issues** + - Ensure test database is readable/writable + - Check directory permissions + +## Implementation Notes + +- All public functions are marked with `#[no_mangle]` and follow SQLite's C API naming convention +- Uses `unsafe` blocks for C API compatibility +- Implements error handling similar to SQLite +- Maintains thread safety guarantees of SQLite + +## Contributing + +When adding new features or fixing bugs: + +1. Add C tests that can run against both implementations +2. Add corresponding Rust tests +3. Verify behavior matches SQLite by running C tests against both implementations +4. Ensure all existing tests pass in both suites +5. Make sure test database exists and is accessible + +## Status + +This is an ongoing implementation. Some functions are marked with `stub!()` macro, indicating they're not yet implemented. Check individual function documentation for implementation status. \ No newline at end of file diff --git a/sqlite3/include/sqlite3.h b/sqlite3/include/sqlite3.h index 6ddf3938b..530eef5aa 100644 --- a/sqlite3/include/sqlite3.h +++ b/sqlite3/include/sqlite3.h @@ -31,6 +31,12 @@ #define SQLITE_STATE_BUSY 109 +/* WAL Checkpoint modes */ +#define SQLITE_CHECKPOINT_PASSIVE 0 +#define SQLITE_CHECKPOINT_FULL 1 +#define SQLITE_CHECKPOINT_RESTART 2 +#define SQLITE_CHECKPOINT_TRUNCATE 3 + typedef struct sqlite3 sqlite3; typedef struct sqlite3_stmt sqlite3_stmt; @@ -244,6 +250,17 @@ const char *sqlite3_libversion(void); int sqlite3_libversion_number(void); +/* WAL Checkpoint functions */ +int sqlite3_wal_checkpoint(sqlite3 *db, const char *db_name); + +int sqlite3_wal_checkpoint_v2( + sqlite3 *db, + const char *db_name, + int mode, + int *log_size, + int *checkpoint_count +); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/sqlite3/src/lib.rs b/sqlite3/src/lib.rs index c1cab5eb0..aeea6bb64 100644 --- a/sqlite3/src/lib.rs +++ b/sqlite3/src/lib.rs @@ -113,7 +113,7 @@ pub unsafe extern "C" fn sqlite3_open( ":memory:" => Arc::new(limbo_core::MemoryIO::new()), _ => match limbo_core::PlatformIO::new() { Ok(io) => Arc::new(io), - Err(_) => return SQLITE_MISUSE, + Err(_) => return SQLITE_CANTOPEN, }, }; match limbo_core::Database::open_file(io, filename, false) { @@ -122,7 +122,10 @@ pub unsafe extern "C" fn sqlite3_open( *db_out = Box::leak(Box::new(sqlite3::new(db, conn))); SQLITE_OK } - Err(_e) => SQLITE_CANTOPEN, + Err(e) => { + log::error!("error opening database {:?}", e); + SQLITE_CANTOPEN + } } } @@ -1079,3 +1082,187 @@ pub unsafe extern "C" fn sqlite3_wal_checkpoint_v2( } SQLITE_OK } + +#[cfg(test)] +mod tests { + use super::*; + use std::ptr; + + #[test] + fn test_libversion() { + unsafe { + let version = sqlite3_libversion(); + assert!(!version.is_null()); + } + } + + #[test] + fn test_libversion_number() { + unsafe { + let version_num = sqlite3_libversion_number(); + assert_eq!(version_num, 3042000); + } + } + + #[test] + fn test_open_misuse() { + unsafe { + let mut db = ptr::null_mut(); + assert_eq!(sqlite3_open(ptr::null(), &mut db), SQLITE_MISUSE); + } + } + + #[test] + fn test_open_not_found() { + unsafe { + let mut db = ptr::null_mut(); + assert_eq!( + sqlite3_open(b"not-found/local.db\0".as_ptr() as *const i8, &mut db), + SQLITE_CANTOPEN + ); + } + } + + #[test] + #[ignore] + fn test_open_existing() { + unsafe { + let mut db = ptr::null_mut(); + assert_eq!( + sqlite3_open(b"../testing/testing.db\0".as_ptr() as *const i8, &mut db), + SQLITE_OK + ); + assert_eq!(sqlite3_close(db), SQLITE_OK); + } + } + + #[test] + fn test_close() { + unsafe { + assert_eq!(sqlite3_close(ptr::null_mut()), SQLITE_OK); + } + } + + #[test] + #[ignore] + fn test_prepare_misuse() { + unsafe { + let mut db = ptr::null_mut(); + assert_eq!( + sqlite3_open(b"../testing/testing.db\0".as_ptr() as *const i8, &mut db), + SQLITE_OK + ); + + let mut stmt = ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2( + db, + b"SELECT 1\0".as_ptr() as *const i8, + -1, + &mut stmt, + ptr::null_mut() + ), + SQLITE_OK + ); + + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + assert_eq!(sqlite3_close(db), SQLITE_OK); + } + } + + #[test] + #[ignore] + fn test_wal_checkpoint() { + unsafe { + // Test with NULL db handle + assert_eq!( + sqlite3_wal_checkpoint(ptr::null_mut(), ptr::null()), + SQLITE_MISUSE + ); + + // Test with valid db + let mut db = ptr::null_mut(); + assert_eq!( + sqlite3_open(b"../testing/testing.db\0".as_ptr() as *const i8, &mut db), + SQLITE_OK + ); + assert_eq!(sqlite3_wal_checkpoint(db, ptr::null()), SQLITE_OK); + assert_eq!(sqlite3_close(db), SQLITE_OK); + } + } + + #[test] + #[ignore] + fn test_wal_checkpoint_v2() { + unsafe { + // Test with NULL db handle + assert_eq!( + sqlite3_wal_checkpoint_v2( + ptr::null_mut(), + ptr::null(), + SQLITE_CHECKPOINT_PASSIVE, + ptr::null_mut(), + ptr::null_mut() + ), + SQLITE_MISUSE + ); + + // Test with valid db + let mut db = ptr::null_mut(); + assert_eq!( + sqlite3_open(b"../testing/testing.db\0".as_ptr() as *const i8, &mut db), + SQLITE_OK + ); + + let mut log_size = 0; + let mut checkpoint_count = 0; + + // Test different checkpoint modes + assert_eq!( + sqlite3_wal_checkpoint_v2( + db, + ptr::null(), + SQLITE_CHECKPOINT_PASSIVE, + &mut log_size, + &mut checkpoint_count + ), + SQLITE_OK + ); + + assert_eq!( + sqlite3_wal_checkpoint_v2( + db, + ptr::null(), + SQLITE_CHECKPOINT_FULL, + &mut log_size, + &mut checkpoint_count + ), + SQLITE_OK + ); + + assert_eq!( + sqlite3_wal_checkpoint_v2( + db, + ptr::null(), + SQLITE_CHECKPOINT_RESTART, + &mut log_size, + &mut checkpoint_count + ), + SQLITE_OK + ); + + assert_eq!( + sqlite3_wal_checkpoint_v2( + db, + ptr::null(), + SQLITE_CHECKPOINT_TRUNCATE, + &mut log_size, + &mut checkpoint_count + ), + SQLITE_OK + ); + + assert_eq!(sqlite3_close(db), SQLITE_OK); + } + } +} diff --git a/sqlite3/tests/Makefile b/sqlite3/tests/Makefile index ae1b7bb4a..44786a3f2 100644 --- a/sqlite3/tests/Makefile +++ b/sqlite3/tests/Makefile @@ -1,44 +1,39 @@ -V = -ifeq ($(strip $(V)),) - E = @echo - Q = @ -else - E = @\# - Q = -endif -export E Q - -PROGRAM = sqlite3-tests - -CFLAGS = -g -Wall -std=c17 -MMD -MP +# Compiler settings +CC = gcc +CFLAGS = -g -Wall -std=c17 -I../include +# Libraries LIBS ?= -lsqlite3 LIBS += -lm -OBJS += main.o -OBJS += test-aux.o -OBJS += test-close.o -OBJS += test-open.o -OBJS += test-prepare.o +# Target program +PROGRAM = sqlite3-tests +# Object files +OBJS = main.o \ + test-aux.o \ + test-close.o \ + test-open.o \ + test-prepare.o \ + test-wal.o + +# Default target all: $(PROGRAM) +# Test target test: $(PROGRAM) - $(E) " TEST" - $(Q) $(CURDIR)/$(PROGRAM) + ./$(PROGRAM) +# Compile source files %.o: %.c - $(E) " CC " $@ - $(Q) $(CC) $(CFLAGS) -c $< -o $@ -I$(HEADERS) + $(CC) $(CFLAGS) -c $< -o $@ +# Link program $(PROGRAM): $(OBJS) - $(E) " LINK " $@ - $(Q) $(CC) -o $@ $^ $(LIBS) + $(CC) -o $@ $(OBJS) $(LIBS) +# Clean target clean: - $(E) " CLEAN" - $(Q) rm -f $(PROGRAM) - $(Q) rm -f $(OBJS) *.d -.PHONY: clean + rm -f $(PROGRAM) $(OBJS) --include $(OBJS:.o=.d) +.PHONY: all test clean diff --git a/sqlite3/tests/main.c b/sqlite3/tests/main.c index 0166aa860..4cbf19f5e 100644 --- a/sqlite3/tests/main.c +++ b/sqlite3/tests/main.c @@ -5,6 +5,8 @@ extern void test_open_not_found(); extern void test_open_existing(); extern void test_close(); extern void test_prepare_misuse(); +extern void test_wal_checkpoint(); +extern void test_wal_checkpoint_v2(); int main(int argc, char *argv[]) { @@ -15,6 +17,8 @@ int main(int argc, char *argv[]) test_open_existing(); test_close(); test_prepare_misuse(); + test_wal_checkpoint(); + test_wal_checkpoint_v2(); return 0; } diff --git a/sqlite3/tests/test-wal.c b/sqlite3/tests/test-wal.c new file mode 100644 index 000000000..490277e02 --- /dev/null +++ b/sqlite3/tests/test-wal.c @@ -0,0 +1,39 @@ +#include "check.h" + +#include +#include +#include +#include + +void test_wal_checkpoint(void) +{ + sqlite3 *db; + + // Test with NULL db handle + CHECK_EQUAL(SQLITE_MISUSE, sqlite3_wal_checkpoint(NULL, NULL)); + + // Test with valid db + CHECK_EQUAL(SQLITE_OK, sqlite3_open("../../testing/testing.db", &db)); + CHECK_EQUAL(SQLITE_OK, sqlite3_wal_checkpoint(db, NULL)); + CHECK_EQUAL(SQLITE_OK, sqlite3_close(db)); +} + +void test_wal_checkpoint_v2(void) +{ + sqlite3 *db; + int log_size, checkpoint_count; + + // Test with NULL db handle + CHECK_EQUAL(SQLITE_MISUSE, sqlite3_wal_checkpoint_v2(NULL, NULL, SQLITE_CHECKPOINT_PASSIVE, NULL, NULL)); + + // Test with valid db + CHECK_EQUAL(SQLITE_OK, sqlite3_open("../../testing/testing.db", &db)); + + // Test different checkpoint modes + CHECK_EQUAL(SQLITE_OK, sqlite3_wal_checkpoint_v2(db, NULL, SQLITE_CHECKPOINT_PASSIVE, &log_size, &checkpoint_count)); + CHECK_EQUAL(SQLITE_OK, sqlite3_wal_checkpoint_v2(db, NULL, SQLITE_CHECKPOINT_FULL, &log_size, &checkpoint_count)); + CHECK_EQUAL(SQLITE_OK, sqlite3_wal_checkpoint_v2(db, NULL, SQLITE_CHECKPOINT_RESTART, &log_size, &checkpoint_count)); + CHECK_EQUAL(SQLITE_OK, sqlite3_wal_checkpoint_v2(db, NULL, SQLITE_CHECKPOINT_TRUNCATE, &log_size, &checkpoint_count)); + + CHECK_EQUAL(SQLITE_OK, sqlite3_close(db)); +} \ No newline at end of file diff --git a/testing/cli_tests/extensions.py b/testing/cli_tests/extensions.py index a1d278c90..d53e75b22 100755 --- a/testing/cli_tests/extensions.py +++ b/testing/cli_tests/extensions.py @@ -337,7 +337,7 @@ def test_series(): def test_kv(): - ext_path = "./target/debug/liblimbo_kv" + ext_path = "target/debug/liblimbo_ext_tests" limbo = TestLimboShell() limbo.run_test_fn( "create virtual table t using kv_store;", @@ -401,17 +401,18 @@ def test_kv(): ) limbo.quit() + def test_ipaddr(): limbo = TestLimboShell() ext_path = "./target/debug/liblimbo_ipaddr" - + limbo.run_test_fn( "SELECT ipfamily('192.168.1.1');", lambda res: "error: no such function: " in res, "ipfamily function returns null when ext not loaded", ) limbo.execute_dot(f".load {ext_path}") - + limbo.run_test_fn( "SELECT ipfamily('192.168.1.1');", lambda res: "4" == res, @@ -455,7 +456,7 @@ def test_ipaddr(): lambda res: "128" == res, "ipmasklen function returns the mask length for IPv6", ) - + limbo.run_test_fn( "SELECT ipnetwork('192.168.16.12/24');", lambda res: "192.168.16.0/24" == res, @@ -466,7 +467,76 @@ def test_ipaddr(): lambda res: "2001:db8::1/128" == res, "ipnetwork function returns the network for IPv6", ) - + limbo.quit() + + +def test_vfs(): + limbo = TestLimboShell() + ext_path = "target/debug/liblimbo_ext_tests" + limbo.run_test_fn(".vfslist", lambda x: "testvfs" not in x, "testvfs not loaded") + limbo.execute_dot(f".load {ext_path}") + limbo.run_test_fn( + ".vfslist", lambda res: "testvfs" in res, "testvfs extension loaded" + ) + limbo.execute_dot(".open testing/vfs.db testvfs") + limbo.execute_dot("create table test (id integer primary key, value float);") + limbo.execute_dot("create table vfs (id integer primary key, value blob);") + for i in range(50): + limbo.execute_dot("insert into test (value) values (randomblob(32*1024));") + limbo.execute_dot(f"insert into vfs (value) values ({i});") + limbo.run_test_fn( + "SELECT count(*) FROM test;", + lambda res: res == "50", + "Tested large write to testfs", + ) + limbo.run_test_fn( + "SELECT count(*) FROM vfs;", + lambda res: res == "50", + "Tested large write to testfs", + ) + print("Tested large write to testfs") + # open regular db file to ensure we don't segfault when vfs file is dropped + limbo.execute_dot(".open testing/vfs.db") + limbo.execute_dot("create table test (id integer primary key, value float);") + limbo.execute_dot("insert into test (value) values (1.0);") + limbo.quit() + + +def test_sqlite_vfs_compat(): + sqlite = TestLimboShell( + init_commands="", + exec_name="sqlite3", + flags="testing/vfs.db", + ) + sqlite.run_test_fn( + ".show", + lambda res: "filename: testing/vfs.db" in res, + "Opened db file created with vfs extension in sqlite3", + ) + sqlite.run_test_fn( + ".schema", + lambda res: "CREATE TABLE test (id integer PRIMARY KEY, value float);" in res, + "Tables created by vfs extension exist in db file", + ) + sqlite.run_test_fn( + "SELECT count(*) FROM test;", + lambda res: res == "50", + "Tested large write to testfs", + ) + sqlite.run_test_fn( + "SELECT count(*) FROM vfs;", + lambda res: res == "50", + "Tested large write to testfs", + ) + sqlite.quit() + + +def cleanup(): + if os.path.exists("testing/vfs.db"): + os.remove("testing/vfs.db") + if os.path.exists("testing/vfs.db-wal"): + os.remove("testing/vfs.db-wal") + if __name__ == "__main__": try: @@ -477,7 +547,11 @@ if __name__ == "__main__": test_series() test_kv() test_ipaddr() + test_vfs() + test_sqlite_vfs_compat() except Exception as e: print(f"Test FAILED: {e}") + cleanup() exit(1) + cleanup() print("All tests passed successfully.") diff --git a/testing/cli_tests/test_limbo_cli.py b/testing/cli_tests/test_limbo_cli.py index ad82952a6..38186bf48 100755 --- a/testing/cli_tests/test_limbo_cli.py +++ b/testing/cli_tests/test_limbo_cli.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 import os import select +from time import sleep import subprocess -from dataclasses import dataclass, field from pathlib import Path from typing import Callable, List, Optional @@ -10,16 +10,14 @@ from typing import Callable, List, Optional PIPE_BUF = 4096 -@dataclass class ShellConfig: - sqlite_exec: str = os.getenv("LIMBO_TARGET", "./target/debug/limbo") - sqlite_flags: List[str] = field( - default_factory=lambda: os.getenv("SQLITE_FLAGS", "-q").split() - ) - cwd = os.getcwd() - test_dir: Path = field(default_factory=lambda: Path("testing")) - py_folder: Path = field(default_factory=lambda: Path("cli_tests")) - test_files: Path = field(default_factory=lambda: Path("test_files")) + def __init__(self, exe_name, flags: str = "-q"): + self.sqlite_exec: str = exe_name + self.sqlite_flags: List[str] = flags.split() + self.cwd = os.getcwd() + self.test_dir: Path = Path("testing") + self.py_folder: Path = Path("cli_tests") + self.test_files: Path = Path("test_files") class LimboShell: @@ -92,14 +90,24 @@ class LimboShell: def quit(self) -> None: self._write_to_pipe(".quit") + sleep(0.3) self.pipe.terminate() + self.pipe.kill() class TestLimboShell: def __init__( - self, init_commands: Optional[str] = None, init_blobs_table: bool = False + self, + init_commands: Optional[str] = None, + init_blobs_table: bool = False, + exec_name: Optional[str] = None, + flags="", ): - self.config = ShellConfig() + if exec_name is None: + exec_name = "./target/debug/limbo" + if flags == "": + flags = "-q" + self.config = ShellConfig(exe_name=exec_name, flags=flags) if init_commands is None: # Default initialization init_commands = """ @@ -132,6 +140,11 @@ INSERT INTO t VALUES (zeroblob(1024 - 1), zeroblob(1024 - 2), zeroblob(1024 - 3) f"Actual:\n{repr(actual)}" ) + def debug_print(self, sql: str): + print(f"debugging: {sql}") + actual = self.shell.execute(sql) + print(f"OUTPUT:\n{repr(actual)}") + def run_test_fn( self, sql: str, validate: Callable[[str], bool], desc: str = "" ) -> None: diff --git a/testing/json.test b/testing/json.test index d5fa827d9..c6dc99553 100755 --- a/testing/json.test +++ b/testing/json.test @@ -682,9 +682,12 @@ do_execsql_test json_valid_1 { do_execsql_test json_valid_2 { SELECT json_valid('["a",55,"b",72]'); } {1} -do_execsql_test json_valid_3 { - SELECT json_valid( CAST('{"a":1}' AS BLOB) ); -} {1} +# +# Unimplemented +#do_execsql_test json_valid_3 { +# SELECT json_valid( CAST('{"a":"1}' AS BLOB) ); +#} {0} +# do_execsql_test json_valid_4 { SELECT json_valid(123); } {1} @@ -700,9 +703,7 @@ do_execsql_test json_valid_7 { do_execsql_test json_valid_8 { SELECT json_valid('{"a":55 "b":72}'); } {0} -do_execsql_test json_valid_3 { - SELECT json_valid( CAST('{"a":"1}' AS BLOB) ); -} {0} + do_execsql_test json_valid_9 { SELECT json_valid(NULL); } {} @@ -906,6 +907,80 @@ do_execsql_test json_quote_json_value { SELECT json_quote(json('{a:1, b: "test"}')); } {{{"a":1,"b":"test"}}} +do_execsql_test json_basics { + SELECT json(jsonb('{"name":"John", "age":30, "city":"New York"}')); +} {{{"name":"John","age":30,"city":"New York"}}} + +do_execsql_test json_complex_nested { + SELECT json(jsonb('{"complex": {"nested": ["array", "of", "values"], "numbers": [1, 2, 3]}}')); +} {{{"complex":{"nested":["array","of","values"],"numbers":[1,2,3]}}}} + +do_execsql_test json_array_of_objects { + SELECT json(jsonb('[{"id": 1, "data": "value1"}, {"id": 2, "data": "value2"}]')); +} {{[{"id":1,"data":"value1"},{"id":2,"data":"value2"}]}} + +do_execsql_test json_special_chars { + SELECT json(jsonb('{"special_chars": "!@#$%^&*()_+", "quotes": "\"quoted text\""}')); +} {{{"special_chars":"!@#$%^&*()_+","quotes":"\"quoted text\""}}} + +do_execsql_test json_unicode_emoji { + SELECT json(jsonb('{"unicode": "こんにちは世界", "emoji": "🚀🔥💯"}')); +} {{{"unicode":"こんにちは世界","emoji":"🚀🔥💯"}}} + +do_execsql_test json_value_types { + SELECT json(jsonb('{"boolean": true, "null_value": null, "number": 42.5}')); +} {{{"boolean":true,"null_value":null,"number":42.5}}} + +do_execsql_test json_deeply_nested { + SELECT json(jsonb('{"deeply": {"nested": {"structure": {"with": "values"}}}}')); +} {{{"deeply":{"nested":{"structure":{"with":"values"}}}}}} + +do_execsql_test json_mixed_array { + SELECT json(jsonb('{"array_mixed": [1, "text", true, null, {"obj": "inside array"}]}')); +} {{{"array_mixed":[1,"text",true,null,{"obj":"inside array"}]}}} + +do_execsql_test json_single_line_comments { + SELECT json(jsonb('{"name": "John", // This is a comment + "age": 30}')); +} {{{"name":"John","age":30}}} + +do_execsql_test json_multi_line_comments { + SELECT json(jsonb('{"data": "value", /* This is a + multi-line comment that spans + several lines */ "more": "data"}')); +} {{{"data":"value","more":"data"}}} + +do_execsql_test json_trailing_commas { + SELECT json(jsonb('{"items": ["one", "two", "three",], "status": "complete",}')); +} {{{"items":["one","two","three"],"status":"complete"}}} + +do_execsql_test json_unquoted_keys { + SELECT json(jsonb('{name: "Alice", age: 25}')); +} {{{"name":"Alice","age":25}}} + +do_execsql_test json_newlines { + SELECT json(jsonb('{"description": "Text with \nnew lines\nand more\nformatting"}')); +} {{{"description":"Text with \nnew lines\nand more\nformatting"}}} + +do_execsql_test json_hex_values { + SELECT json(jsonb('{"hex_value": "\x68\x65\x6c\x6c\x6f"}')); +} {{{"hex_value":"\u0068\u0065\u006c\u006c\u006f"}}} + +do_execsql_test json_unicode_escape { + SELECT json(jsonb('{"unicode": "\u0068\u0065\u006c\u006c\u006f"}')); +} {{{"unicode":"\u0068\u0065\u006c\u006c\u006f"}}} + +do_execsql_test json_tabs_whitespace { + SELECT json(jsonb('{"formatted": "Text with \ttabs and \tspacing"}')); +} {{{"formatted":"Text with \ttabs and \tspacing"}}} + +do_execsql_test json_mixed_escaping { + SELECT json(jsonb('{"mixed": "Newlines: \n Tabs: \t Quotes: \" Backslash: \\ Hex: \x40"}')); +} {{{"mixed":"Newlines: \n Tabs: \t Quotes: \" Backslash: \\ Hex: \u0040"}}} + +do_execsql_test json_control_chars { + SELECT json(jsonb('{"control": "Bell: \u0007 Backspace: \u0008 Form feed: \u000C"}')); +} {{{"control":"Bell: \u0007 Backspace: \u0008 Form feed: \u000C"}}} # Escape character tests in sqlite source depend on json_valid and in some syntax that is not implemented # yet in limbo. @@ -916,4 +991,3 @@ do_execsql_test json_quote_json_value { # WITH RECURSIVE c(x) AS (VALUES(1) UNION ALL SELECT x+1 FROM c WHERE x<0x1f) # SELECT sum(json_valid(json_quote('a'||char(x)||'z'))) FROM c ORDER BY x; # } {31} - diff --git a/tests/integration/query_processing/test_write_path.rs b/tests/integration/query_processing/test_write_path.rs index 0d9d68b41..dd313cfea 100644 --- a/tests/integration/query_processing/test_write_path.rs +++ b/tests/integration/query_processing/test_write_path.rs @@ -5,6 +5,7 @@ use log::debug; use std::rc::Rc; #[test] +#[ignore] fn test_simple_overflow_page() -> anyhow::Result<()> { let _ = env_logger::try_init(); let tmp_db = @@ -75,6 +76,7 @@ fn test_simple_overflow_page() -> anyhow::Result<()> { } #[test] +#[ignore] fn test_sequential_overflow_page() -> anyhow::Result<()> { let _ = env_logger::try_init(); let tmp_db = @@ -152,7 +154,7 @@ fn test_sequential_overflow_page() -> anyhow::Result<()> { } #[ignore] -#[test] +#[test_log::test] fn test_sequential_write() -> anyhow::Result<()> { let _ = env_logger::try_init(); @@ -162,7 +164,7 @@ fn test_sequential_write() -> anyhow::Result<()> { let list_query = "SELECT * FROM test"; let max_iterations = 10000; for i in 0..max_iterations { - debug!("inserting {} ", i); + println!("inserting {} ", i); if (i % 100) == 0 { let progress = (i as f64 / max_iterations as f64) * 100.0; println!("progress {:.1}%", progress);