diff --git a/COMPAT.md b/COMPAT.md index e85a47725..3d07558c8 100644 --- a/COMPAT.md +++ b/COMPAT.md @@ -226,7 +226,7 @@ Feature support of [sqlite expr syntax](https://www.sqlite.org/lang_expr.html). | length(X) | Yes | | | like(X,Y) | Yes | | | like(X,Y,Z) | Yes | | -| likelihood(X,Y) | No | | +| likelihood(X,Y) | Yes | | | likely(X) | Yes | | | load_extension(X) | Yes | sqlite3 extensions not yet supported | | load_extension(X,Y) | No | | @@ -328,7 +328,7 @@ Feature support of [sqlite expr syntax](https://www.sqlite.org/lang_expr.html). | julianday() | Partial | does not support modifiers | | unixepoch() | Partial | does not support modifiers | | strftime() | Yes | partially supports modifiers | -| timediff() | No | | +| timediff() | Yes | partially supports modifiers | Modifiers: diff --git a/Cargo.lock b/Cargo.lock index eb7943c70..eb6d9b620 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1917,11 +1917,16 @@ dependencies = [ name = "limbo_stress" version = "0.0.19-pre.4" dependencies = [ + "anarchist-readable-name-generator-lib", "antithesis_sdk", "clap", + "hex", "limbo", "serde_json", "tokio", + "tracing", + "tracing-appender", + "tracing-subscriber", ] [[package]] @@ -2588,9 +2593,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.24.0" +version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f1c6c3591120564d64db2261bec5f910ae454f01def849b9c22835a84695e86" +checksum = "17da310086b068fbdcefbba30aeb3721d5bb9af8db4987d6735b2183ca567229" dependencies = [ "anyhow", "cfg-if", @@ -2607,9 +2612,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.24.0" +version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9b6c2b34cf71427ea37c7001aefbaeb85886a074795e35f161f5aecc7620a7a" +checksum = "e27165889bd793000a098bb966adc4300c312497ea25cf7a690a9f0ac5aa5fc1" dependencies = [ "once_cell", "target-lexicon", @@ -2617,9 +2622,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.24.0" +version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5507651906a46432cdda02cd02dd0319f6064f1374c9147c45b978621d2c3a9c" +checksum = "05280526e1dbf6b420062f3ef228b78c0c54ba94e157f5cb724a609d0f2faabc" dependencies = [ "libc", "pyo3-build-config", @@ -2627,9 +2632,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.24.0" +version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0d394b5b4fd8d97d48336bb0dd2aebabad39f1d294edd6bcd2cccf2eefe6f42" +checksum = "5c3ce5686aa4d3f63359a5100c62a127c9f15e8398e5fdeb5deef1fed5cd5f44" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -2639,9 +2644,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.24.0" +version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd72da09cfa943b1080f621f024d2ef7e2773df7badd51aa30a2be1f8caa7c8e" +checksum = "f4cf6faa0cbfb0ed08e89beb8103ae9724eb4750e3a78084ba4017cbe94f3855" dependencies = [ "heck", "proc-macro2", diff --git a/Dockerfile.antithesis b/Dockerfile.antithesis index 056ad0947..1f4f3ba10 100644 --- a/Dockerfile.antithesis +++ b/Dockerfile.antithesis @@ -14,6 +14,7 @@ COPY ./Cargo.lock ./Cargo.lock COPY ./Cargo.toml ./Cargo.toml COPY ./bindings/go ./bindings/go/ COPY ./bindings/java ./bindings/java/ +COPY ./bindings/javascript ./bindings/javascript/ COPY ./bindings/python ./bindings/python/ COPY ./bindings/rust ./bindings/rust/ COPY ./bindings/wasm ./bindings/wasm/ diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 402a3a760..4a8eaef59 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -18,7 +18,7 @@ extension-module = ["pyo3/extension-module"] [dependencies] anyhow = "1.0" limbo_core = { path = "../../core", features = ["io_uring"] } -pyo3 = { version = "0.24.0", features = ["anyhow"] } +pyo3 = { version = "0.24.1", features = ["anyhow"] } [build-dependencies] version_check = "0.9.5" diff --git a/bindings/rust/src/lib.rs b/bindings/rust/src/lib.rs index 60a7ffd77..61e6271c9 100644 --- a/bindings/rust/src/lib.rs +++ b/bindings/rust/src/lib.rs @@ -17,11 +17,13 @@ pub enum Error { ToSqlConversionFailure(BoxError), #[error("Mutex lock error: {0}")] MutexError(String), + #[error("SQL execution failure: `{0}`")] + SqlExecutionFailure(String), } impl From for Error { - fn from(_err: limbo_core::LimboError) -> Self { - todo!(); + fn from(err: limbo_core::LimboError) -> Self { + Error::SqlExecutionFailure(err.to_string()) } } diff --git a/cli/app.rs b/cli/app.rs index c5cb2ff4f..3f04ab9fe 100644 --- a/cli/app.rs +++ b/cli/app.rs @@ -818,22 +818,27 @@ impl<'a> Limbo<'a> { } pub fn init_tracing(&mut self) -> Result { - let (non_blocking, guard) = if let Some(file) = &self.opts.tracing_output { - tracing_appender::non_blocking( - std::fs::File::options() - .append(true) - .create(true) - .open(file)?, - ) - } else { - tracing_appender::non_blocking(std::io::stderr()) - }; + let ((non_blocking, guard), should_emit_ansi) = + if let Some(file) = &self.opts.tracing_output { + ( + tracing_appender::non_blocking( + std::fs::File::options() + .append(true) + .create(true) + .open(file)?, + ), + false, + ) + } else { + (tracing_appender::non_blocking(std::io::stderr()), true) + }; if let Err(e) = tracing_subscriber::registry() .with( tracing_subscriber::fmt::layer() .with_writer(non_blocking) .with_line_number(true) - .with_thread_ids(true), + .with_thread_ids(true) + .with_ansi(should_emit_ansi), ) .with(EnvFilter::from_default_env()) .try_init() diff --git a/cli/input.rs b/cli/input.rs index e352899c9..eac5312dc 100644 --- a/cli/input.rs +++ b/cli/input.rs @@ -43,7 +43,7 @@ impl Default for Io { true => { #[cfg(all(target_os = "linux", feature = "io_uring"))] { - Io::IoUring + Io::Syscall // FIXME: make io_uring faster so it can be the default } #[cfg(any( not(target_os = "linux"), diff --git a/core/ext/mod.rs b/core/ext/mod.rs index 270bee682..939fe3e05 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -89,12 +89,12 @@ impl Database { path: &str, vfs: &str, ) -> crate::Result<(Arc, Arc)> { - use crate::{MemoryIO, PlatformIO}; + use crate::{MemoryIO, SyscallIO}; use dynamic::get_vfs_modules; let io: Arc = match vfs { "memory" => Arc::new(MemoryIO::new()), - "syscall" => Arc::new(PlatformIO::new()?), + "syscall" => Arc::new(SyscallIO::new()?), #[cfg(all(target_os = "linux", feature = "io_uring"))] "io_uring" => Arc::new(UringIO::new()?), other => match get_vfs_modules().iter().find(|v| v.0 == vfs) { diff --git a/core/function.rs b/core/function.rs index 4c235cca5..da246e719 100644 --- a/core/function.rs +++ b/core/function.rs @@ -293,6 +293,8 @@ pub enum ScalarFunc { StrfTime, Printf, Likely, + TimeDiff, + Likelihood, } impl Display for ScalarFunc { @@ -348,6 +350,8 @@ impl Display for ScalarFunc { Self::StrfTime => "strftime".to_string(), Self::Printf => "printf".to_string(), Self::Likely => "likely".to_string(), + Self::TimeDiff => "timediff".to_string(), + Self::Likelihood => "likelihood".to_string(), }; write!(f, "{}", str) } @@ -555,6 +559,12 @@ impl Func { } Ok(Self::Agg(AggFunc::Total)) } + "timediff" => { + if arg_count != 2 { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Scalar(ScalarFunc::TimeDiff)) + } #[cfg(feature = "json")] "jsonb_group_array" => Ok(Self::Agg(AggFunc::JsonbGroupArray)), #[cfg(feature = "json")] @@ -599,6 +609,7 @@ impl Func { "sqlite_source_id" => Ok(Self::Scalar(ScalarFunc::SqliteSourceId)), "replace" => Ok(Self::Scalar(ScalarFunc::Replace)), "likely" => Ok(Self::Scalar(ScalarFunc::Likely)), + "likelihood" => Ok(Self::Scalar(ScalarFunc::Likelihood)), #[cfg(feature = "json")] "json" => Ok(Self::Json(JsonFunc::Json)), #[cfg(feature = "json")] diff --git a/core/functions/datetime.rs b/core/functions/datetime.rs index 294fbfb2d..864f61787 100644 --- a/core/functions/datetime.rs +++ b/core/functions/datetime.rs @@ -656,6 +656,61 @@ fn parse_modifier(modifier: &str) -> Result { } } +pub fn exec_timediff(values: &[Register]) -> OwnedValue { + if values.len() < 2 { + return OwnedValue::Null; + } + + let start = parse_naive_date_time(values[0].get_owned_value()); + let end = parse_naive_date_time(values[1].get_owned_value()); + + match (start, end) { + (Some(start), Some(end)) => { + let duration = start.signed_duration_since(end); + format_time_duration(&duration) + } + _ => OwnedValue::Null, + } +} + +/// Format the time duration as +/-YYYY-MM-DD HH:MM:SS.SSS as per SQLite's timediff() function +fn format_time_duration(duration: &chrono::Duration) -> OwnedValue { + let is_negative = duration.num_seconds() < 0; + + let abs_duration = if is_negative { + -duration.clone() + } else { + duration.clone() + }; + + let total_seconds = abs_duration.num_seconds(); + let hours = (total_seconds % 86400) / 3600; + let minutes = (total_seconds % 3600) / 60; + let seconds = total_seconds % 60; + + let days = total_seconds / 86400; + let years = days / 365; + let remaining_days = days % 365; + let months = 0; + + let total_millis = abs_duration.num_milliseconds(); + let millis = total_millis % 1000; + + let result = format!( + "{}{:04}-{:02}-{:02} {:02}:{:02}:{:02}.{:03}", + if is_negative { "-" } else { "+" }, + years, + months, + remaining_days, + hours, + minutes, + seconds, + millis + ); + + OwnedValue::build_text(&result) +} + #[cfg(test)] mod tests { use super::*; @@ -1642,4 +1697,67 @@ mod tests { #[test] fn test_strftime() {} + + #[test] + fn test_exec_timediff() { + let start = OwnedValue::build_text("12:00:00"); + let end = OwnedValue::build_text("14:30:45"); + let expected = OwnedValue::build_text("-0000-00-00 02:30:45.000"); + assert_eq!( + exec_timediff(&[Register::OwnedValue(start), Register::OwnedValue(end)]), + expected + ); + + let start = OwnedValue::build_text("14:30:45"); + let end = OwnedValue::build_text("12:00:00"); + let expected = OwnedValue::build_text("+0000-00-00 02:30:45.000"); + assert_eq!( + exec_timediff(&[Register::OwnedValue(start), Register::OwnedValue(end)]), + expected + ); + + let start = OwnedValue::build_text("12:00:01.300"); + let end = OwnedValue::build_text("12:00:00.500"); + let expected = OwnedValue::build_text("+0000-00-00 00:00:00.800"); + assert_eq!( + exec_timediff(&[Register::OwnedValue(start), Register::OwnedValue(end)]), + expected + ); + + let start = OwnedValue::build_text("13:30:00"); + let end = OwnedValue::build_text("16:45:30"); + let expected = OwnedValue::build_text("-0000-00-00 03:15:30.000"); + assert_eq!( + exec_timediff(&[Register::OwnedValue(start), Register::OwnedValue(end)]), + expected + ); + + let start = OwnedValue::build_text("2023-05-10 23:30:00"); + let end = OwnedValue::build_text("2023-05-11 01:15:00"); + let expected = OwnedValue::build_text("-0000-00-00 01:45:00.000"); + assert_eq!( + exec_timediff(&[Register::OwnedValue(start), Register::OwnedValue(end)]), + expected + ); + + let start = OwnedValue::Null; + let end = OwnedValue::build_text("12:00:00"); + let expected = OwnedValue::Null; + assert_eq!( + exec_timediff(&[Register::OwnedValue(start), Register::OwnedValue(end)]), + expected + ); + + let start = OwnedValue::build_text("not a time"); + let end = OwnedValue::build_text("12:00:00"); + let expected = OwnedValue::Null; + assert_eq!( + exec_timediff(&[Register::OwnedValue(start), Register::OwnedValue(end)]), + expected + ); + + let start = OwnedValue::build_text("12:00:00"); + let expected = OwnedValue::Null; + assert_eq!(exec_timediff(&[Register::OwnedValue(start)]), expected); + } } diff --git a/core/io/mod.rs b/core/io/mod.rs index 1d3223128..1cda42380 100644 --- a/core/io/mod.rs +++ b/core/io/mod.rs @@ -190,7 +190,7 @@ cfg_block! { #[cfg(feature = "fs")] pub use unix::UnixIO; pub use unix::UnixIO as SyscallIO; - pub use io_uring::UringIO as PlatformIO; + pub use unix::UnixIO as PlatformIO; } #[cfg(any(all(target_os = "linux",not(feature = "io_uring")), target_os = "macos"))] { diff --git a/core/storage/btree.rs b/core/storage/btree.rs index 8f4afb090..5d07b6b82 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -1,56 +1,82 @@ -use tracing::debug; - -use crate::storage::pager::Pager; -use crate::storage::sqlite3_ondisk::{ - read_u32, read_varint, BTreeCell, PageContent, PageType, TableInteriorCell, TableLeafCell, +use crate::{ + storage::{ + pager::Pager, + sqlite3_ondisk::{ + read_u32, read_varint, BTreeCell, PageContent, PageType, TableInteriorCell, + TableLeafCell, + }, + }, + translate::plan::IterationDirection, + MvCursor, }; -use crate::translate::plan::IterationDirection; -use crate::MvCursor; -use crate::types::{ - compare_immutable, CursorResult, ImmutableRecord, OwnedValue, RefValue, SeekKey, SeekOp, +use crate::{ + return_corrupt, + types::{ + compare_immutable, CursorResult, ImmutableRecord, OwnedValue, RefValue, SeekKey, SeekOp, + }, + LimboError, Result, }; -use crate::{return_corrupt, LimboError, Result}; -use std::cell::{Cell, Ref, RefCell}; -use std::cmp::Ordering; #[cfg(debug_assertions)] use std::collections::HashSet; -use std::pin::Pin; -use std::rc::Rc; - -use super::pager::PageRef; -use super::sqlite3_ondisk::{ - read_record, write_varint_to_vec, IndexInteriorCell, IndexLeafCell, OverflowCell, - DATABASE_HEADER_SIZE, +use std::{ + cell::{Cell, Ref, RefCell}, + cmp::Ordering, + pin::Pin, + rc::Rc, }; -/* - These are offsets of fields in the header of a b-tree page. -*/ +use super::{ + pager::PageRef, + sqlite3_ondisk::{ + read_record, write_varint_to_vec, IndexInteriorCell, IndexLeafCell, OverflowCell, + DATABASE_HEADER_SIZE, + }, +}; -/// type of btree page -> u8 -const PAGE_HEADER_OFFSET_PAGE_TYPE: usize = 0; -/// pointer to first freeblock -> u16 -/// The second field of the b-tree page header is the offset of the first freeblock, or zero if there are no freeblocks on the page. -/// A freeblock is a structure used to identify unallocated space within a b-tree page. -/// Freeblocks are organized as a chain. +/// The B-Tree page header is 12 bytes for interior pages and 8 bytes for leaf pages. /// -/// To be clear, freeblocks do not mean the regular unallocated free space to the left of the cell content area pointer, but instead -/// blocks of at least 4 bytes WITHIN the cell content area that are not in use due to e.g. deletions. -const PAGE_HEADER_OFFSET_FIRST_FREEBLOCK: usize = 1; -/// number of cells in the page -> u16 -const PAGE_HEADER_OFFSET_CELL_COUNT: usize = 3; -/// pointer to first byte of cell allocated content from top -> u16 -/// SQLite strives to place cells as far toward the end of the b-tree page as it can, -/// in order to leave space for future growth of the cell pointer array. -/// = the cell content area pointer moves leftward as cells are added to the page -const PAGE_HEADER_OFFSET_CELL_CONTENT_AREA: usize = 5; -/// number of fragmented bytes -> u8 -/// Fragments are isolated groups of 1, 2, or 3 unused bytes within the cell content area. -const PAGE_HEADER_OFFSET_FRAGMENTED_BYTES_COUNT: usize = 7; -/// if internalnode, pointer right most pointer (saved separately from cells) -> u32 -const PAGE_HEADER_OFFSET_RIGHTMOST_PTR: usize = 8; +/// +--------+-----------------+-----------------+-----------------+--------+----- ..... ----+ +/// | Page | First Freeblock | Cell Count | Cell Content | Frag. | Right-most | +/// | Type | Offset | | Area Start | Bytes | pointer | +/// +--------+-----------------+-----------------+-----------------+--------+----- ..... ----+ +/// 0 1 2 3 4 5 6 7 8 11 +/// +pub mod offset { + /// Type of the B-Tree page (u8). + pub const BTREE_PAGE_TYPE: usize = 0; + + /// A pointer to the first freeblock (u16). + /// + /// This field of the B-Tree page header is an offset to the first freeblock, or zero if + /// there are no freeblocks on the page. A freeblock is a structure used to identify + /// unallocated space within a B-Tree page, organized as a chain. + /// + /// Please note that freeblocks do not mean the regular unallocated free space to the left + /// of the cell content area pointer, but instead blocks of at least 4 + /// bytes WITHIN the cell content area that are not in use due to e.g. + /// deletions. + pub const BTREE_FIRST_FREEBLOCK: usize = 1; + + /// The number of cells in the page (u16). + pub const BTREE_CELL_COUNT: usize = 3; + + /// A pointer to first byte of cell allocated content from top (u16). + /// + /// SQLite strives to place cells as far toward the end of the b-tree page as it can, in + /// order to leave space for future growth of the cell pointer array. This means that the + /// cell content area pointer moves leftward as cells are added to the page. + pub const BTREE_CELL_CONTENT_AREA: usize = 5; + + /// The number of fragmented bytes (u8). + /// + /// Fragments are isolated groups of 1, 2, or 3 unused bytes within the cell content area. + pub const BTREE_FRAGMENTED_BYTES_COUNT: usize = 7; + + /// The right-most pointer (saved separately from cells) (u32) + pub const BTREE_RIGHTMOST_PTR: usize = 8; +} /// Maximum depth of an SQLite B-Tree structure. Any B-Tree deeper than /// this will be declared corrupt. This value is calculated based on a @@ -229,7 +255,7 @@ impl BTreeKey<'_> { struct BalanceInfo { /// Old pages being balanced. pages_to_balance: Vec, - /// Bookkeeping of the rightmost pointer so the PAGE_HEADER_OFFSET_RIGHTMOST_PTR can be updated. + /// Bookkeeping of the rightmost pointer so the offset::BTREE_RIGHTMOST_PTR can be updated. rightmost_pointer: *mut u8, /// Divider cells of old pages divider_cells: Vec>, @@ -313,17 +339,6 @@ enum OverflowState { Done, } -/// Iteration state of the cursor. Can only be set once. -/// Once a SeekGT or SeekGE is performed, the cursor must iterate forwards and calling prev() is an error. -/// Similarly, once a SeekLT or SeekLE is performed, the cursor must iterate backwards and calling next() is an error. -/// When a SeekEQ or SeekRowid is performed, the cursor is NOT allowed to iterate further. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum IterationState { - Unset, - Iterating(IterationDirection), - IterationNotAllowed, -} - pub struct BTreeCursor { /// The multi-version cursor that is used to read and write to the database file. mv_cursor: Option>>, @@ -349,8 +364,6 @@ pub struct BTreeCursor { /// Reusable immutable record, used to allow better allocation strategy. reusable_immutable_record: RefCell>, empty_record: Cell, - - pub iteration_state: IterationState, } /// Stack of pages representing the tree traversal order. @@ -399,7 +412,6 @@ impl BTreeCursor { }, reusable_immutable_record: RefCell::new(None), empty_record: Cell::new(true), - iteration_state: IterationState::Unset, } } @@ -773,7 +785,7 @@ impl BTreeCursor { // end let has_parent = self.stack.current() > 0; if has_parent { - debug!("moving upwards"); + tracing::debug!("moving upwards"); self.going_upwards = true; self.stack.pop(); continue; @@ -943,35 +955,7 @@ impl BTreeCursor { /// or e.g. find the first record greater than the seek key in a range query (e.g. SELECT * FROM table WHERE col > 10). /// We don't include the rowid in the comparison and that's why the last value from the record is not included. fn do_seek(&mut self, key: SeekKey<'_>, op: SeekOp) -> Result>> { - assert!( - self.iteration_state != IterationState::Unset, - "iteration state must have been set before do_seek() is called" - ); - let valid_op = match (self.iteration_state, op) { - (IterationState::Iterating(IterationDirection::Forwards), SeekOp::GE | SeekOp::GT) => { - true - } - (IterationState::Iterating(IterationDirection::Backwards), SeekOp::LE | SeekOp::LT) => { - true - } - (IterationState::IterationNotAllowed, SeekOp::EQ) => true, - _ => false, - }; - assert!( - valid_op, - "invalid seek op for iteration state: {:?} {:?}", - self.iteration_state, op - ); - let cell_iter_dir = match self.iteration_state { - IterationState::Iterating(IterationDirection::Forwards) - | IterationState::IterationNotAllowed => IterationDirection::Forwards, - IterationState::Iterating(IterationDirection::Backwards) => { - IterationDirection::Backwards - } - IterationState::Unset => { - unreachable!("iteration state must have been set before do_seek() is called"); - } - }; + let cell_iter_dir = op.iteration_direction(); return_if_io!(self.move_to(key.clone(), op.clone())); { @@ -1117,19 +1101,13 @@ impl BTreeCursor { // if we were to return Ok(CursorResult::Ok((None, None))), self.record would be None, which is incorrect, because we already know // that there is a record with a key greater than K (K' = K+2) in the parent interior cell. Hence, we need to move back up the tree // and get the next matching record from there. - match self.iteration_state { - IterationState::Iterating(IterationDirection::Forwards) => { + match op.iteration_direction() { + IterationDirection::Forwards => { return self.get_next_record(Some((key, op))); } - IterationState::Iterating(IterationDirection::Backwards) => { + IterationDirection::Backwards => { return self.get_prev_record(Some((key, op))); } - IterationState::Unset => { - unreachable!("iteration state must not be unset"); - } - IterationState::IterationNotAllowed => { - unreachable!("iteration state must not be IterationNotAllowed"); - } } } @@ -1179,6 +1157,7 @@ impl BTreeCursor { pub fn move_to(&mut self, key: SeekKey<'_>, cmp: SeekOp) -> Result> { assert!(self.mv_cursor.is_none()); tracing::trace!("move_to(key={:?} cmp={:?})", key, cmp); + tracing::trace!("backtrace: {}", std::backtrace::Backtrace::force_capture()); // For a table with N rows, we can find any row by row id in O(log(N)) time by starting at the root page and following the B-tree pointers. // B-trees consist of interior pages and leaf pages. Interior pages contain pointers to other pages, while leaf pages contain the actual row data. // @@ -1204,12 +1183,7 @@ impl BTreeCursor { // 6. If we find the cell, we return the record. Otherwise, we return an empty result. self.move_to_root(); - let iter_dir = match self.iteration_state { - IterationState::Iterating(IterationDirection::Backwards) => { - IterationDirection::Backwards - } - _ => IterationDirection::Forwards, - }; + let iter_dir = cmp.iteration_direction(); loop { let page = self.stack.top(); @@ -1265,29 +1239,12 @@ impl BTreeCursor { // No iteration (point query): // EQ | > or = | go left | Last = key is in left subtree // EQ | < | go right | Last = key is in right subtree - let target_leaf_page_is_in_left_subtree = match (self.iteration_state, cmp) - { - ( - IterationState::Iterating(IterationDirection::Forwards), - SeekOp::GT, - ) => *cell_rowid > rowid_key, - ( - IterationState::Iterating(IterationDirection::Forwards), - SeekOp::GE, - ) => *cell_rowid >= rowid_key, - ( - IterationState::Iterating(IterationDirection::Backwards), - SeekOp::LE, - ) => *cell_rowid >= rowid_key, - ( - IterationState::Iterating(IterationDirection::Backwards), - SeekOp::LT, - ) => *cell_rowid >= rowid_key || *cell_rowid == rowid_key - 1, - (_any, SeekOp::EQ) => *cell_rowid >= rowid_key, - _ => unreachable!( - "invalid combination of seek op and iteration state: {:?} {:?}", - cmp, self.iteration_state - ), + let target_leaf_page_is_in_left_subtree = match cmp { + SeekOp::GT => *cell_rowid > rowid_key, + SeekOp::GE => *cell_rowid >= rowid_key, + SeekOp::LE => *cell_rowid >= rowid_key, + SeekOp::LT => *cell_rowid + 1 >= rowid_key, + SeekOp::EQ => *cell_rowid >= rowid_key, }; if target_leaf_page_is_in_left_subtree { // If we found our target rowid in the left subtree, @@ -1375,36 +1332,13 @@ impl BTreeCursor { // EQ | > | go left | First = key must be in left subtree // EQ | = | go left | First = key could be exactly this one, or in left subtree // EQ | < | go right | First = key must be in right subtree - assert!( - self.iteration_state != IterationState::Unset, - "iteration state must have been set before move_to() is called" - ); - let target_leaf_page_is_in_left_subtree = match (cmp, self.iteration_state) - { - ( - SeekOp::GT, - IterationState::Iterating(IterationDirection::Forwards), - ) => interior_cell_vs_index_key.is_gt(), - ( - SeekOp::GE, - IterationState::Iterating(IterationDirection::Forwards), - ) => interior_cell_vs_index_key.is_ge(), - (SeekOp::EQ, IterationState::IterationNotAllowed) => { - interior_cell_vs_index_key.is_ge() - } - ( - SeekOp::LE, - IterationState::Iterating(IterationDirection::Backwards), - ) => interior_cell_vs_index_key.is_gt(), - ( - SeekOp::LT, - IterationState::Iterating(IterationDirection::Backwards), - ) => interior_cell_vs_index_key.is_ge(), - _ => unreachable!( - "invalid combination of seek op and iteration state: {:?} {:?}", - cmp, self.iteration_state - ), + let target_leaf_page_is_in_left_subtree = match cmp { + SeekOp::GT => interior_cell_vs_index_key.is_gt(), + SeekOp::GE => interior_cell_vs_index_key.is_ge(), + SeekOp::EQ => interior_cell_vs_index_key.is_ge(), + SeekOp::LE => interior_cell_vs_index_key.is_gt(), + SeekOp::LT => interior_cell_vs_index_key.is_ge(), }; if target_leaf_page_is_in_left_subtree { // we don't advance in case of forward iteration and index tree internal nodes because we will visit this node going up. @@ -1549,7 +1483,7 @@ impl BTreeCursor { // insert let overflow = { let contents = page.get().contents.as_mut().unwrap(); - debug!( + tracing::debug!( "insert_into_page(overflow, cell_count={})", contents.cell_count() ); @@ -1630,12 +1564,6 @@ impl BTreeCursor { let write_info = self.state.mut_write_info().unwrap(); write_info.state = WriteState::BalanceNonRoot; self.stack.pop(); - // with `move_to` we advance the current cell idx of TableInterior once we move to left subtree. - // On the other hand, with IndexInterior, we do not because we tranver in-order. In the latter case - // since we haven't consumed the cell we can avoid retreating the current cell index. - if matches!(current_page.get_contents().page_type(), PageType::TableLeaf) { - self.stack.retreat(); - } return_if_io!(self.balance_non_root()); } WriteState::BalanceNonRoot | WriteState::BalanceNonRootWaitLoadPages => { @@ -1660,16 +1588,20 @@ impl BTreeCursor { WriteState::BalanceStart => todo!(), WriteState::BalanceNonRoot => { let parent_page = self.stack.top(); - if parent_page.is_locked() { - return Ok(CursorResult::IO); - } return_if_locked_maybe_load!(self.pager, parent_page); + // If `move_to` moved to rightmost page, cell index will be out of bounds. Meaning cell_count+1. + // In any other case, `move_to` will stay in the correct index. + if self.stack.current_cell_index() as usize + == parent_page.get_contents().cell_count() + 1 + { + self.stack.retreat(); + } parent_page.set_dirty(); self.pager.add_dirty(parent_page.get().id); let parent_contents = parent_page.get().contents.as_ref().unwrap(); let page_to_balance_idx = self.stack.current_cell_index() as usize; - debug!( + tracing::debug!( "balance_non_root(parent_id={} page_to_balance_idx={})", parent_page.get().id, page_to_balance_idx @@ -1899,6 +1831,7 @@ impl BTreeCursor { let mut count_cells_in_old_pages = Vec::new(); let page_type = balance_info.pages_to_balance[0].get_contents().page_type(); + tracing::debug!("balance_non_root(page_type={:?})", page_type); let leaf_data = matches!(page_type, PageType::TableLeaf); let leaf = matches!(page_type, PageType::TableLeaf | PageType::IndexLeaf); for (i, old_page) in balance_info.pages_to_balance.iter().enumerate() { @@ -2228,7 +2161,7 @@ impl BTreeCursor { let new_last_page = pages_to_balance_new.last().unwrap(); new_last_page .get_contents() - .write_u32(PAGE_HEADER_OFFSET_RIGHTMOST_PTR, right_pointer); + .write_u32(offset::BTREE_RIGHTMOST_PTR, right_pointer); } // TODO: pointer map update (vacuum support) // Update divider cells in parent @@ -2247,7 +2180,7 @@ impl BTreeCursor { // Make this page's rightmost pointer point to pointer of divider cell before modification let previous_pointer_divider = read_u32(÷r_cell, 0); page.get_contents() - .write_u32(PAGE_HEADER_OFFSET_RIGHTMOST_PTR, previous_pointer_divider); + .write_u32(offset::BTREE_RIGHTMOST_PTR, previous_pointer_divider); // divider cell now points to this page new_divider_cell.extend_from_slice(&(page.get().id as u32).to_be_bytes()); // now copy the rest of the divider cell: @@ -2535,6 +2468,7 @@ impl BTreeCursor { // Let's now make a in depth check that we in fact added all possible cells somewhere and they are not lost for (page_idx, page) in pages_to_balance_new.iter().enumerate() { let contents = page.get_contents(); + debug_validate_cells!(contents, self.usable_space() as u16); // Cells are distributed in order for cell_idx in 0..contents.cell_count() { let (cell_start, cell_len) = contents.cell_get_raw_region( @@ -2871,6 +2805,7 @@ impl BTreeCursor { &mut child_contents.overflow_cells, &mut root_contents.overflow_cells, ); + root_contents.overflow_cells.clear(); // 2. Modify root let new_root_page_type = match root_contents.page_type() { @@ -2879,16 +2814,13 @@ impl BTreeCursor { other => other, } as u8; // set new page type - root_contents.write_u8(PAGE_HEADER_OFFSET_PAGE_TYPE, new_root_page_type); - root_contents.write_u32(PAGE_HEADER_OFFSET_RIGHTMOST_PTR, child.get().id as u32); - root_contents.write_u16( - PAGE_HEADER_OFFSET_CELL_CONTENT_AREA, - self.usable_space() as u16, - ); - root_contents.write_u16(PAGE_HEADER_OFFSET_CELL_COUNT, 0); - root_contents.write_u16(PAGE_HEADER_OFFSET_FIRST_FREEBLOCK, 0); + root_contents.write_u8(offset::BTREE_PAGE_TYPE, new_root_page_type); + root_contents.write_u32(offset::BTREE_RIGHTMOST_PTR, child.get().id as u32); + root_contents.write_u16(offset::BTREE_CELL_CONTENT_AREA, self.usable_space() as u16); + root_contents.write_u16(offset::BTREE_CELL_COUNT, 0); + root_contents.write_u16(offset::BTREE_FIRST_FREEBLOCK, 0); - root_contents.write_u8(PAGE_HEADER_OFFSET_FRAGMENTED_BYTES_COUNT, 0); + root_contents.write_u8(offset::BTREE_FRAGMENTED_BYTES_COUNT, 0); root_contents.overflow_cells.clear(); self.root_page = root.get().id; self.stack.clear(); @@ -2999,14 +2931,6 @@ impl BTreeCursor { } pub fn rewind(&mut self) -> Result> { - assert!( - matches!( - self.iteration_state, - IterationState::Unset | IterationState::Iterating(IterationDirection::Forwards) - ), - "iteration state must be unset or Iterating(Forwards) when rewind() is called" - ); - self.iteration_state = IterationState::Iterating(IterationDirection::Forwards); if self.mv_cursor.is_some() { let rowid = return_if_io!(self.get_next_record(None)); self.rowid.replace(rowid); @@ -3022,14 +2946,6 @@ impl BTreeCursor { } pub fn last(&mut self) -> Result> { - assert!( - matches!( - self.iteration_state, - IterationState::Unset | IterationState::Iterating(IterationDirection::Backwards) - ), - "iteration state must be unset or Iterating(Backwards) when last() is called" - ); - self.iteration_state = IterationState::Iterating(IterationDirection::Backwards); assert!(self.mv_cursor.is_none()); match self.move_to_rightmost()? { CursorResult::Ok(_) => self.prev(), @@ -3038,14 +2954,6 @@ impl BTreeCursor { } pub fn next(&mut self) -> Result> { - assert!( - matches!( - self.iteration_state, - IterationState::Iterating(IterationDirection::Forwards) - ), - "iteration state must be Iterating(Forwards) when next() is called, but it was {:?}", - self.iteration_state - ); let rowid = return_if_io!(self.get_next_record(None)); self.rowid.replace(rowid); self.empty_record.replace(rowid.is_none()); @@ -3053,13 +2961,6 @@ impl BTreeCursor { } pub fn prev(&mut self) -> Result> { - assert!( - matches!( - self.iteration_state, - IterationState::Iterating(IterationDirection::Backwards) - ), - "iteration state must be Iterating(Backwards) when prev() is called" - ); assert!(self.mv_cursor.is_none()); match self.get_prev_record(None)? { CursorResult::Ok(rowid) => { @@ -3086,38 +2987,6 @@ impl BTreeCursor { pub fn seek(&mut self, key: SeekKey<'_>, op: SeekOp) -> Result> { assert!(self.mv_cursor.is_none()); - match op { - SeekOp::GE | SeekOp::GT => { - if self.iteration_state == IterationState::Unset { - self.iteration_state = IterationState::Iterating(IterationDirection::Forwards); - } else { - assert!(matches!( - self.iteration_state, - IterationState::Iterating(IterationDirection::Forwards) - )); - } - } - SeekOp::LE | SeekOp::LT => { - if self.iteration_state == IterationState::Unset { - self.iteration_state = IterationState::Iterating(IterationDirection::Backwards); - } else { - assert!(matches!( - self.iteration_state, - IterationState::Iterating(IterationDirection::Backwards) - )); - } - } - SeekOp::EQ => { - if self.iteration_state == IterationState::Unset { - self.iteration_state = IterationState::IterationNotAllowed; - } else { - assert!(matches!( - self.iteration_state, - IterationState::IterationNotAllowed - )); - } - } - }; let rowid = return_if_io!(self.do_seek(key, op)); self.rowid.replace(rowid); self.empty_record.replace(rowid.is_none()); @@ -3133,6 +3002,7 @@ impl BTreeCursor { key: &BTreeKey, moved_before: bool, /* Indicate whether it's necessary to traverse to find the leaf page */ ) -> Result> { + tracing::trace!("insert"); match &self.mv_cursor { Some(mv_cursor) => match key.maybe_rowid() { Some(rowid) => { @@ -3144,8 +3014,8 @@ impl BTreeCursor { None => todo!("Support mvcc inserts with index btrees"), }, None => { + tracing::trace!("moved {}", moved_before); if !moved_before { - self.iteration_state = IterationState::Iterating(IterationDirection::Forwards); match key { BTreeKey::IndexKey(_) => { return_if_io!(self @@ -3833,25 +3703,8 @@ impl BTreeCursor { }; // if it all fits in local space and old_local_size is enough, do an in-place overwrite - if new_payload.len() <= old_local_size { - self.overwrite_content( - page_ref.clone(), - old_offset, - &new_payload, - 0, - new_payload.len(), - )?; - let remaining = old_local_size - new_payload.len(); - if remaining > 0 { - // fill the rest with zeros - self.overwrite_content( - page_ref.clone(), - old_offset + new_payload.len(), - &[0; 1], - 0, - remaining, - )?; - } + if new_payload.len() == old_local_size { + self.overwrite_content(page_ref.clone(), old_offset, &new_payload)?; Ok(CursorResult::Ok(())) } else { // doesn't fit, drop it and insert a new one @@ -3875,36 +3728,11 @@ impl BTreeCursor { page_ref: PageRef, dest_offset: usize, new_payload: &[u8], - src_offset: usize, - amount: usize, ) -> Result> { return_if_locked!(page_ref); - page_ref.set_dirty(); - self.pager.add_dirty(page_ref.get().id); let buf = page_ref.get().contents.as_mut().unwrap().as_ptr(); + buf[dest_offset..dest_offset + new_payload.len()].copy_from_slice(&new_payload); - // if new_payload doesn't have enough data, we fill with zeros - let n_data = new_payload.len().saturating_sub(src_offset); - if n_data == 0 { - // everything is zeros - for i in 0..amount { - if buf[dest_offset + i] != 0 { - buf[dest_offset + i] = 0; - } - } - } else { - let copy_len = n_data.min(amount); - // copy the overlapping portion - buf[dest_offset..dest_offset + copy_len] - .copy_from_slice(&new_payload[src_offset..src_offset + copy_len]); - - // if copy_len < amount => fill remainder with 0 - if copy_len < amount { - for i in copy_len..amount { - buf[dest_offset + i] = 0; - } - } - } Ok(CursorResult::Ok(())) } @@ -4069,7 +3897,7 @@ impl CellArray { fn find_free_cell(page_ref: &PageContent, usable_space: u16, amount: usize) -> Result { // NOTE: freelist is in ascending order of keys and pc // unuse_space is reserved bytes at the end of page, therefore we must substract from maxpc - let mut prev_pc = page_ref.offset + PAGE_HEADER_OFFSET_FIRST_FREEBLOCK; + let mut prev_pc = page_ref.offset + offset::BTREE_FIRST_FREEBLOCK; let mut pc = page_ref.first_freeblock() as usize; let maxpc = usable_space as usize - amount; @@ -4091,16 +3919,16 @@ fn find_free_cell(page_ref: &PageContent, usable_space: u16, amount: usize) -> R return Ok(0); } // Delete the slot from freelist and update the page's fragment count. - page_ref.write_u16(prev_pc, next); + page_ref.write_u16_no_offset(prev_pc, next); let frag = page_ref.num_frag_free_bytes() + new_size as u8; - page_ref.write_u8(PAGE_HEADER_OFFSET_FRAGMENTED_BYTES_COUNT, frag); + page_ref.write_u8(offset::BTREE_FRAGMENTED_BYTES_COUNT, frag); return Ok(pc); } else if new_size + pc > maxpc { return_corrupt!("Free block extends beyond page end"); } else { // Requested amount fits inside the current free slot so we reduce its size // to account for newly allocated space. - page_ref.write_u16(pc + 2, new_size as u16); + page_ref.write_u16_no_offset(pc + 2, new_size as u16); return Ok(pc + new_size); } } @@ -4123,18 +3951,18 @@ fn find_free_cell(page_ref: &PageContent, usable_space: u16, amount: usize) -> R pub fn btree_init_page(page: &PageRef, page_type: PageType, offset: usize, usable_space: u16) { // setup btree page let contents = page.get(); - debug!("btree_init_page(id={}, offset={})", contents.id, offset); + tracing::debug!("btree_init_page(id={}, offset={})", contents.id, offset); let contents = contents.contents.as_mut().unwrap(); contents.offset = offset; let id = page_type as u8; - contents.write_u8(PAGE_HEADER_OFFSET_PAGE_TYPE, id); - contents.write_u16(PAGE_HEADER_OFFSET_FIRST_FREEBLOCK, 0); - contents.write_u16(PAGE_HEADER_OFFSET_CELL_COUNT, 0); + contents.write_u8(offset::BTREE_PAGE_TYPE, id); + contents.write_u16(offset::BTREE_FIRST_FREEBLOCK, 0); + contents.write_u16(offset::BTREE_CELL_COUNT, 0); - contents.write_u16(PAGE_HEADER_OFFSET_CELL_CONTENT_AREA, usable_space); + contents.write_u16(offset::BTREE_CELL_CONTENT_AREA, usable_space); - contents.write_u8(PAGE_HEADER_OFFSET_FRAGMENTED_BYTES_COUNT, 0); - contents.write_u32(PAGE_HEADER_OFFSET_RIGHTMOST_PTR, 0); + contents.write_u8(offset::BTREE_FRAGMENTED_BYTES_COUNT, 0); + contents.write_u32(offset::BTREE_RIGHTMOST_PTR, 0); } fn to_static_buf(buf: &mut [u8]) -> &'static mut [u8] { @@ -4231,7 +4059,7 @@ fn edit_page( )?; debug_validate_cells!(page, usable_space); // TODO: noverflow - page.write_u16(PAGE_HEADER_OFFSET_CELL_COUNT, number_new_cells as u16); + page.write_u16(offset::BTREE_CELL_COUNT, number_new_cells as u16); Ok(()) } @@ -4261,7 +4089,7 @@ fn page_free_array( let offset = (cell_pointer.start as usize - buf_range.start as usize) as u16; let len = (cell_pointer.end as usize - cell_pointer.start as usize) as u16; free_cell_range(page, offset, len, usable_space)?; - page.write_u16(PAGE_HEADER_OFFSET_CELL_COUNT, page.cell_count() as u16 - 1); + page.write_u16(offset::BTREE_CELL_COUNT, page.cell_count() as u16 - 1); number_of_cells_removed += 1; } } @@ -4368,10 +4196,14 @@ fn free_cell_range( } } if removed_fragmentation > page.num_frag_free_bytes() { - return_corrupt!("Invalid fragmentation count"); + return_corrupt!(format!( + "Invalid fragmentation count. Had {} and removed {}", + page.num_frag_free_bytes(), + removed_fragmentation + )); } let frag = page.num_frag_free_bytes() - removed_fragmentation; - page.write_u8(PAGE_HEADER_OFFSET_FRAGMENTED_BYTES_COUNT, frag); + page.write_u8(offset::BTREE_FRAGMENTED_BYTES_COUNT, frag); pc }; @@ -4379,11 +4211,11 @@ fn free_cell_range( if offset < page.cell_content_area() { return_corrupt!("Free block before content area"); } - if pointer_to_pc != page.offset as u16 + PAGE_HEADER_OFFSET_FIRST_FREEBLOCK as u16 { + if pointer_to_pc != page.offset as u16 + offset::BTREE_FIRST_FREEBLOCK as u16 { return_corrupt!("Invalid content area merge"); } - page.write_u16(PAGE_HEADER_OFFSET_FIRST_FREEBLOCK, pc); - page.write_u16(PAGE_HEADER_OFFSET_CELL_CONTENT_AREA, end); + page.write_u16(offset::BTREE_FIRST_FREEBLOCK, pc); + page.write_u16(offset::BTREE_CELL_CONTENT_AREA, end); } else { page.write_u16_no_offset(pointer_to_pc as usize, offset); page.write_u16_no_offset(offset as usize, pc); @@ -4448,10 +4280,10 @@ fn defragment_page(page: &PageContent, usable_space: u16) { assert!(cbrk >= first_cell); // set new first byte of cell content - page.write_u16(PAGE_HEADER_OFFSET_CELL_CONTENT_AREA, cbrk); + page.write_u16(offset::BTREE_CELL_CONTENT_AREA, cbrk); // set free block to 0, unused spaced can be retrieved from gap between cell pointer end and content start - page.write_u16(PAGE_HEADER_OFFSET_FIRST_FREEBLOCK, 0); - page.write_u8(PAGE_HEADER_OFFSET_FRAGMENTED_BYTES_COUNT, 0); + page.write_u16(offset::BTREE_FIRST_FREEBLOCK, 0); + page.write_u8(offset::BTREE_FRAGMENTED_BYTES_COUNT, 0); debug_validate_cells!(page, usable_space); } @@ -4544,7 +4376,7 @@ fn insert_into_cell( // update cell count let new_n_cells = (page.cell_count() + 1) as u16; - page.write_u16(PAGE_HEADER_OFFSET_CELL_COUNT, new_n_cells); + page.write_u16(offset::BTREE_CELL_COUNT, new_n_cells); debug_validate_cells!(page, usable_space); Ok(()) } @@ -4656,12 +4488,12 @@ fn allocate_cell_space(page_ref: &PageContent, amount: u16, usable_space: u16) - if gap + 2 + amount > top { // defragment defragment_page(page_ref, usable_space); - top = page_ref.read_u16(PAGE_HEADER_OFFSET_CELL_CONTENT_AREA) as usize; + top = page_ref.read_u16(offset::BTREE_CELL_CONTENT_AREA) as usize; } top -= amount; - page_ref.write_u16(PAGE_HEADER_OFFSET_CELL_CONTENT_AREA, top as u16); + page_ref.write_u16(offset::BTREE_CELL_CONTENT_AREA, top as u16); assert!(top + amount <= usable_space as usize); Ok(top as u16) @@ -4694,7 +4526,7 @@ fn fill_cell_payload( } let payload_overflow_threshold_max = payload_overflow_threshold_max(page_type, usable_space); - debug!( + tracing::debug!( "fill_cell_payload(record_size={}, payload_overflow_threshold_max={})", record_buf.len(), payload_overflow_threshold_max @@ -4820,11 +4652,11 @@ fn drop_cell(page: &mut PageContent, cell_idx: usize, usable_space: u16) -> Resu if page.cell_count() > 1 { shift_pointers_left(page, cell_idx); } else { - page.write_u16(PAGE_HEADER_OFFSET_CELL_CONTENT_AREA, usable_space); - page.write_u16(PAGE_HEADER_OFFSET_FIRST_FREEBLOCK, 0); - page.write_u8(PAGE_HEADER_OFFSET_FRAGMENTED_BYTES_COUNT, 0); + page.write_u16(offset::BTREE_CELL_CONTENT_AREA, usable_space); + page.write_u16(offset::BTREE_FIRST_FREEBLOCK, 0); + page.write_u8(offset::BTREE_FRAGMENTED_BYTES_COUNT, 0); } - page.write_u16(PAGE_HEADER_OFFSET_CELL_COUNT, page.cell_count() as u16 - 1); + page.write_u16(offset::BTREE_CELL_COUNT, page.cell_count() as u16 - 1); debug_validate_cells!(page, usable_space); Ok(()) } @@ -4844,31 +4676,28 @@ fn shift_pointers_left(page: &mut PageContent, cell_idx: usize) { #[cfg(test)] mod tests { - use rand::thread_rng; - use rand::Rng; - use rand_chacha::rand_core::RngCore; - use rand_chacha::rand_core::SeedableRng; - use rand_chacha::ChaCha8Rng; + use rand::{thread_rng, Rng}; + use rand_chacha::{ + rand_core::{RngCore, SeedableRng}, + ChaCha8Rng, + }; use test_log::test; use super::*; - use crate::fast_lock::SpinLock; - use crate::io::{Buffer, Completion, MemoryIO, OpenFlags, IO}; - use crate::storage::database::DatabaseFile; - use crate::storage::page_cache::DumbLruPageCache; - use crate::storage::sqlite3_ondisk; - use crate::storage::sqlite3_ondisk::DatabaseHeader; - use crate::types::Text; - use crate::vdbe::Register; - use crate::Connection; - use crate::{BufferPool, DatabaseStorage, WalFile, WalFileShared, WriteCompletion}; - use std::cell::RefCell; - use std::collections::HashSet; - use std::mem::transmute; - use std::ops::Deref; - use std::panic; - use std::rc::Rc; - use std::sync::Arc; + use crate::{ + fast_lock::SpinLock, + io::{Buffer, Completion, MemoryIO, OpenFlags, IO}, + storage::{ + database::DatabaseFile, page_cache::DumbLruPageCache, sqlite3_ondisk, + sqlite3_ondisk::DatabaseHeader, + }, + types::Text, + vdbe::Register, + BufferPool, Connection, DatabaseStorage, WalFile, WalFileShared, WriteCompletion, + }; + use std::{ + cell::RefCell, collections::HashSet, mem::transmute, ops::Deref, panic, rc::Rc, sync::Arc, + }; use tempfile::TempDir; @@ -4893,14 +4722,13 @@ mod tests { let page = Arc::new(Page::new(id)); let drop_fn = Rc::new(|_| {}); - let inner = PageContent { - offset: 0, - buffer: Arc::new(RefCell::new(Buffer::new( + let inner = PageContent::new( + 0, + Arc::new(RefCell::new(Buffer::new( BufferData::new(vec![0; 4096]), drop_fn, ))), - overflow_cells: Vec::new(), - }; + ); page.get().contents.replace(inner); btree_init_page(&page, PageType::TableLeaf, 0, 4096); @@ -5338,8 +5166,6 @@ mod tests { // FIXME: add sorted vector instead, should be okay for small amounts of keys for now :P, too lazy to fix right now keys.sort(); cursor.move_to_root(); - // hack to allow bypassing our internal invariant of not allowing cursor iteration after SeekOp::EQ - cursor.iteration_state = IterationState::Iterating(IterationDirection::Forwards); let mut valid = true; for key in keys.iter() { tracing::trace!("seeking key: {}", key); @@ -5351,7 +5177,6 @@ mod tests { break; } } - cursor.iteration_state = IterationState::Unset; // let's validate btree too so that we undertsand where the btree failed if matches!(validate_btree(pager.clone(), root_page), (_, false)) || !valid { let btree_after = format_btree(pager.clone(), root_page, 0); @@ -5369,8 +5194,6 @@ mod tests { } keys.sort(); cursor.move_to_root(); - // hack to allow bypassing our internal invariant of not allowing cursor iteration after SeekOp::EQ - cursor.iteration_state = IterationState::Iterating(IterationDirection::Forwards); for key in keys.iter() { tracing::trace!("seeking key: {}", key); run_until_done(|| cursor.next(), pager.deref()).unwrap(); @@ -5686,7 +5509,7 @@ mod tests { let contents = root_page.get().contents.as_mut().unwrap(); // Set rightmost pointer to page4 - contents.write_u32(PAGE_HEADER_OFFSET_RIGHTMOST_PTR, page4.get().id as u32); + contents.write_u32(offset::BTREE_RIGHTMOST_PTR, page4.get().id as u32); // Create a cell with pointer to page3 let cell_content = vec![ @@ -6242,7 +6065,7 @@ mod tests { run_until_done( || { let key = SeekKey::TableRowId(i as u64); - cursor.seek(key, SeekOp::EQ) + cursor.move_to(key, SeekOp::EQ) }, pager.deref(), ) @@ -6322,7 +6145,7 @@ mod tests { run_until_done( || { let key = SeekKey::TableRowId(i as u64); - cursor.seek(key, SeekOp::EQ) + cursor.move_to(key, SeekOp::EQ) }, pager.deref(), ) @@ -6404,7 +6227,7 @@ mod tests { run_until_done( || { let key = SeekKey::TableRowId(i as u64); - cursor.seek(key, SeekOp::EQ) + cursor.move_to(key, SeekOp::EQ) }, pager.deref(), ) diff --git a/core/storage/pager.rs b/core/storage/pager.rs index 70af1c8d2..9d6d90c00 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -637,11 +637,7 @@ pub fn allocate_page(page_id: usize, buffer_pool: &Rc, offset: usize }); let buffer = Arc::new(RefCell::new(Buffer::new(buffer, drop_fn))); page.set_loaded(); - page.get().contents = Some(PageContent { - offset, - buffer, - overflow_cells: Vec::new(), - }); + page.get().contents = Some(PageContent::new(offset, buffer)); } page } diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index b8373514f..10251ca51 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -413,6 +413,14 @@ impl Clone for PageContent { } impl PageContent { + pub fn new(offset: usize, buffer: Arc>) -> Self { + Self { + offset, + buffer, + overflow_cells: Vec::new(), + } + } + pub fn page_type(&self) -> PageType { self.read_u8(0).try_into().unwrap() } @@ -741,11 +749,7 @@ fn finish_read_page( } else { 0 }; - let inner = PageContent { - offset: pos, - buffer: buffer_ref.clone(), - overflow_cells: Vec::new(), - }; + let inner = PageContent::new(pos, buffer_ref.clone()); { page.get().contents.replace(inner); page.set_uptodate(); diff --git a/core/storage/wal.rs b/core/storage/wal.rs index b56246a78..2d1f17776 100644 --- a/core/storage/wal.rs +++ b/core/storage/wal.rs @@ -698,11 +698,10 @@ impl WalFile { let drop_fn = Rc::new(move |buf| { buffer_pool.put(buf); }); - checkpoint_page.get().contents = Some(PageContent { - offset: 0, - buffer: Arc::new(RefCell::new(Buffer::new(buffer, drop_fn))), - overflow_cells: Vec::new(), - }); + checkpoint_page.get().contents = Some(PageContent::new( + 0, + Arc::new(RefCell::new(Buffer::new(buffer, drop_fn))), + )); } Self { io, diff --git a/core/translate/emitter.rs b/core/translate/emitter.rs index 5049bb738..21e311bba 100644 --- a/core/translate/emitter.rs +++ b/core/translate/emitter.rs @@ -397,10 +397,12 @@ fn emit_delete_insns( let cursor_id = match &table_reference.op { Operation::Scan { .. } => program.resolve_cursor_id(&table_reference.identifier), Operation::Search(search) => match search { - Search::RowidEq { .. } | Search::RowidSearch { .. } => { + Search::RowidEq { .. } | Search::Seek { index: None, .. } => { program.resolve_cursor_id(&table_reference.identifier) } - Search::IndexSearch { index, .. } => program.resolve_cursor_id(&index.name), + Search::Seek { + index: Some(index), .. + } => program.resolve_cursor_id(&index.name), }, _ => return Ok(()), }; @@ -537,12 +539,14 @@ fn emit_update_insns( table_ref.virtual_table().is_some(), ), Operation::Search(search) => match search { - &Search::RowidEq { .. } | Search::RowidSearch { .. } => ( + &Search::RowidEq { .. } | Search::Seek { index: None, .. } => ( program.resolve_cursor_id(&table_ref.identifier), None, false, ), - Search::IndexSearch { index, .. } => ( + Search::Seek { + index: Some(index), .. + } => ( program.resolve_cursor_id(&table_ref.identifier), Some((index.clone(), program.resolve_cursor_id(&index.name))), false, diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 7bb0dc228..958005259 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -1309,6 +1309,33 @@ pub fn translate_expr( }); Ok(target_register) } + ScalarFunc::TimeDiff => { + let args = expect_arguments_exact!(args, 2, srf); + + let start_reg = program.alloc_registers(2); + translate_expr( + program, + referenced_tables, + &args[0], + start_reg, + resolver, + )?; + translate_expr( + program, + referenced_tables, + &args[1], + start_reg + 1, + resolver, + )?; + + program.emit_insn(Insn::Function { + constant_mask: 0, + start_reg, + dest: target_register, + func: func_ctx, + }); + Ok(target_register) + } ScalarFunc::TotalChanges => { if args.is_some() { crate::bail_parse_error!( @@ -1598,6 +1625,58 @@ pub fn translate_expr( }); Ok(target_register) } + ScalarFunc::Likelihood => { + let args = if let Some(args) = args { + if args.len() != 2 { + crate::bail_parse_error!( + "likelihood() function must have exactly 2 arguments", + ); + } + args + } else { + crate::bail_parse_error!("likelihood() function with no arguments",); + }; + + if let ast::Expr::Literal(ast::Literal::Numeric(ref value)) = args[1] { + if let Ok(probability) = value.parse::() { + if !(0.0..=1.0).contains(&probability) { + crate::bail_parse_error!( + "second argument of likelihood() must be between 0.0 and 1.0", + ); + } + if !value.contains('.') { + crate::bail_parse_error!( + "second argument of likelihood() must be a floating point number with decimal point", + ); + } + } else { + crate::bail_parse_error!( + "second argument of likelihood() must be a floating point constant", + ); + } + } else { + crate::bail_parse_error!( + "second argument of likelihood() must be a numeric literal", + ); + } + + let start_reg = program.alloc_register(); + translate_and_mark( + program, + referenced_tables, + &args[0], + start_reg, + resolver, + )?; + + program.emit_insn(Insn::Copy { + src_reg: start_reg, + dst_reg: target_register, + amount: 0, + }); + + Ok(target_register) + } } } Func::Math(math_func) => match math_func.arity() { diff --git a/core/translate/main_loop.rs b/core/translate/main_loop.rs index 51bd05382..8409e31c9 100644 --- a/core/translate/main_loop.rs +++ b/core/translate/main_loop.rs @@ -1,8 +1,7 @@ -use limbo_sqlite3_parser::ast; - use crate::{ schema::Table, translate::result_row::emit_select_result, + types::SeekOp, vdbe::{ builder::{CursorType, ProgramBuilder}, insn::{CmpInsFlags, Insn}, @@ -18,8 +17,8 @@ use super::{ group_by::is_column_in_group_by, order_by::{order_by_sorter_insert, sorter_insert}, plan::{ - IterationDirection, Operation, Search, SelectPlan, SelectQueryType, TableReference, - WhereTerm, + IterationDirection, Operation, Search, SeekDef, SelectPlan, SelectQueryType, + TableReference, WhereTerm, }, }; @@ -166,7 +165,10 @@ pub fn init_loop( } } - if let Search::IndexSearch { index, .. } = search { + if let Search::Seek { + index: Some(index), .. + } = search + { let index_cursor_id = program.alloc_cursor_id( Some(index.name.clone()), CursorType::BTreeIndex(index.clone()), @@ -381,268 +383,42 @@ pub fn open_loop( }); } else { // Otherwise, it's an index/rowid scan, i.e. first a seek is performed and then a scan until the comparison expression is not satisfied anymore. - let index_cursor_id = if let Search::IndexSearch { index, .. } = search { + let index_cursor_id = if let Search::Seek { + index: Some(index), .. + } = search + { Some(program.resolve_cursor_id(&index.name)) } else { None }; - let (cmp_expr, cmp_op, iter_dir) = match search { - Search::IndexSearch { - cmp_expr, - cmp_op, - iter_dir, - .. - } => (cmp_expr, cmp_op, iter_dir), - Search::RowidSearch { - cmp_expr, - cmp_op, - iter_dir, - } => (cmp_expr, cmp_op, iter_dir), - Search::RowidEq { .. } => unreachable!(), + let is_index = index_cursor_id.is_some(); + let seek_cursor_id = index_cursor_id.unwrap_or(table_cursor_id); + let Search::Seek { seek_def, .. } = search else { + unreachable!("Rowid equality point lookup should have been handled above"); }; - // There are a few steps in an index seek: - // 1. Emit the comparison expression for the rowid/index seek. For example, if we a clause 'WHERE index_key >= 10', we emit the comparison expression 10 into cmp_reg. - // - // 2. Emit the seek instruction. SeekGE and SeekGT are used in forwards iteration, SeekLT and SeekLE are used in backwards iteration. - // All of the examples below assume an ascending index, because we do not support descending indexes yet. - // If we are scanning the ascending index: - // - Forwards, and have a GT/GE/EQ comparison, the comparison expression from step 1 is used as the value to seek to, because that is the lowest possible value that satisfies the clause. - // - Forwards, and have a LT/LE comparison, NULL is used as the comparison expression because we actually want to start scanning from the beginning of the index. - // - Backwards, and have a GT/GE comparison, no Seek instruction is emitted and we emit LastAsync instead, because we want to start scanning from the end of the index. - // - Backwards, and have a LT/LE/EQ comparison, we emit a Seek instruction with the comparison expression from step 1 as the value to seek to, since that is the highest possible - // value that satisfies the clause. - let seek_cmp_reg = program.alloc_register(); - let mut comparison_expr_translated = false; - match (cmp_op, iter_dir) { - // Forwards, GT/GE/EQ -> use the comparison expression (i.e. seek to the first key where the cmp expr is satisfied, and then scan forwards) - ( - ast::Operator::Equals - | ast::Operator::Greater - | ast::Operator::GreaterEquals, - IterationDirection::Forwards, - ) => { - translate_expr( - program, - Some(tables), - &cmp_expr.expr, - seek_cmp_reg, - &t_ctx.resolver, - )?; - comparison_expr_translated = true; - match cmp_op { - ast::Operator::Equals | ast::Operator::GreaterEquals => { - program.emit_insn(Insn::SeekGE { - is_index: index_cursor_id.is_some(), - cursor_id: index_cursor_id.unwrap_or(table_cursor_id), - start_reg: seek_cmp_reg, - num_regs: 1, - target_pc: loop_end, - }); - } - ast::Operator::Greater => { - program.emit_insn(Insn::SeekGT { - is_index: index_cursor_id.is_some(), - cursor_id: index_cursor_id.unwrap_or(table_cursor_id), - start_reg: seek_cmp_reg, - num_regs: 1, - target_pc: loop_end, - }); - } - _ => unreachable!(), - } - } - // Forwards, LT/LE -> use NULL (i.e. start from the beginning of the index) - ( - ast::Operator::Less | ast::Operator::LessEquals, - IterationDirection::Forwards, - ) => { - program.emit_insn(Insn::Null { - dest: seek_cmp_reg, - dest_end: None, - }); - program.emit_insn(Insn::SeekGT { - is_index: index_cursor_id.is_some(), - cursor_id: index_cursor_id.unwrap_or(table_cursor_id), - start_reg: seek_cmp_reg, - num_regs: 1, - target_pc: loop_end, - }); - } - // Backwards, GT/GE -> no seek, emit LastAsync (i.e. start from the end of the index) - ( - ast::Operator::Greater | ast::Operator::GreaterEquals, - IterationDirection::Backwards, - ) => { - program.emit_insn(Insn::LastAsync { - cursor_id: index_cursor_id.unwrap_or(table_cursor_id), - }); - program.emit_insn(Insn::LastAwait { - cursor_id: index_cursor_id.unwrap_or(table_cursor_id), - pc_if_empty: loop_end, - }); - } - // Backwards, LT/LE/EQ -> use the comparison expression (i.e. seek from the end of the index until the cmp expr is satisfied, and then scan backwards) - ( - ast::Operator::Less | ast::Operator::LessEquals | ast::Operator::Equals, - IterationDirection::Backwards, - ) => { - translate_expr( - program, - Some(tables), - &cmp_expr.expr, - seek_cmp_reg, - &t_ctx.resolver, - )?; - comparison_expr_translated = true; - match cmp_op { - ast::Operator::Less => { - program.emit_insn(Insn::SeekLT { - is_index: index_cursor_id.is_some(), - cursor_id: index_cursor_id.unwrap_or(table_cursor_id), - start_reg: seek_cmp_reg, - num_regs: 1, - target_pc: loop_end, - }); - } - ast::Operator::LessEquals | ast::Operator::Equals => { - program.emit_insn(Insn::SeekLE { - is_index: index_cursor_id.is_some(), - cursor_id: index_cursor_id.unwrap_or(table_cursor_id), - start_reg: seek_cmp_reg, - num_regs: 1, - target_pc: loop_end, - }); - } - _ => unreachable!(), - } - } - _ => unreachable!(), - }; - - program.resolve_label(loop_start, program.offset()); - - let scan_terminating_cmp_reg = if comparison_expr_translated { - seek_cmp_reg - } else { - let reg = program.alloc_register(); - translate_expr( - program, - Some(tables), - &cmp_expr.expr, - reg, - &t_ctx.resolver, - )?; - reg - }; - - // 3. Emit a scan-terminating comparison instruction (IdxGT, IdxGE, IdxLT, IdxLE if index; GT, GE, LT, LE if btree rowid scan). - // Here the comparison expression from step 1 is compared to the current index key and the loop is exited if the comparison is true. - // The comparison operator used in the Idx__ instruction is the inverse of the WHERE clause comparison operator. - // For example, if we are scanning forwards and have a clause 'WHERE index_key < 10', we emit IdxGE(10) since >=10 is the first key where our condition is not satisfied anymore. - match (cmp_op, iter_dir) { - // Forwards, <= -> terminate if > - ( - ast::Operator::Equals | ast::Operator::LessEquals, - IterationDirection::Forwards, - ) => { - if let Some(index_cursor_id) = index_cursor_id { - program.emit_insn(Insn::IdxGT { - cursor_id: index_cursor_id, - start_reg: scan_terminating_cmp_reg, - num_regs: 1, - target_pc: loop_end, - }); - } else { - let rowid_reg = program.alloc_register(); - program.emit_insn(Insn::RowId { - cursor_id: table_cursor_id, - dest: rowid_reg, - }); - program.emit_insn(Insn::Gt { - lhs: rowid_reg, - rhs: scan_terminating_cmp_reg, - target_pc: loop_end, - flags: CmpInsFlags::default(), - }); - } - } - // Forwards, < -> terminate if >= - (ast::Operator::Less, IterationDirection::Forwards) => { - if let Some(index_cursor_id) = index_cursor_id { - program.emit_insn(Insn::IdxGE { - cursor_id: index_cursor_id, - start_reg: scan_terminating_cmp_reg, - num_regs: 1, - target_pc: loop_end, - }); - } else { - let rowid_reg = program.alloc_register(); - program.emit_insn(Insn::RowId { - cursor_id: table_cursor_id, - dest: rowid_reg, - }); - program.emit_insn(Insn::Ge { - lhs: rowid_reg, - rhs: scan_terminating_cmp_reg, - target_pc: loop_end, - flags: CmpInsFlags::default(), - }); - } - } - // Backwards, >= -> terminate if < - ( - ast::Operator::Equals | ast::Operator::GreaterEquals, - IterationDirection::Backwards, - ) => { - if let Some(index_cursor_id) = index_cursor_id { - program.emit_insn(Insn::IdxLT { - cursor_id: index_cursor_id, - start_reg: scan_terminating_cmp_reg, - num_regs: 1, - target_pc: loop_end, - }); - } else { - let rowid_reg = program.alloc_register(); - program.emit_insn(Insn::RowId { - cursor_id: table_cursor_id, - dest: rowid_reg, - }); - program.emit_insn(Insn::Lt { - lhs: rowid_reg, - rhs: scan_terminating_cmp_reg, - target_pc: loop_end, - flags: CmpInsFlags::default(), - }); - } - } - // Backwards, > -> terminate if <= - (ast::Operator::Greater, IterationDirection::Backwards) => { - if let Some(index_cursor_id) = index_cursor_id { - program.emit_insn(Insn::IdxLE { - cursor_id: index_cursor_id, - start_reg: scan_terminating_cmp_reg, - num_regs: 1, - target_pc: loop_end, - }); - } else { - let rowid_reg = program.alloc_register(); - program.emit_insn(Insn::RowId { - cursor_id: table_cursor_id, - dest: rowid_reg, - }); - program.emit_insn(Insn::Le { - lhs: rowid_reg, - rhs: scan_terminating_cmp_reg, - target_pc: loop_end, - flags: CmpInsFlags::default(), - }); - } - } - // Forwards, > and >= -> we already did a seek to the first key where the cmp expr is satisfied, so we dont have a terminating condition - // Backwards, < and <= -> we already did a seek to the last key where the cmp expr is satisfied, so we dont have a terminating condition - _ => {} - } + let start_reg = program.alloc_registers(seek_def.key.len()); + emit_seek( + program, + tables, + seek_def, + t_ctx, + seek_cursor_id, + start_reg, + loop_end, + is_index, + )?; + emit_seek_termination( + program, + tables, + seek_def, + t_ctx, + seek_cursor_id, + start_reg, + loop_start, + loop_end, + is_index, + )?; if let Some(index_cursor_id) = index_cursor_id { // Don't do a btree table seek until it's actually necessary to read from the table. @@ -1002,12 +778,19 @@ pub fn close_loop( // Rowid equality point lookups are handled with a SeekRowid instruction which does not loop, so there is no need to emit a NextAsync instruction. if !matches!(search, Search::RowidEq { .. }) { let (cursor_id, iter_dir) = match search { - Search::IndexSearch { - index, iter_dir, .. - } => (program.resolve_cursor_id(&index.name), *iter_dir), - Search::RowidSearch { iter_dir, .. } => { - (program.resolve_cursor_id(&table.identifier), *iter_dir) - } + Search::Seek { + index: Some(index), + seek_def, + .. + } => (program.resolve_cursor_id(&index.name), seek_def.iter_dir), + Search::Seek { + index: None, + seek_def, + .. + } => ( + program.resolve_cursor_id(&table.identifier), + seek_def.iter_dir, + ), Search::RowidEq { .. } => unreachable!(), }; @@ -1074,3 +857,201 @@ pub fn close_loop( } Ok(()) } + +/// Emits instructions for an index seek. See e.g. [crate::translate::plan::SeekDef] +/// for more details about the seek definition. +/// +/// Index seeks always position the cursor to the first row that matches the seek key, +/// and then continue to emit rows until the termination condition is reached, +/// see [emit_seek_termination] below. +/// +/// If either 1. the seek finds no rows or 2. the termination condition is reached, +/// the loop for that given table/index is fully exited. +#[allow(clippy::too_many_arguments)] +fn emit_seek( + program: &mut ProgramBuilder, + tables: &[TableReference], + seek_def: &SeekDef, + t_ctx: &mut TranslateCtx, + seek_cursor_id: usize, + start_reg: usize, + loop_end: BranchOffset, + is_index: bool, +) -> Result<()> { + let Some(seek) = seek_def.seek.as_ref() else { + assert!(seek_def.iter_dir == IterationDirection::Backwards, "A SeekDef without a seek operation should only be used in backwards iteration direction"); + program.emit_insn(Insn::LastAsync { + cursor_id: seek_cursor_id, + }); + program.emit_insn(Insn::LastAwait { + cursor_id: seek_cursor_id, + pc_if_empty: loop_end, + }); + return Ok(()); + }; + // We allocated registers for the full index key, but our seek key might not use the full index key. + // Later on for the termination condition we will overwrite the NULL registers. + // See [crate::translate::optimizer::build_seek_def] for more details about in which cases we do and don't use the full index key. + for i in 0..seek_def.key.len() { + let reg = start_reg + i; + if i >= seek.len { + if seek_def.null_pad_unset_cols() { + program.emit_insn(Insn::Null { + dest: reg, + dest_end: None, + }); + } + } else { + translate_expr( + program, + Some(tables), + &seek_def.key[i], + reg, + &t_ctx.resolver, + )?; + } + } + let num_regs = if seek_def.null_pad_unset_cols() { + seek_def.key.len() + } else { + seek.len + }; + match seek.op { + SeekOp::GE => program.emit_insn(Insn::SeekGE { + is_index, + cursor_id: seek_cursor_id, + start_reg, + num_regs, + target_pc: loop_end, + }), + SeekOp::GT => program.emit_insn(Insn::SeekGT { + is_index, + cursor_id: seek_cursor_id, + start_reg, + num_regs, + target_pc: loop_end, + }), + SeekOp::LE => program.emit_insn(Insn::SeekLE { + is_index, + cursor_id: seek_cursor_id, + start_reg, + num_regs, + target_pc: loop_end, + }), + SeekOp::LT => program.emit_insn(Insn::SeekLT { + is_index, + cursor_id: seek_cursor_id, + start_reg, + num_regs, + target_pc: loop_end, + }), + SeekOp::EQ => panic!("An index seek is never EQ"), + }; + + Ok(()) +} + +/// Emits instructions for an index seek termination. See e.g. [crate::translate::plan::SeekDef] +/// for more details about the seek definition. +/// +/// Index seeks always position the cursor to the first row that matches the seek key +/// (see [emit_seek] above), and then continue to emit rows until the termination condition +/// (if any) is reached. +/// +/// If the termination condition is not present, the cursor is fully scanned to the end. +#[allow(clippy::too_many_arguments)] +fn emit_seek_termination( + program: &mut ProgramBuilder, + tables: &[TableReference], + seek_def: &SeekDef, + t_ctx: &mut TranslateCtx, + seek_cursor_id: usize, + start_reg: usize, + loop_start: BranchOffset, + loop_end: BranchOffset, + is_index: bool, +) -> Result<()> { + let Some(termination) = seek_def.termination.as_ref() else { + program.resolve_label(loop_start, program.offset()); + return Ok(()); + }; + let num_regs = termination.len; + // If the seek termination was preceded by a seek (which happens in most cases), + // we can re-use the registers that were allocated for the full index key. + let start_idx = seek_def.seek.as_ref().map_or(0, |seek| seek.len); + for i in start_idx..termination.len { + let reg = start_reg + i; + translate_expr( + program, + Some(tables), + &seek_def.key[i], + reg, + &t_ctx.resolver, + )?; + } + program.resolve_label(loop_start, program.offset()); + let mut rowid_reg = None; + if !is_index { + rowid_reg = Some(program.alloc_register()); + program.emit_insn(Insn::RowId { + cursor_id: seek_cursor_id, + dest: rowid_reg.unwrap(), + }); + } + + match (is_index, termination.op) { + (true, SeekOp::GE) => program.emit_insn(Insn::IdxGE { + cursor_id: seek_cursor_id, + start_reg, + num_regs, + target_pc: loop_end, + }), + (true, SeekOp::GT) => program.emit_insn(Insn::IdxGT { + cursor_id: seek_cursor_id, + start_reg, + num_regs, + target_pc: loop_end, + }), + (true, SeekOp::LE) => program.emit_insn(Insn::IdxLE { + cursor_id: seek_cursor_id, + start_reg, + num_regs, + target_pc: loop_end, + }), + (true, SeekOp::LT) => program.emit_insn(Insn::IdxLT { + cursor_id: seek_cursor_id, + start_reg, + num_regs, + target_pc: loop_end, + }), + (false, SeekOp::GE) => program.emit_insn(Insn::Ge { + lhs: rowid_reg.unwrap(), + rhs: start_reg, + target_pc: loop_end, + flags: CmpInsFlags::default(), + }), + (false, SeekOp::GT) => program.emit_insn(Insn::Gt { + lhs: rowid_reg.unwrap(), + rhs: start_reg, + target_pc: loop_end, + flags: CmpInsFlags::default(), + }), + (false, SeekOp::LE) => program.emit_insn(Insn::Le { + lhs: rowid_reg.unwrap(), + rhs: start_reg, + target_pc: loop_end, + flags: CmpInsFlags::default(), + }), + (false, SeekOp::LT) => program.emit_insn(Insn::Lt { + lhs: rowid_reg.unwrap(), + rhs: start_reg, + target_pc: loop_end, + flags: CmpInsFlags::default(), + }), + (_, SeekOp::EQ) => { + panic!("An index termination condition is never EQ") + } + }; + + Ok(()) +} diff --git a/core/translate/optimizer.rs b/core/translate/optimizer.rs index 772ed81e7..609acd906 100644 --- a/core/translate/optimizer.rs +++ b/core/translate/optimizer.rs @@ -4,13 +4,15 @@ use limbo_sqlite3_parser::ast::{self, Expr, SortOrder}; use crate::{ schema::{Index, Schema}, + translate::plan::TerminationKey, + types::SeekOp, util::exprs_are_equivalent, Result, }; use super::plan::{ - DeletePlan, Direction, GroupBy, IterationDirection, Operation, Plan, Search, SelectPlan, - TableReference, UpdatePlan, WhereTerm, + DeletePlan, Direction, GroupBy, IterationDirection, Operation, Plan, Search, SeekDef, SeekKey, + SelectPlan, TableReference, UpdatePlan, WhereTerm, }; pub fn optimize_plan(plan: &mut Plan, schema: &Schema) -> Result<()> { @@ -296,24 +298,62 @@ fn use_indexes( ) -> Result<()> { // Try to use indexes for eliminating ORDER BY clauses eliminate_unnecessary_orderby(table_references, available_indexes, order_by, group_by)?; + // Try to use indexes for WHERE conditions - 'outer: for (table_index, table_reference) in table_references.iter_mut().enumerate() { - if let Operation::Scan { iter_dir, .. } = &table_reference.op { - let mut i = 0; - while i < where_clause.len() { - let cond = where_clause.get_mut(i).unwrap(); - if let Some(index_search) = try_extract_index_search_expression( - cond, - table_index, - table_reference, - available_indexes, - *iter_dir, - )? { - where_clause.remove(i); - table_reference.op = Operation::Search(index_search); - continue 'outer; + for (table_index, table_reference) in table_references.iter_mut().enumerate() { + if matches!(table_reference.op, Operation::Scan { .. }) { + let index = if let Operation::Scan { index, .. } = &table_reference.op { + Option::clone(index) + } else { + None + }; + match index { + // If we decided to eliminate ORDER BY using an index, let's constrain our search to only that index + Some(index) => { + let available_indexes = available_indexes + .values() + .flatten() + .filter(|i| i.name == index.name) + .cloned() + .collect::>(); + if let Some(search) = try_extract_index_search_from_where_clause( + where_clause, + table_index, + table_reference, + &available_indexes, + )? { + table_reference.op = Operation::Search(search); + } + } + None => { + let table_name = table_reference.table.get_name(); + + // If we can utilize the rowid alias of the table, let's preferentially always use it for now. + let mut i = 0; + while i < where_clause.len() { + if let Some(search) = try_extract_rowid_search_expression( + &mut where_clause[i], + table_index, + table_reference, + )? { + where_clause.remove(i); + table_reference.op = Operation::Search(search); + continue; + } else { + i += 1; + } + } + if let Some(indexes) = available_indexes.get(table_name) { + if let Some(search) = try_extract_index_search_from_where_clause( + where_clause, + table_index, + table_reference, + indexes, + )? { + table_reference.op = Operation::Search(search); + } + } } - i += 1; } } } @@ -431,12 +471,6 @@ pub trait Optimizable { .map_or(false, |c| c == ConstantPredicate::AlwaysFalse)) } fn is_rowid_alias_of(&self, table_index: usize) -> bool; - fn check_index_scan( - &mut self, - table_index: usize, - table_reference: &TableReference, - available_indexes: &HashMap>>, - ) -> Result>>; } impl Optimizable for ast::Expr { @@ -450,79 +484,6 @@ impl Optimizable for ast::Expr { _ => false, } } - fn check_index_scan( - &mut self, - table_index: usize, - table_reference: &TableReference, - available_indexes: &HashMap>>, - ) -> Result>> { - match self { - Self::Column { table, column, .. } => { - if *table != table_index { - return Ok(None); - } - let Some(available_indexes_for_table) = - available_indexes.get(table_reference.table.get_name()) - else { - return Ok(None); - }; - let Some(column) = table_reference.table.get_column_at(*column) else { - return Ok(None); - }; - for index in available_indexes_for_table.iter() { - if let Some(name) = column.name.as_ref() { - if &index.columns.first().unwrap().name == name { - return Ok(Some(index.clone())); - } - } - } - Ok(None) - } - Self::Binary(lhs, op, rhs) => { - // Only consider index scans for binary ops that are comparisons. - // e.g. "t1.id = t2.id" is a valid index scan, but "t1.id + 1" is not. - // - // TODO/optimization: consider detecting index scan on e.g. table t1 in - // "WHERE t1.id + 1 = t2.id" - // here the Expr could be rewritten to "t1.id = t2.id - 1" - // and then t1.id could be used as an index key. - if !matches!( - *op, - ast::Operator::Equals - | ast::Operator::Greater - | ast::Operator::GreaterEquals - | ast::Operator::Less - | ast::Operator::LessEquals - ) { - return Ok(None); - } - let lhs_index = - lhs.check_index_scan(table_index, &table_reference, available_indexes)?; - if lhs_index.is_some() { - return Ok(lhs_index); - } - let rhs_index = - rhs.check_index_scan(table_index, &table_reference, available_indexes)?; - if rhs_index.is_some() { - // swap lhs and rhs - let swapped_operator = match *op { - ast::Operator::Equals => ast::Operator::Equals, - ast::Operator::Greater => ast::Operator::Less, - ast::Operator::GreaterEquals => ast::Operator::LessEquals, - ast::Operator::Less => ast::Operator::Greater, - ast::Operator::LessEquals => ast::Operator::GreaterEquals, - _ => unreachable!(), - }; - let lhs_new = rhs.take_ownership(); - let rhs_new = lhs.take_ownership(); - *self = Self::Binary(Box::new(lhs_new), swapped_operator, Box::new(rhs_new)); - return Ok(rhs_index); - } - Ok(None) - } - _ => Ok(None), - } - } fn check_constant(&self) -> Result> { match self { Self::Literal(lit) => match lit { @@ -652,13 +613,506 @@ fn opposite_cmp_op(op: ast::Operator) -> ast::Operator { } } -pub fn try_extract_index_search_expression( +/// Struct used for scoring index scans +/// Currently we just score by the number of index columns that can be utilized +/// in the scan, i.e. no statistics are used. +struct IndexScore { + index: Option>, + score: usize, + constraints: Vec, +} + +/// Try to extract an index search from the WHERE clause +/// Returns an optional [Search] struct if an index search can be extracted, otherwise returns None. +pub fn try_extract_index_search_from_where_clause( + where_clause: &mut Vec, + table_index: usize, + table_reference: &TableReference, + table_indexes: &[Arc], +) -> Result> { + // If there are no WHERE terms, we can't extract a search + if where_clause.is_empty() { + return Ok(None); + } + // If there are no indexes, we can't extract a search + if table_indexes.is_empty() { + return Ok(None); + } + + let iter_dir = if let Operation::Scan { iter_dir, .. } = &table_reference.op { + *iter_dir + } else { + return Ok(None); + }; + + // Find all potential index constraints + // For WHERE terms to be used to constrain an index scan, they must: + // 1. refer to columns in the table that the index is on + // 2. be a binary comparison expression + // 3. constrain the index columns in the order that they appear in the index + // - e.g. if the index is on (a,b,c) then we can use all of "a = 1 AND b = 2 AND c = 3" to constrain the index scan, + // - but if the where clause is "a = 1 and c = 3" then we can only use "a = 1". + let mut constraints_cur = vec![]; + let mut best_index = IndexScore { + index: None, + score: 0, + constraints: vec![], + }; + + for index in table_indexes { + // Check how many terms in the where clause constrain the index in column order + find_index_constraints( + where_clause, + table_index, + table_reference, + index, + &mut constraints_cur, + )?; + // naive scoring since we don't have statistics: prefer the index where we can use the most columns + // e.g. if we can use all columns of an index on (a,b), it's better than an index of (c,d,e) where we can only use c. + let score = constraints_cur.len(); + if score > best_index.score { + best_index.index = Some(Arc::clone(index)); + best_index.score = score; + best_index.constraints.clear(); + best_index.constraints.append(&mut constraints_cur); + } + } + + if best_index.index.is_none() { + return Ok(None); + } + + // Build the seek definition + let seek_def = + build_seek_def_from_index_constraints(&best_index.constraints, iter_dir, where_clause)?; + + // Remove the used terms from the where_clause since they are now part of the seek definition + // Sort terms by position in descending order to avoid shifting indices during removal + best_index.constraints.sort_by(|a, b| { + b.position_in_where_clause + .0 + .cmp(&a.position_in_where_clause.0) + }); + + for constraint in best_index.constraints.iter() { + where_clause.remove(constraint.position_in_where_clause.0); + } + + return Ok(Some(Search::Seek { + index: best_index.index, + seek_def, + })); +} + +#[derive(Debug, Clone)] +/// A representation of an expression in a [WhereTerm] that can potentially be used as part of an index seek key. +/// For example, if there is an index on table T(x,y) and another index on table U(z), and the where clause is "WHERE x > 10 AND 20 = z", +/// the index constraints are: +/// - x > 10 ==> IndexConstraint { position_in_where_clause: (0, [BinaryExprSide::Rhs]), operator: [ast::Operator::Greater] } +/// - 20 = z ==> IndexConstraint { position_in_where_clause: (1, [BinaryExprSide::Lhs]), operator: [ast::Operator::Equals] } +pub struct IndexConstraint { + position_in_where_clause: (usize, BinaryExprSide), + operator: ast::Operator, +} + +/// Helper enum for [IndexConstraint] to indicate which side of a binary comparison expression is being compared to the index column. +/// For example, if the where clause is "WHERE x = 10" and there's an index on x, +/// the [IndexConstraint] for the where clause term "x = 10" will have a [BinaryExprSide::Rhs] +/// because the right hand side expression "10" is being compared to the index column "x". +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum BinaryExprSide { + Lhs, + Rhs, +} + +/// Get the position of a column in an index +/// For example, if there is an index on table T(x,y) then y's position in the index is 1. +fn get_column_position_in_index( + expr: &ast::Expr, + table_index: usize, + table_reference: &TableReference, + index: &Arc, +) -> Option { + let ast::Expr::Column { table, column, .. } = expr else { + return None; + }; + if *table != table_index { + return None; + } + let Some(column) = table_reference.table.get_column_at(*column) else { + return None; + }; + index + .columns + .iter() + .position(|col| Some(&col.name) == column.name.as_ref()) +} + +/// Find all [IndexConstraint]s for a given WHERE clause +/// Constraints are appended as long as they constrain the index in column order. +/// E.g. for index (a,b,c) to be fully used, there must be a [WhereTerm] for each of a, b, and c. +/// If e.g. only a and c are present, then only the first column 'a' of the index will be used. +fn find_index_constraints( + where_clause: &mut Vec, + table_index: usize, + table_reference: &TableReference, + index: &Arc, + out_constraints: &mut Vec, +) -> Result<()> { + for position_in_index in 0..index.columns.len() { + let mut found = false; + for (position_in_where_clause, term) in where_clause.iter().enumerate() { + // Skip terms that cannot be evaluated at this table's loop level + if !term.should_eval_at_loop(table_index) { + continue; + } + // Skip terms that are not binary comparisons + let ast::Expr::Binary(lhs, operator, rhs) = &term.expr else { + continue; + }; + // Only consider index scans for binary ops that are comparisons + if !matches!( + *operator, + ast::Operator::Equals + | ast::Operator::Greater + | ast::Operator::GreaterEquals + | ast::Operator::Less + | ast::Operator::LessEquals + ) { + continue; + } + + // Check if lhs is a column that is in the i'th position of the index + if Some(position_in_index) + == get_column_position_in_index(lhs, table_index, table_reference, index) + { + out_constraints.push(IndexConstraint { + operator: *operator, + position_in_where_clause: (position_in_where_clause, BinaryExprSide::Rhs), + }); + found = true; + break; + } + // Check if rhs is a column that is in the i'th position of the index + if Some(position_in_index) + == get_column_position_in_index(rhs, table_index, table_reference, index) + { + out_constraints.push(IndexConstraint { + operator: opposite_cmp_op(*operator), // swap the operator since e.g. if condition is 5 >= x, we want to use x <= 5 + position_in_where_clause: (position_in_where_clause, BinaryExprSide::Lhs), + }); + found = true; + break; + } + } + if !found { + // Expressions must constrain index columns in index definition order. If we didn't find a constraint for the i'th index column, + // then we stop here and return the constraints we have found so far. + break; + } + } + + // In a multicolumn index, only the last term can have a nonequality expression. + // For example, imagine an index on (x,y) and the where clause is "WHERE x > 10 AND y > 20"; + // We can't use GT(x: 10,y: 20) as the seek key, because the first row greater than (x: 10,y: 20) + // might be e.g. (x: 10,y: 21), which does not satisfy the where clause, but a row after that e.g. (x: 11,y: 21) does. + // So: + // - in this case only GT(x: 10) can be used as the seek key, and we must emit a regular condition expression for y > 20 while scanning. + // On the other hand, if the where clause is "WHERE x = 10 AND y > 20", we can use GT(x=10,y=20) as the seek key, + // because any rows where (x=10,y=20) < ROW < (x=11) will match the where clause. + for i in 0..out_constraints.len() { + if out_constraints[i].operator != ast::Operator::Equals { + out_constraints.truncate(i + 1); + break; + } + } + + Ok(()) +} + +/// Build a [SeekDef] for a given list of [IndexConstraint]s +pub fn build_seek_def_from_index_constraints( + constraints: &[IndexConstraint], + iter_dir: IterationDirection, + where_clause: &mut Vec, +) -> Result { + assert!( + !constraints.is_empty(), + "cannot build seek def from empty list of index constraints" + ); + // Extract the key values and operators + let mut key = Vec::with_capacity(constraints.len()); + + for constraint in constraints { + // Extract the other expression from the binary WhereTerm (i.e. the one being compared to the index column) + let (idx, side) = constraint.position_in_where_clause; + let where_term = &mut where_clause[idx]; + let ast::Expr::Binary(lhs, _, rhs) = where_term.expr.take_ownership() else { + crate::bail_parse_error!("expected binary expression"); + }; + let cmp_expr = if side == BinaryExprSide::Lhs { + *lhs + } else { + *rhs + }; + key.push(cmp_expr); + } + + // We know all but potentially the last term is an equality, so we can use the operator of the last term + // to form the SeekOp + let op = constraints.last().unwrap().operator; + + build_seek_def(op, iter_dir, key) +} + +/// Build a [SeekDef] for a given comparison operator and index key. +/// To be usable as a seek key, all but potentially the last term must be equalities. +/// The last term can be a nonequality. +/// The comparison operator referred to by `op` is the operator of the last term. +/// +/// There are two parts to the seek definition: +/// 1. The [SeekKey], which specifies the key that we will use to seek to the first row that matches the index key. +/// 2. The [TerminationKey], which specifies the key that we will use to terminate the index scan that follows the seek. +/// +/// There are some nuances to how, and which parts of, the index key can be used in the [SeekKey] and [TerminationKey], +/// depending on the operator and iteration direction. This function explains those nuances inline when dealing with +/// each case. +/// +/// But to illustrate the general idea, consider the following examples: +/// +/// 1. For example, having two conditions like (x>10 AND y>20) cannot be used as a valid [SeekKey] GT(x:10, y:20) +/// because the first row greater than (x:10, y:20) might be (x:10, y:21), which does not satisfy the where clause. +/// In this case, only GT(x:10) must be used as the [SeekKey], and rows with y <= 20 must be filtered as a regular condition expression for each value of x. +/// +/// 2. In contrast, having (x=10 AND y>20) forms a valid index key GT(x:10, y:20) because after the seek, we can simply terminate as soon as x > 10, +/// i.e. use GT(x:10, y:20) as the [SeekKey] and GT(x:10) as the [TerminationKey]. +/// +fn build_seek_def( + op: ast::Operator, + iter_dir: IterationDirection, + key: Vec, +) -> Result { + let key_len = key.len(); + Ok(match (iter_dir, op) { + // Forwards, EQ: + // Example: (x=10 AND y=20) + // Seek key: GE(x:10, y:20) + // Termination key: GT(x:10, y:20) + (IterationDirection::Forwards, ast::Operator::Equals) => SeekDef { + key, + iter_dir, + seek: Some(SeekKey { + len: key_len, + op: SeekOp::GE, + }), + termination: Some(TerminationKey { + len: key_len, + op: SeekOp::GT, + }), + }, + // Forwards, GT: + // Example: (x=10 AND y>20) + // Seek key: GT(x:10, y:20) + // Termination key: GT(x:10) + (IterationDirection::Forwards, ast::Operator::Greater) => { + let termination_key_len = key_len - 1; + SeekDef { + key, + iter_dir, + seek: Some(SeekKey { + len: key_len, + op: SeekOp::GT, + }), + termination: if termination_key_len > 0 { + Some(TerminationKey { + len: termination_key_len, + op: SeekOp::GT, + }) + } else { + None + }, + } + } + // Forwards, GE: + // Example: (x=10 AND y>=20) + // Seek key: GE(x:10, y:20) + // Termination key: GT(x:10) + (IterationDirection::Forwards, ast::Operator::GreaterEquals) => { + let termination_key_len = key_len - 1; + SeekDef { + key, + iter_dir, + seek: Some(SeekKey { + len: key_len, + op: SeekOp::GE, + }), + termination: if termination_key_len > 0 { + Some(TerminationKey { + len: termination_key_len, + op: SeekOp::GT, + }) + } else { + None + }, + } + } + // Forwards, LT: + // Example: (x=10 AND y<20) + // Seek key: GT(x:10, y: NULL) // NULL is always LT, indicating we only care about x + // Termination key: GE(x:10, y:20) + (IterationDirection::Forwards, ast::Operator::Less) => SeekDef { + key, + iter_dir, + seek: Some(SeekKey { + len: key_len - 1, + op: SeekOp::GT, + }), + termination: Some(TerminationKey { + len: key_len, + op: SeekOp::GE, + }), + }, + // Forwards, LE: + // Example: (x=10 AND y<=20) + // Seek key: GE(x:10, y:NULL) // NULL is always LT, indicating we only care about x + // Termination key: GT(x:10, y:20) + (IterationDirection::Forwards, ast::Operator::LessEquals) => SeekDef { + key, + iter_dir, + seek: Some(SeekKey { + len: key_len - 1, + op: SeekOp::GE, + }), + termination: Some(TerminationKey { + len: key_len, + op: SeekOp::GT, + }), + }, + // Backwards, EQ: + // Example: (x=10 AND y=20) + // Seek key: LE(x:10, y:20) + // Termination key: LT(x:10, y:20) + (IterationDirection::Backwards, ast::Operator::Equals) => SeekDef { + key, + iter_dir, + seek: Some(SeekKey { + len: key_len, + op: SeekOp::LE, + }), + termination: Some(TerminationKey { + len: key_len, + op: SeekOp::LT, + }), + }, + // Backwards, LT: + // Example: (x=10 AND y<20) + // Seek key: LT(x:10, y:20) + // Termination key: LT(x:10) + (IterationDirection::Backwards, ast::Operator::Less) => { + let termination_key_len = key_len - 1; + SeekDef { + key, + iter_dir, + seek: Some(SeekKey { + len: key_len, + op: SeekOp::LT, + }), + termination: if termination_key_len > 0 { + Some(TerminationKey { + len: termination_key_len, + op: SeekOp::LT, + }) + } else { + None + }, + } + } + // Backwards, LE: + // Example: (x=10 AND y<=20) + // Seek key: LE(x:10, y:20) + // Termination key: LT(x:10) + (IterationDirection::Backwards, ast::Operator::LessEquals) => { + let termination_key_len = key_len - 1; + SeekDef { + key, + iter_dir, + seek: Some(SeekKey { + len: key_len, + op: SeekOp::LE, + }), + termination: if termination_key_len > 0 { + Some(TerminationKey { + len: termination_key_len, + op: SeekOp::LT, + }) + } else { + None + }, + } + } + // Backwards, GT: + // Example: (x=10 AND y>20) + // Seek key: LE(x:10) // try to find the last row where x = 10, not considering y at all. + // Termination key: LE(x:10, y:20) + (IterationDirection::Backwards, ast::Operator::Greater) => { + let seek_key_len = key_len - 1; + SeekDef { + key, + iter_dir, + seek: if seek_key_len > 0 { + Some(SeekKey { + len: seek_key_len, + op: SeekOp::LE, + }) + } else { + None + }, + termination: Some(TerminationKey { + len: key_len, + op: SeekOp::LE, + }), + } + } + // Backwards, GE: + // Example: (x=10 AND y>=20) + // Seek key: LE(x:10) // try to find the last row where x = 10, not considering y at all. + // Termination key: LT(x:10, y:20) + (IterationDirection::Backwards, ast::Operator::GreaterEquals) => { + let seek_key_len = key_len - 1; + SeekDef { + key, + iter_dir, + seek: if seek_key_len > 0 { + Some(SeekKey { + len: seek_key_len, + op: SeekOp::LE, + }) + } else { + None + }, + termination: Some(TerminationKey { + len: key_len, + op: SeekOp::LT, + }), + } + } + (_, op) => { + crate::bail_parse_error!("build_seek_def: invalid operator: {:?}", op,) + } + }) +} + +pub fn try_extract_rowid_search_expression( cond: &mut WhereTerm, table_index: usize, table_reference: &TableReference, - available_indexes: &HashMap>>, - iter_dir: IterationDirection, ) -> Result> { + let iter_dir = if let Operation::Scan { iter_dir, .. } = &table_reference.op { + *iter_dir + } else { + return Ok(None); + }; if !cond.should_eval_at_loop(table_index) { return Ok(None); } @@ -681,14 +1135,10 @@ pub fn try_extract_index_search_expression( | ast::Operator::Less | ast::Operator::LessEquals => { let rhs_owned = rhs.take_ownership(); - return Ok(Some(Search::RowidSearch { - cmp_op: *operator, - cmp_expr: WhereTerm { - expr: rhs_owned, - from_outer_join: cond.from_outer_join, - eval_at: cond.eval_at, - }, - iter_dir, + let seek_def = build_seek_def(*operator, iter_dir, vec![rhs_owned])?; + return Ok(Some(Search::Seek { + index: None, + seek_def, })); } _ => {} @@ -712,64 +1162,11 @@ pub fn try_extract_index_search_expression( | ast::Operator::Less | ast::Operator::LessEquals => { let lhs_owned = lhs.take_ownership(); - return Ok(Some(Search::RowidSearch { - cmp_op: opposite_cmp_op(*operator), - cmp_expr: WhereTerm { - expr: lhs_owned, - from_outer_join: cond.from_outer_join, - eval_at: cond.eval_at, - }, - iter_dir, - })); - } - _ => {} - } - } - - if let Some(index_rc) = - lhs.check_index_scan(table_index, &table_reference, available_indexes)? - { - match operator { - ast::Operator::Equals - | ast::Operator::Greater - | ast::Operator::GreaterEquals - | ast::Operator::Less - | ast::Operator::LessEquals => { - let rhs_owned = rhs.take_ownership(); - return Ok(Some(Search::IndexSearch { - index: index_rc, - cmp_op: *operator, - cmp_expr: WhereTerm { - expr: rhs_owned, - from_outer_join: cond.from_outer_join, - eval_at: cond.eval_at, - }, - iter_dir, - })); - } - _ => {} - } - } - - if let Some(index_rc) = - rhs.check_index_scan(table_index, &table_reference, available_indexes)? - { - match operator { - ast::Operator::Equals - | ast::Operator::Greater - | ast::Operator::GreaterEquals - | ast::Operator::Less - | ast::Operator::LessEquals => { - let lhs_owned = lhs.take_ownership(); - return Ok(Some(Search::IndexSearch { - index: index_rc, - cmp_op: opposite_cmp_op(*operator), - cmp_expr: WhereTerm { - expr: lhs_owned, - from_outer_join: cond.from_outer_join, - eval_at: cond.eval_at, - }, - iter_dir, + let op = opposite_cmp_op(*operator); + let seek_def = build_seek_def(op, iter_dir, vec![lhs_owned])?; + return Ok(Some(Search::Seek { + index: None, + seek_def, })); } _ => {} diff --git a/core/translate/plan.rs b/core/translate/plan.rs index 3958f9f81..ab7bc893c 100644 --- a/core/translate/plan.rs +++ b/core/translate/plan.rs @@ -7,13 +7,16 @@ use std::{ sync::Arc, }; -use crate::schema::{PseudoTable, Type}; use crate::{ function::AggFunc, schema::{BTreeTable, Column, Index, Table}, vdbe::BranchOffset, VirtualTable, }; +use crate::{ + schema::{PseudoTable, Type}, + types::SeekOp, +}; #[derive(Debug, Clone)] pub struct ResultSetColumn { @@ -325,6 +328,68 @@ impl TableReference { } } +/// A definition of a rowid/index search. +/// +/// [SeekKey] is the condition that is used to seek to a specific row in a table/index. +/// [TerminationKey] is the condition that is used to terminate the search after a seek. +#[derive(Debug, Clone)] +pub struct SeekDef { + /// The key to use when seeking and when terminating the scan that follows the seek. + /// For example, given: + /// - CREATE INDEX i ON t (x, y) + /// - SELECT * FROM t WHERE x = 1 AND y >= 30 + /// The key is [1, 30] + pub key: Vec, + /// The condition to use when seeking. See [SeekKey] for more details. + pub seek: Option, + /// The condition to use when terminating the scan that follows the seek. See [TerminationKey] for more details. + pub termination: Option, + /// The direction of the scan that follows the seek. + pub iter_dir: IterationDirection, +} + +impl SeekDef { + /// Whether we should null pad unset columns when seeking. + /// This is only done for forward seeks. + /// The reason it is done is that sometimes our full index key is not used in seeking. + /// See [SeekKey] for more details. + /// + /// For example, given: + /// - CREATE INDEX i ON t (x, y) + /// - SELECT * FROM t WHERE x = 1 AND y < 30 + /// We want to seek to the first row where x = 1, and then iterate forwards. + /// In this case, the seek key is GT(1, NULL) since '30' cannot be used to seek (since we want y < 30), + /// and any value of y will be greater than NULL. + /// + /// In backwards iteration direction, we do not null pad because we want to seek to the last row that matches the seek key. + /// For example, given: + /// - CREATE INDEX i ON t (x, y) + /// - SELECT * FROM t WHERE x = 1 AND y > 30 ORDER BY y + /// We want to seek to the last row where x = 1, and then iterate backwards. + /// In this case, the seek key is just LE(1) so any row with x = 1 will be a match. + pub fn null_pad_unset_cols(&self) -> bool { + self.iter_dir == IterationDirection::Forwards + } +} + +/// A condition to use when seeking. +#[derive(Debug, Clone)] +pub struct SeekKey { + /// How many columns from [SeekDef::key] are used in seeking. + pub len: usize, + /// The comparison operator to use when seeking. + pub op: SeekOp, +} + +#[derive(Debug, Clone)] +/// A condition to use when terminating the scan that follows a seek. +pub struct TerminationKey { + /// How many columns from [SeekDef::key] are used in terminating the scan that follows the seek. + pub len: usize, + /// The comparison operator to use when terminating the scan that follows the seek. + pub op: SeekOp, +} + /// An enum that represents a search operation that can be used to search for a row in a table using an index /// (i.e. a primary key or a secondary index) #[allow(clippy::enum_variant_names)] @@ -332,18 +397,10 @@ impl TableReference { pub enum Search { /// A rowid equality point lookup. This is a special case that uses the SeekRowid bytecode instruction and does not loop. RowidEq { cmp_expr: WhereTerm }, - /// A rowid search. Uses bytecode instructions like SeekGT, SeekGE etc. - RowidSearch { - cmp_op: ast::Operator, - cmp_expr: WhereTerm, - iter_dir: IterationDirection, - }, - /// A secondary index search. Uses bytecode instructions like SeekGE, SeekGT etc. - IndexSearch { - index: Arc, - cmp_op: ast::Operator, - cmp_expr: WhereTerm, - iter_dir: IterationDirection, + /// A search on a table btree (via `rowid`) or a secondary index search. Uses bytecode instructions like SeekGE, SeekGT etc. + Seek { + index: Option>, + seek_def: SeekDef, }, } @@ -420,14 +477,16 @@ impl Display for SelectPlan { writeln!(f, "{}SCAN {}", indent, table_name)?; } Operation::Search(search) => match search { - Search::RowidEq { .. } | Search::RowidSearch { .. } => { + Search::RowidEq { .. } | Search::Seek { index: None, .. } => { writeln!( f, "{}SEARCH {} USING INTEGER PRIMARY KEY (rowid=?)", indent, reference.identifier )?; } - Search::IndexSearch { index, .. } => { + Search::Seek { + index: Some(index), .. + } => { writeln!( f, "{}SEARCH {} USING INDEX {}", @@ -509,14 +568,16 @@ impl fmt::Display for UpdatePlan { } } Operation::Search(search) => match search { - Search::RowidEq { .. } | Search::RowidSearch { .. } => { + Search::RowidEq { .. } | Search::Seek { index: None, .. } => { writeln!( f, "{}SEARCH {} USING INTEGER PRIMARY KEY (rowid=?)", indent, reference.identifier )?; } - Search::IndexSearch { index, .. } => { + Search::Seek { + index: Some(index), .. + } => { writeln!( f, "{}SEARCH {} USING INDEX {}", diff --git a/core/translate/select.rs b/core/translate/select.rs index bde61880f..24a6331e5 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -411,8 +411,8 @@ fn count_plan_required_cursors(plan: &SelectPlan) -> usize { .map(|t| match &t.op { Operation::Scan { .. } => 1, Operation::Search(search) => match search { - Search::RowidEq { .. } | Search::RowidSearch { .. } => 1, - Search::IndexSearch { .. } => 2, // btree cursor and index cursor + Search::RowidEq { .. } => 1, + Search::Seek { index, .. } => 1 + index.is_some() as usize, }, Operation::Subquery { plan, .. } => count_plan_required_cursors(plan), }) diff --git a/core/types.rs b/core/types.rs index da6b778cc..3d531adfe 100644 --- a/core/types.rs +++ b/core/types.rs @@ -5,6 +5,7 @@ use crate::ext::{ExtValue, ExtValueType}; use crate::pseudo::PseudoCursor; use crate::storage::btree::BTreeCursor; use crate::storage::sqlite3_ondisk::write_varint; +use crate::translate::plan::IterationDirection; use crate::vdbe::sorter::Sorter; use crate::vdbe::{Register, VTabOpaqueCursor}; use crate::Result; @@ -1227,6 +1228,7 @@ pub enum CursorResult { } #[derive(Clone, Copy, PartialEq, Eq, Debug)] +/// The match condition of a table/index seek. pub enum SeekOp { EQ, GE, @@ -1235,6 +1237,24 @@ pub enum SeekOp { LT, } +impl SeekOp { + /// A given seek op implies an iteration direction. + /// + /// For example, a seek with SeekOp::GT implies: + /// Find the first table/index key that compares greater than the seek key + /// -> used in forwards iteration. + /// + /// A seek with SeekOp::LE implies: + /// Find the last table/index key that compares less than or equal to the seek key + /// -> used in backwards iteration. + pub fn iteration_direction(&self) -> IterationDirection { + match self { + SeekOp::EQ | SeekOp::GE | SeekOp::GT => IterationDirection::Forwards, + SeekOp::LE | SeekOp::LT => IterationDirection::Backwards, + } + } +} + #[derive(Clone, PartialEq, Debug)] pub enum SeekKey<'a> { TableRowId(u64), diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index 5cf2e6cd2..4d2a96d10 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -1,46 +1,55 @@ #![allow(unused_variables)] -use crate::error::{LimboError, SQLITE_CONSTRAINT, SQLITE_CONSTRAINT_PRIMARYKEY}; -use crate::ext::ExtValue; -use crate::function::{AggFunc, ExtFunc, MathFunc, MathFuncArity, ScalarFunc, VectorFunc}; -use crate::functions::datetime::{ - exec_date, exec_datetime_full, exec_julianday, exec_strftime, exec_time, exec_unixepoch, +use crate::{ + error::{LimboError, SQLITE_CONSTRAINT, SQLITE_CONSTRAINT_PRIMARYKEY}, + ext::ExtValue, + function::{AggFunc, ExtFunc, MathFunc, MathFuncArity, ScalarFunc, VectorFunc}, + functions::{ + datetime::{ + exec_date, exec_datetime_full, exec_julianday, exec_strftime, exec_time, exec_unixepoch, + }, + printf::exec_printf, + }, }; -use crate::functions::printf::exec_printf; use std::{borrow::BorrowMut, rc::Rc}; -use crate::pseudo::PseudoCursor; -use crate::result::LimboResult; +use crate::{pseudo::PseudoCursor, result::LimboResult}; -use crate::schema::{affinity, Affinity}; -use crate::storage::btree::{BTreeCursor, BTreeKey}; +use crate::{ + schema::{affinity, Affinity}, + storage::btree::{BTreeCursor, BTreeKey}, +}; -use crate::storage::wal::CheckpointResult; -use crate::types::{ - AggContext, Cursor, CursorResult, ExternalAggState, OwnedValue, OwnedValueType, SeekKey, SeekOp, +use crate::{ + storage::wal::CheckpointResult, + types::{ + AggContext, Cursor, CursorResult, ExternalAggState, OwnedValue, OwnedValueType, SeekKey, + SeekOp, + }, + util::{ + cast_real_to_integer, cast_text_to_integer, cast_text_to_numeric, cast_text_to_real, + checked_cast_text_to_numeric, parse_schema_rows, RoundToPrecision, + }, + vdbe::{ + builder::CursorType, + insn::{IdxInsertFlags, Insn}, + }, + vector::{vector32, vector64, vector_distance_cos, vector_extract}, }; -use crate::util::{ - cast_real_to_integer, cast_text_to_integer, cast_text_to_numeric, cast_text_to_real, - checked_cast_text_to_numeric, parse_schema_rows, RoundToPrecision, -}; -use crate::vdbe::builder::CursorType; -use crate::vdbe::insn::{IdxInsertFlags, Insn}; -use crate::vector::{vector32, vector64, vector_distance_cos, vector_extract}; use crate::{info, MvCursor, RefValue, Row, StepResult, TransactionState}; -use super::insn::{ - exec_add, exec_and, exec_bit_and, exec_bit_not, exec_bit_or, exec_boolean_not, exec_concat, - exec_divide, exec_multiply, exec_or, exec_remainder, exec_shift_left, exec_shift_right, - exec_subtract, Cookie, RegisterOrLiteral, +use super::{ + insn::{Cookie, RegisterOrLiteral}, + HaltState, }; -use super::HaltState; use rand::thread_rng; -use super::likeop::{construct_like_escape_arg, exec_glob, exec_like_with_escape}; -use super::sorter::Sorter; +use super::{ + likeop::{construct_like_escape_arg, exec_glob, exec_like_with_escape}, + sorter::Sorter, +}; use regex::{Regex, RegexBuilder}; -use std::cell::RefCell; -use std::collections::HashMap; +use std::{cell::RefCell, collections::HashMap}; #[cfg(feature = "json")] use crate::{ @@ -3397,6 +3406,21 @@ pub fn op_function( let result = exec_time(values); state.registers[*dest] = Register::OwnedValue(result); } + ScalarFunc::TimeDiff => { + if arg_count != 2 { + state.registers[*dest] = Register::OwnedValue(OwnedValue::Null); + } else { + let start = state.registers[*start_reg].get_owned_value().clone(); + let end = state.registers[*start_reg + 1].get_owned_value().clone(); + + let result = crate::functions::datetime::exec_timediff(&[ + Register::OwnedValue(start), + Register::OwnedValue(end), + ]); + + state.registers[*dest] = Register::OwnedValue(result); + } + } ScalarFunc::TotalChanges => { let res = &program.connection.upgrade().unwrap().total_changes; let total_changes = res.get(); @@ -3496,6 +3520,14 @@ pub fn op_function( let result = exec_likely(value.get_owned_value()); state.registers[*dest] = Register::OwnedValue(result); } + ScalarFunc::Likelihood => { + assert_eq!(arg_count, 2); + let value = &state.registers[*start_reg]; + let probability = &state.registers[*start_reg + 1]; + let result = + exec_likelihood(value.get_owned_value(), probability.get_owned_value()); + state.registers[*dest] = Register::OwnedValue(result); + } }, crate::function::Func::Vector(vector_func) => match vector_func { VectorFunc::Vector => { @@ -3789,6 +3821,7 @@ pub fn op_idx_insert_async( pager: &Rc, mv_store: Option<&Rc>, ) -> Result { + dbg!("op_idx_insert_async"); if let Insn::IdxInsertAsync { cursor_id, record_reg, @@ -3807,29 +3840,29 @@ pub fn op_idx_insert_async( Register::Record(ref r) => r, _ => return Err(LimboError::InternalError("expected record".into())), }; - let moved_before = if index_meta.unique { - // check for uniqueness violation - match cursor.key_exists_in_index(record)? { - CursorResult::Ok(true) => { - return Err(LimboError::Constraint( - "UNIQUE constraint failed: duplicate key".into(), - )) - } - CursorResult::IO => return Ok(InsnFunctionStepResult::IO), - CursorResult::Ok(false) => {} - }; - false - } else { - flags.has(IdxInsertFlags::USE_SEEK) - }; - // To make this reentrant in case of `moved_before` = false, we need to check if the previous cursor.insert started // a write/balancing operation. If it did, it means we already moved to the place we wanted. let moved_before = if cursor.is_write_in_progress() { true } else { - moved_before + if index_meta.unique { + // check for uniqueness violation + match cursor.key_exists_in_index(record)? { + CursorResult::Ok(true) => { + return Err(LimboError::Constraint( + "UNIQUE constraint failed: duplicate key".into(), + )) + } + CursorResult::IO => return Ok(InsnFunctionStepResult::IO), + CursorResult::Ok(false) => {} + }; + false + } else { + flags.has(IdxInsertFlags::USE_SEEK) + } }; + + dbg!(moved_before); // Start insertion of row. This might trigger a balance procedure which will take care of moving to different pages, // therefore, we don't want to seek again if that happens, meaning we don't want to return on io without moving to `Await` opcode // because it could trigger a movement to child page after a balance root which will leave the current page as the root page. @@ -5355,10 +5388,894 @@ fn exec_likely(reg: &OwnedValue) -> OwnedValue { reg.clone() } +fn exec_likelihood(reg: &OwnedValue, _probability: &OwnedValue) -> OwnedValue { + reg.clone() +} + +pub fn exec_add(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { + let result = match (lhs, rhs) { + (OwnedValue::Integer(lhs), OwnedValue::Integer(rhs)) => { + let result = lhs.overflowing_add(*rhs); + if result.1 { + OwnedValue::Float(*lhs as f64 + *rhs as f64) + } else { + OwnedValue::Integer(result.0) + } + } + (OwnedValue::Float(lhs), OwnedValue::Float(rhs)) => OwnedValue::Float(lhs + rhs), + (OwnedValue::Float(f), OwnedValue::Integer(i)) + | (OwnedValue::Integer(i), OwnedValue::Float(f)) => OwnedValue::Float(*f + *i as f64), + (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, + (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_add( + &cast_text_to_numeric(lhs.as_str()), + &cast_text_to_numeric(rhs.as_str()), + ), + (OwnedValue::Text(text), other) | (other, OwnedValue::Text(text)) => { + exec_add(&cast_text_to_numeric(text.as_str()), other) + } + _ => todo!(), + }; + match result { + OwnedValue::Float(f) if f.is_nan() => OwnedValue::Null, + _ => result, + } +} + +pub fn exec_subtract(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { + let result = match (lhs, rhs) { + (OwnedValue::Integer(lhs), OwnedValue::Integer(rhs)) => { + let result = lhs.overflowing_sub(*rhs); + if result.1 { + OwnedValue::Float(*lhs as f64 - *rhs as f64) + } else { + OwnedValue::Integer(result.0) + } + } + (OwnedValue::Float(lhs), OwnedValue::Float(rhs)) => OwnedValue::Float(lhs - rhs), + (OwnedValue::Float(lhs), OwnedValue::Integer(rhs)) => OwnedValue::Float(lhs - *rhs as f64), + (OwnedValue::Integer(lhs), OwnedValue::Float(rhs)) => OwnedValue::Float(*lhs as f64 - rhs), + (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, + (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_subtract( + &cast_text_to_numeric(lhs.as_str()), + &cast_text_to_numeric(rhs.as_str()), + ), + (OwnedValue::Text(text), other) => { + exec_subtract(&cast_text_to_numeric(text.as_str()), other) + } + (other, OwnedValue::Text(text)) => { + exec_subtract(other, &cast_text_to_numeric(text.as_str())) + } + _ => todo!(), + }; + match result { + OwnedValue::Float(f) if f.is_nan() => OwnedValue::Null, + _ => result, + } +} + +pub fn exec_multiply(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { + let result = match (lhs, rhs) { + (OwnedValue::Integer(lhs), OwnedValue::Integer(rhs)) => { + let result = lhs.overflowing_mul(*rhs); + if result.1 { + OwnedValue::Float(*lhs as f64 * *rhs as f64) + } else { + OwnedValue::Integer(result.0) + } + } + (OwnedValue::Float(lhs), OwnedValue::Float(rhs)) => OwnedValue::Float(lhs * rhs), + (OwnedValue::Integer(i), OwnedValue::Float(f)) + | (OwnedValue::Float(f), OwnedValue::Integer(i)) => OwnedValue::Float(*i as f64 * { *f }), + (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, + (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_multiply( + &cast_text_to_numeric(lhs.as_str()), + &cast_text_to_numeric(rhs.as_str()), + ), + (OwnedValue::Text(text), other) | (other, OwnedValue::Text(text)) => { + exec_multiply(&cast_text_to_numeric(text.as_str()), other) + } + + _ => todo!(), + }; + match result { + OwnedValue::Float(f) if f.is_nan() => OwnedValue::Null, + _ => result, + } +} + +pub fn exec_divide(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { + let result = match (lhs, rhs) { + (_, OwnedValue::Integer(0)) | (_, OwnedValue::Float(0.0)) => OwnedValue::Null, + (OwnedValue::Integer(lhs), OwnedValue::Integer(rhs)) => { + let result = lhs.overflowing_div(*rhs); + if result.1 { + OwnedValue::Float(*lhs as f64 / *rhs as f64) + } else { + OwnedValue::Integer(result.0) + } + } + (OwnedValue::Float(lhs), OwnedValue::Float(rhs)) => OwnedValue::Float(lhs / rhs), + (OwnedValue::Float(lhs), OwnedValue::Integer(rhs)) => OwnedValue::Float(lhs / *rhs as f64), + (OwnedValue::Integer(lhs), OwnedValue::Float(rhs)) => OwnedValue::Float(*lhs as f64 / rhs), + (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, + (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_divide( + &cast_text_to_numeric(lhs.as_str()), + &cast_text_to_numeric(rhs.as_str()), + ), + (OwnedValue::Text(text), other) => exec_divide(&cast_text_to_numeric(text.as_str()), other), + (other, OwnedValue::Text(text)) => exec_divide(other, &cast_text_to_numeric(text.as_str())), + _ => todo!(), + }; + match result { + OwnedValue::Float(f) if f.is_nan() => OwnedValue::Null, + _ => result, + } +} + +pub fn exec_bit_and(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { + match (lhs, rhs) { + (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, + (_, OwnedValue::Integer(0)) + | (OwnedValue::Integer(0), _) + | (_, OwnedValue::Float(0.0)) + | (OwnedValue::Float(0.0), _) => OwnedValue::Integer(0), + (OwnedValue::Integer(lh), OwnedValue::Integer(rh)) => OwnedValue::Integer(lh & rh), + (OwnedValue::Float(lh), OwnedValue::Float(rh)) => { + OwnedValue::Integer(*lh as i64 & *rh as i64) + } + (OwnedValue::Float(lh), OwnedValue::Integer(rh)) => OwnedValue::Integer(*lh as i64 & rh), + (OwnedValue::Integer(lh), OwnedValue::Float(rh)) => OwnedValue::Integer(lh & *rh as i64), + (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_bit_and( + &cast_text_to_numeric(lhs.as_str()), + &cast_text_to_numeric(rhs.as_str()), + ), + (OwnedValue::Text(text), other) | (other, OwnedValue::Text(text)) => { + exec_bit_and(&cast_text_to_numeric(text.as_str()), other) + } + _ => todo!(), + } +} + +pub fn exec_bit_or(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { + match (lhs, rhs) { + (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, + (OwnedValue::Integer(lh), OwnedValue::Integer(rh)) => OwnedValue::Integer(lh | rh), + (OwnedValue::Float(lh), OwnedValue::Integer(rh)) => OwnedValue::Integer(*lh as i64 | rh), + (OwnedValue::Integer(lh), OwnedValue::Float(rh)) => OwnedValue::Integer(lh | *rh as i64), + (OwnedValue::Float(lh), OwnedValue::Float(rh)) => { + OwnedValue::Integer(*lh as i64 | *rh as i64) + } + (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_bit_or( + &cast_text_to_numeric(lhs.as_str()), + &cast_text_to_numeric(rhs.as_str()), + ), + (OwnedValue::Text(text), other) | (other, OwnedValue::Text(text)) => { + exec_bit_or(&cast_text_to_numeric(text.as_str()), other) + } + _ => todo!(), + } +} + +pub fn exec_remainder(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { + match (lhs, rhs) { + (OwnedValue::Null, _) + | (_, OwnedValue::Null) + | (_, OwnedValue::Integer(0)) + | (_, OwnedValue::Float(0.0)) => OwnedValue::Null, + (OwnedValue::Integer(lhs), OwnedValue::Integer(rhs)) => { + if rhs == &0 { + OwnedValue::Null + } else { + OwnedValue::Integer(lhs % rhs.abs()) + } + } + (OwnedValue::Float(lhs), OwnedValue::Float(rhs)) => { + let rhs_int = *rhs as i64; + if rhs_int == 0 { + OwnedValue::Null + } else { + OwnedValue::Float(((*lhs as i64) % rhs_int.abs()) as f64) + } + } + (OwnedValue::Float(lhs), OwnedValue::Integer(rhs)) => { + if rhs == &0 { + OwnedValue::Null + } else { + OwnedValue::Float(((*lhs as i64) % rhs.abs()) as f64) + } + } + (OwnedValue::Integer(lhs), OwnedValue::Float(rhs)) => { + let rhs_int = *rhs as i64; + if rhs_int == 0 { + OwnedValue::Null + } else { + OwnedValue::Float((lhs % rhs_int.abs()) as f64) + } + } + (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_remainder( + &cast_text_to_numeric(lhs.as_str()), + &cast_text_to_numeric(rhs.as_str()), + ), + (OwnedValue::Text(text), other) => { + exec_remainder(&cast_text_to_numeric(text.as_str()), other) + } + (other, OwnedValue::Text(text)) => { + exec_remainder(other, &cast_text_to_numeric(text.as_str())) + } + other => todo!("remainder not implemented for: {:?} {:?}", lhs, other), + } +} + +pub fn exec_bit_not(reg: &OwnedValue) -> OwnedValue { + match reg { + OwnedValue::Null => OwnedValue::Null, + OwnedValue::Integer(i) => OwnedValue::Integer(!i), + OwnedValue::Float(f) => OwnedValue::Integer(!(*f as i64)), + OwnedValue::Text(text) => exec_bit_not(&cast_text_to_numeric(text.as_str())), + _ => todo!(), + } +} + +pub fn exec_shift_left(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { + match (lhs, rhs) { + (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, + (OwnedValue::Integer(lh), OwnedValue::Integer(rh)) => { + OwnedValue::Integer(compute_shl(*lh, *rh)) + } + (OwnedValue::Float(lh), OwnedValue::Integer(rh)) => { + OwnedValue::Integer(compute_shl(*lh as i64, *rh)) + } + (OwnedValue::Integer(lh), OwnedValue::Float(rh)) => { + OwnedValue::Integer(compute_shl(*lh, *rh as i64)) + } + (OwnedValue::Float(lh), OwnedValue::Float(rh)) => { + OwnedValue::Integer(compute_shl(*lh as i64, *rh as i64)) + } + (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_shift_left( + &cast_text_to_numeric(lhs.as_str()), + &cast_text_to_numeric(rhs.as_str()), + ), + (OwnedValue::Text(text), other) => { + exec_shift_left(&cast_text_to_numeric(text.as_str()), other) + } + (other, OwnedValue::Text(text)) => { + exec_shift_left(other, &cast_text_to_numeric(text.as_str())) + } + _ => todo!(), + } +} + +fn compute_shl(lhs: i64, rhs: i64) -> i64 { + if rhs == 0 { + lhs + } else if rhs > 0 { + // for positive shifts, if it's too large return 0 + if rhs >= 64 { + 0 + } else { + lhs << rhs + } + } else { + // for negative shifts, check if it's i64::MIN to avoid overflow on negation + if rhs == i64::MIN || rhs <= -64 { + if lhs < 0 { + -1 + } else { + 0 + } + } else { + lhs >> (-rhs) + } + } +} + +pub fn exec_shift_right(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { + match (lhs, rhs) { + (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, + (OwnedValue::Integer(lh), OwnedValue::Integer(rh)) => { + OwnedValue::Integer(compute_shr(*lh, *rh)) + } + (OwnedValue::Float(lh), OwnedValue::Integer(rh)) => { + OwnedValue::Integer(compute_shr(*lh as i64, *rh)) + } + (OwnedValue::Integer(lh), OwnedValue::Float(rh)) => { + OwnedValue::Integer(compute_shr(*lh, *rh as i64)) + } + (OwnedValue::Float(lh), OwnedValue::Float(rh)) => { + OwnedValue::Integer(compute_shr(*lh as i64, *rh as i64)) + } + (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_shift_right( + &cast_text_to_numeric(lhs.as_str()), + &cast_text_to_numeric(rhs.as_str()), + ), + (OwnedValue::Text(text), other) => { + exec_shift_right(&cast_text_to_numeric(text.as_str()), other) + } + (other, OwnedValue::Text(text)) => { + exec_shift_right(other, &cast_text_to_numeric(text.as_str())) + } + _ => todo!(), + } +} + +// compute binary shift to the right if rhs >= 0 and binary shift to the left - if rhs < 0 +// note, that binary shift to the right is sign-extended +fn compute_shr(lhs: i64, rhs: i64) -> i64 { + if rhs == 0 { + lhs + } else if rhs > 0 { + // for positive right shifts + if rhs >= 64 { + if lhs < 0 { + -1 + } else { + 0 + } + } else { + lhs >> rhs + } + } else { + // for negative right shifts, check if it's i64::MIN to avoid overflow + if rhs == i64::MIN || -rhs >= 64 { + 0 + } else { + lhs << (-rhs) + } + } +} + +pub fn exec_boolean_not(reg: &OwnedValue) -> OwnedValue { + match reg { + OwnedValue::Null => OwnedValue::Null, + OwnedValue::Integer(i) => OwnedValue::Integer((*i == 0) as i64), + OwnedValue::Float(f) => OwnedValue::Integer((*f == 0.0) as i64), + OwnedValue::Text(text) => exec_boolean_not(&cast_text_to_numeric(text.as_str())), + _ => todo!(), + } +} +pub fn exec_concat(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { + match (lhs, rhs) { + (OwnedValue::Text(lhs_text), OwnedValue::Text(rhs_text)) => { + OwnedValue::build_text(&(lhs_text.as_str().to_string() + rhs_text.as_str())) + } + (OwnedValue::Text(lhs_text), OwnedValue::Integer(rhs_int)) => { + OwnedValue::build_text(&(lhs_text.as_str().to_string() + &rhs_int.to_string())) + } + (OwnedValue::Text(lhs_text), OwnedValue::Float(rhs_float)) => { + OwnedValue::build_text(&(lhs_text.as_str().to_string() + &rhs_float.to_string())) + } + (OwnedValue::Integer(lhs_int), OwnedValue::Text(rhs_text)) => { + OwnedValue::build_text(&(lhs_int.to_string() + rhs_text.as_str())) + } + (OwnedValue::Integer(lhs_int), OwnedValue::Integer(rhs_int)) => { + OwnedValue::build_text(&(lhs_int.to_string() + &rhs_int.to_string())) + } + (OwnedValue::Integer(lhs_int), OwnedValue::Float(rhs_float)) => { + OwnedValue::build_text(&(lhs_int.to_string() + &rhs_float.to_string())) + } + (OwnedValue::Float(lhs_float), OwnedValue::Text(rhs_text)) => { + OwnedValue::build_text(&(lhs_float.to_string() + rhs_text.as_str())) + } + (OwnedValue::Float(lhs_float), OwnedValue::Integer(rhs_int)) => { + OwnedValue::build_text(&(lhs_float.to_string() + &rhs_int.to_string())) + } + (OwnedValue::Float(lhs_float), OwnedValue::Float(rhs_float)) => { + OwnedValue::build_text(&(lhs_float.to_string() + &rhs_float.to_string())) + } + (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, + (OwnedValue::Blob(_), _) | (_, OwnedValue::Blob(_)) => { + todo!("TODO: Handle Blob conversion to String") + } + } +} + +pub fn exec_and(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { + match (lhs, rhs) { + (_, OwnedValue::Integer(0)) + | (OwnedValue::Integer(0), _) + | (_, OwnedValue::Float(0.0)) + | (OwnedValue::Float(0.0), _) => OwnedValue::Integer(0), + (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, + (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_and( + &cast_text_to_numeric(lhs.as_str()), + &cast_text_to_numeric(rhs.as_str()), + ), + (OwnedValue::Text(text), other) | (other, OwnedValue::Text(text)) => { + exec_and(&cast_text_to_numeric(text.as_str()), other) + } + _ => OwnedValue::Integer(1), + } +} + +pub fn exec_or(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { + match (lhs, rhs) { + (OwnedValue::Null, OwnedValue::Null) + | (OwnedValue::Null, OwnedValue::Float(0.0)) + | (OwnedValue::Float(0.0), OwnedValue::Null) + | (OwnedValue::Null, OwnedValue::Integer(0)) + | (OwnedValue::Integer(0), OwnedValue::Null) => OwnedValue::Null, + (OwnedValue::Float(0.0), OwnedValue::Integer(0)) + | (OwnedValue::Integer(0), OwnedValue::Float(0.0)) + | (OwnedValue::Float(0.0), OwnedValue::Float(0.0)) + | (OwnedValue::Integer(0), OwnedValue::Integer(0)) => OwnedValue::Integer(0), + (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_or( + &cast_text_to_numeric(lhs.as_str()), + &cast_text_to_numeric(rhs.as_str()), + ), + (OwnedValue::Text(text), other) | (other, OwnedValue::Text(text)) => { + exec_or(&cast_text_to_numeric(text.as_str()), other) + } + _ => OwnedValue::Integer(1), + } +} + #[cfg(test)] mod tests { + use crate::types::{OwnedValue, Text}; + + use super::{exec_add, exec_or}; + + #[test] + fn test_exec_add() { + let inputs = vec![ + (OwnedValue::Integer(3), OwnedValue::Integer(1)), + (OwnedValue::Float(3.0), OwnedValue::Float(1.0)), + (OwnedValue::Float(3.0), OwnedValue::Integer(1)), + (OwnedValue::Integer(3), OwnedValue::Float(1.0)), + (OwnedValue::Null, OwnedValue::Null), + (OwnedValue::Null, OwnedValue::Integer(1)), + (OwnedValue::Null, OwnedValue::Float(1.0)), + (OwnedValue::Null, OwnedValue::Text(Text::from_str("2"))), + (OwnedValue::Integer(1), OwnedValue::Null), + (OwnedValue::Float(1.0), OwnedValue::Null), + (OwnedValue::Text(Text::from_str("1")), OwnedValue::Null), + ( + OwnedValue::Text(Text::from_str("1")), + OwnedValue::Text(Text::from_str("3")), + ), + ( + OwnedValue::Text(Text::from_str("1.0")), + OwnedValue::Text(Text::from_str("3.0")), + ), + ( + OwnedValue::Text(Text::from_str("1.0")), + OwnedValue::Float(3.0), + ), + ( + OwnedValue::Text(Text::from_str("1.0")), + OwnedValue::Integer(3), + ), + ( + OwnedValue::Float(1.0), + OwnedValue::Text(Text::from_str("3.0")), + ), + ( + OwnedValue::Integer(1), + OwnedValue::Text(Text::from_str("3")), + ), + ]; + + let outputs = [ + OwnedValue::Integer(4), + OwnedValue::Float(4.0), + OwnedValue::Float(4.0), + OwnedValue::Float(4.0), + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Integer(4), + OwnedValue::Float(4.0), + OwnedValue::Float(4.0), + OwnedValue::Float(4.0), + OwnedValue::Float(4.0), + OwnedValue::Float(4.0), + ]; + + assert_eq!( + inputs.len(), + outputs.len(), + "Inputs and Outputs should have same size" + ); + for (i, (lhs, rhs)) in inputs.iter().enumerate() { + assert_eq!( + exec_add(lhs, rhs), + outputs[i], + "Wrong ADD for lhs: {}, rhs: {}", + lhs, + rhs + ); + } + } + + use super::exec_subtract; + + #[test] + fn test_exec_subtract() { + let inputs = vec![ + (OwnedValue::Integer(3), OwnedValue::Integer(1)), + (OwnedValue::Float(3.0), OwnedValue::Float(1.0)), + (OwnedValue::Float(3.0), OwnedValue::Integer(1)), + (OwnedValue::Integer(3), OwnedValue::Float(1.0)), + (OwnedValue::Null, OwnedValue::Null), + (OwnedValue::Null, OwnedValue::Integer(1)), + (OwnedValue::Null, OwnedValue::Float(1.0)), + (OwnedValue::Null, OwnedValue::Text(Text::from_str("1"))), + (OwnedValue::Integer(1), OwnedValue::Null), + (OwnedValue::Float(1.0), OwnedValue::Null), + (OwnedValue::Text(Text::from_str("4")), OwnedValue::Null), + ( + OwnedValue::Text(Text::from_str("1")), + OwnedValue::Text(Text::from_str("3")), + ), + ( + OwnedValue::Text(Text::from_str("1.0")), + OwnedValue::Text(Text::from_str("3.0")), + ), + ( + OwnedValue::Text(Text::from_str("1.0")), + OwnedValue::Float(3.0), + ), + ( + OwnedValue::Text(Text::from_str("1.0")), + OwnedValue::Integer(3), + ), + ( + OwnedValue::Float(1.0), + OwnedValue::Text(Text::from_str("3.0")), + ), + ( + OwnedValue::Integer(1), + OwnedValue::Text(Text::from_str("3")), + ), + ]; + + let outputs = [ + OwnedValue::Integer(2), + OwnedValue::Float(2.0), + OwnedValue::Float(2.0), + OwnedValue::Float(2.0), + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Integer(-2), + OwnedValue::Float(-2.0), + OwnedValue::Float(-2.0), + OwnedValue::Float(-2.0), + OwnedValue::Float(-2.0), + OwnedValue::Float(-2.0), + ]; + + assert_eq!( + inputs.len(), + outputs.len(), + "Inputs and Outputs should have same size" + ); + for (i, (lhs, rhs)) in inputs.iter().enumerate() { + assert_eq!( + exec_subtract(lhs, rhs), + outputs[i], + "Wrong subtract for lhs: {}, rhs: {}", + lhs, + rhs + ); + } + } + use super::exec_multiply; + + #[test] + fn test_exec_multiply() { + let inputs = vec![ + (OwnedValue::Integer(3), OwnedValue::Integer(2)), + (OwnedValue::Float(3.0), OwnedValue::Float(2.0)), + (OwnedValue::Float(3.0), OwnedValue::Integer(2)), + (OwnedValue::Integer(3), OwnedValue::Float(2.0)), + (OwnedValue::Null, OwnedValue::Null), + (OwnedValue::Null, OwnedValue::Integer(1)), + (OwnedValue::Null, OwnedValue::Float(1.0)), + (OwnedValue::Null, OwnedValue::Text(Text::from_str("1"))), + (OwnedValue::Integer(1), OwnedValue::Null), + (OwnedValue::Float(1.0), OwnedValue::Null), + (OwnedValue::Text(Text::from_str("4")), OwnedValue::Null), + ( + OwnedValue::Text(Text::from_str("2")), + OwnedValue::Text(Text::from_str("3")), + ), + ( + OwnedValue::Text(Text::from_str("2.0")), + OwnedValue::Text(Text::from_str("3.0")), + ), + ( + OwnedValue::Text(Text::from_str("2.0")), + OwnedValue::Float(3.0), + ), + ( + OwnedValue::Text(Text::from_str("2.0")), + OwnedValue::Integer(3), + ), + ( + OwnedValue::Float(2.0), + OwnedValue::Text(Text::from_str("3.0")), + ), + ( + OwnedValue::Integer(2), + OwnedValue::Text(Text::from_str("3.0")), + ), + ]; + + let outputs = [ + OwnedValue::Integer(6), + OwnedValue::Float(6.0), + OwnedValue::Float(6.0), + OwnedValue::Float(6.0), + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Integer(6), + OwnedValue::Float(6.0), + OwnedValue::Float(6.0), + OwnedValue::Float(6.0), + OwnedValue::Float(6.0), + OwnedValue::Float(6.0), + ]; + + assert_eq!( + inputs.len(), + outputs.len(), + "Inputs and Outputs should have same size" + ); + for (i, (lhs, rhs)) in inputs.iter().enumerate() { + assert_eq!( + exec_multiply(lhs, rhs), + outputs[i], + "Wrong multiply for lhs: {}, rhs: {}", + lhs, + rhs + ); + } + } + use super::exec_divide; + + #[test] + fn test_exec_divide() { + let inputs = vec![ + (OwnedValue::Integer(1), OwnedValue::Integer(0)), + (OwnedValue::Float(1.0), OwnedValue::Float(0.0)), + (OwnedValue::Integer(i64::MIN), OwnedValue::Integer(-1)), + (OwnedValue::Float(6.0), OwnedValue::Float(2.0)), + (OwnedValue::Float(6.0), OwnedValue::Integer(2)), + (OwnedValue::Integer(6), OwnedValue::Integer(2)), + (OwnedValue::Null, OwnedValue::Integer(2)), + (OwnedValue::Integer(2), OwnedValue::Null), + (OwnedValue::Null, OwnedValue::Null), + ( + OwnedValue::Text(Text::from_str("6")), + OwnedValue::Text(Text::from_str("2")), + ), + ( + OwnedValue::Text(Text::from_str("6")), + OwnedValue::Integer(2), + ), + ]; + + let outputs = [ + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Float(9.223372036854776e18), + OwnedValue::Float(3.0), + OwnedValue::Float(3.0), + OwnedValue::Float(3.0), + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Float(3.0), + OwnedValue::Float(3.0), + ]; + + assert_eq!( + inputs.len(), + outputs.len(), + "Inputs and Outputs should have same size" + ); + for (i, (lhs, rhs)) in inputs.iter().enumerate() { + assert_eq!( + exec_divide(lhs, rhs), + outputs[i], + "Wrong divide for lhs: {}, rhs: {}", + lhs, + rhs + ); + } + } + + use super::exec_remainder; + #[test] + fn test_exec_remainder() { + let inputs = vec![ + (OwnedValue::Null, OwnedValue::Null), + (OwnedValue::Null, OwnedValue::Float(1.0)), + (OwnedValue::Null, OwnedValue::Integer(1)), + (OwnedValue::Null, OwnedValue::Text(Text::from_str("1"))), + (OwnedValue::Float(1.0), OwnedValue::Null), + (OwnedValue::Integer(1), OwnedValue::Null), + (OwnedValue::Integer(12), OwnedValue::Integer(0)), + (OwnedValue::Float(12.0), OwnedValue::Float(0.0)), + (OwnedValue::Float(12.0), OwnedValue::Integer(0)), + (OwnedValue::Integer(12), OwnedValue::Float(0.0)), + (OwnedValue::Integer(i64::MIN), OwnedValue::Integer(-1)), + (OwnedValue::Integer(12), OwnedValue::Integer(3)), + (OwnedValue::Float(12.0), OwnedValue::Float(3.0)), + (OwnedValue::Float(12.0), OwnedValue::Integer(3)), + (OwnedValue::Integer(12), OwnedValue::Float(3.0)), + (OwnedValue::Integer(12), OwnedValue::Integer(-3)), + (OwnedValue::Float(12.0), OwnedValue::Float(-3.0)), + (OwnedValue::Float(12.0), OwnedValue::Integer(-3)), + (OwnedValue::Integer(12), OwnedValue::Float(-3.0)), + ( + OwnedValue::Text(Text::from_str("12.0")), + OwnedValue::Text(Text::from_str("3.0")), + ), + ( + OwnedValue::Text(Text::from_str("12.0")), + OwnedValue::Float(3.0), + ), + ( + OwnedValue::Float(12.0), + OwnedValue::Text(Text::from_str("3.0")), + ), + ]; + let outputs = vec![ + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Float(0.0), + OwnedValue::Integer(0), + OwnedValue::Float(0.0), + OwnedValue::Float(0.0), + OwnedValue::Float(0.0), + OwnedValue::Integer(0), + OwnedValue::Float(0.0), + OwnedValue::Float(0.0), + OwnedValue::Float(0.0), + OwnedValue::Float(0.0), + OwnedValue::Float(0.0), + OwnedValue::Float(0.0), + ]; + + assert_eq!( + inputs.len(), + outputs.len(), + "Inputs and Outputs should have same size" + ); + + for (i, (lhs, rhs)) in inputs.iter().enumerate() { + assert_eq!( + exec_remainder(lhs, rhs), + outputs[i], + "Wrong remainder for lhs: {}, rhs: {}", + lhs, + rhs + ); + } + } + + use super::exec_and; + + #[test] + fn test_exec_and() { + let inputs = vec![ + (OwnedValue::Integer(0), OwnedValue::Null), + (OwnedValue::Null, OwnedValue::Integer(1)), + (OwnedValue::Null, OwnedValue::Null), + (OwnedValue::Float(0.0), OwnedValue::Null), + (OwnedValue::Integer(1), OwnedValue::Float(2.2)), + ( + OwnedValue::Integer(0), + OwnedValue::Text(Text::from_str("string")), + ), + ( + OwnedValue::Integer(0), + OwnedValue::Text(Text::from_str("1")), + ), + ( + OwnedValue::Integer(1), + OwnedValue::Text(Text::from_str("1")), + ), + ]; + let outputs = [ + OwnedValue::Integer(0), + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Integer(0), + OwnedValue::Integer(1), + OwnedValue::Integer(0), + OwnedValue::Integer(0), + OwnedValue::Integer(1), + ]; + + assert_eq!( + inputs.len(), + outputs.len(), + "Inputs and Outputs should have same size" + ); + for (i, (lhs, rhs)) in inputs.iter().enumerate() { + assert_eq!( + exec_and(lhs, rhs), + outputs[i], + "Wrong AND for lhs: {}, rhs: {}", + lhs, + rhs + ); + } + } + + #[test] + fn test_exec_or() { + let inputs = vec![ + (OwnedValue::Integer(0), OwnedValue::Null), + (OwnedValue::Null, OwnedValue::Integer(1)), + (OwnedValue::Null, OwnedValue::Null), + (OwnedValue::Float(0.0), OwnedValue::Null), + (OwnedValue::Integer(1), OwnedValue::Float(2.2)), + (OwnedValue::Float(0.0), OwnedValue::Integer(0)), + ( + OwnedValue::Integer(0), + OwnedValue::Text(Text::from_str("string")), + ), + ( + OwnedValue::Integer(0), + OwnedValue::Text(Text::from_str("1")), + ), + (OwnedValue::Integer(0), OwnedValue::Text(Text::from_str(""))), + ]; + let outputs = [ + OwnedValue::Null, + OwnedValue::Integer(1), + OwnedValue::Null, + OwnedValue::Null, + OwnedValue::Integer(1), + OwnedValue::Integer(0), + OwnedValue::Integer(0), + OwnedValue::Integer(1), + OwnedValue::Integer(0), + ]; + + assert_eq!( + inputs.len(), + outputs.len(), + "Inputs and Outputs should have same size" + ); + for (i, (lhs, rhs)) in inputs.iter().enumerate() { + assert_eq!( + exec_or(lhs, rhs), + outputs[i], + "Wrong OR for lhs: {}, rhs: {}", + lhs, + rhs + ); + } + } + use crate::vdbe::{ - execute::{exec_likely, exec_replace}, + execute::{exec_likelihood, exec_likely, exec_replace}, Bitfield, Register, }; @@ -5366,7 +6283,7 @@ mod tests { exec_abs, exec_char, exec_hex, exec_if, exec_instr, exec_length, exec_like, exec_lower, exec_ltrim, exec_max, exec_min, exec_nullif, exec_quote, exec_random, exec_randomblob, exec_round, exec_rtrim, exec_sign, exec_soundex, exec_substring, exec_trim, exec_typeof, - exec_unhex, exec_unicode, exec_upper, exec_zeroblob, execute_sqlite_version, OwnedValue, + exec_unhex, exec_unicode, exec_upper, exec_zeroblob, execute_sqlite_version, }; use std::collections::HashMap; @@ -6275,6 +7192,39 @@ mod tests { assert_eq!(exec_likely(&input), expected); } + #[test] + fn test_likelihood() { + let value = OwnedValue::build_text("limbo"); + let prob = OwnedValue::Float(0.5); + assert_eq!(exec_likelihood(&value, &prob), value); + + let value = OwnedValue::build_text("database"); + let prob = OwnedValue::Float(0.9375); + assert_eq!(exec_likelihood(&value, &prob), value); + + let value = OwnedValue::Integer(100); + let prob = OwnedValue::Float(1.0); + assert_eq!(exec_likelihood(&value, &prob), value); + + let value = OwnedValue::Float(12.34); + let prob = OwnedValue::Float(0.5); + assert_eq!(exec_likelihood(&value, &prob), value); + + let value = OwnedValue::Null; + let prob = OwnedValue::Float(0.5); + assert_eq!(exec_likelihood(&value, &prob), value); + + let value = OwnedValue::Blob(vec![1, 2, 3, 4]); + let prob = OwnedValue::Float(0.5); + assert_eq!(exec_likelihood(&value, &prob), value); + + let prob = OwnedValue::build_text("0.5"); + assert_eq!(exec_likelihood(&value, &prob), value); + + let prob = OwnedValue::Null; + assert_eq!(exec_likelihood(&value, &prob), value); + } + #[test] fn test_bitfield() { let mut bitfield = Bitfield::<4>::new(); diff --git a/core/vdbe/explain.rs b/core/vdbe/explain.rs index 3d46bc41b..3ce60f5db 100644 --- a/core/vdbe/explain.rs +++ b/core/vdbe/explain.rs @@ -748,28 +748,28 @@ pub fn insn_to_str( is_index: _, cursor_id, start_reg, - num_regs: _, + num_regs, target_pc, } | Insn::SeekGE { is_index: _, cursor_id, start_reg, - num_regs: _, + num_regs, target_pc, } | Insn::SeekLE { is_index: _, cursor_id, start_reg, - num_regs: _, + num_regs, target_pc, } | Insn::SeekLT { is_index: _, cursor_id, start_reg, - num_regs: _, + num_regs, target_pc, } => ( match insn { @@ -784,7 +784,7 @@ pub fn insn_to_str( *start_reg as i32, OwnedValue::build_text(""), 0, - "".to_string(), + format!("key=[{}..{}]", start_reg, start_reg + num_regs - 1), ), Insn::SeekEnd { cursor_id } => ( "SeekEnd", @@ -822,58 +822,40 @@ pub fn insn_to_str( Insn::IdxGT { cursor_id, start_reg, - num_regs: _, + num_regs, target_pc, - } => ( - "IdxGT", - *cursor_id as i32, - target_pc.to_debug_int(), - *start_reg as i32, - OwnedValue::build_text(""), - 0, - "".to_string(), - ), - Insn::IdxGE { + } + | Insn::IdxGE { cursor_id, start_reg, - num_regs: _, + num_regs, target_pc, - } => ( - "IdxGE", - *cursor_id as i32, - target_pc.to_debug_int(), - *start_reg as i32, - OwnedValue::build_text(""), - 0, - "".to_string(), - ), - Insn::IdxLT { + } + | Insn::IdxLE { cursor_id, start_reg, - num_regs: _, + num_regs, target_pc, - } => ( - "IdxLT", - *cursor_id as i32, - target_pc.to_debug_int(), - *start_reg as i32, - OwnedValue::build_text(""), - 0, - "".to_string(), - ), - Insn::IdxLE { + } + | Insn::IdxLT { cursor_id, start_reg, - num_regs: _, + num_regs, target_pc, } => ( - "IdxLE", + match insn { + Insn::IdxGT { .. } => "IdxGT", + Insn::IdxGE { .. } => "IdxGE", + Insn::IdxLE { .. } => "IdxLE", + Insn::IdxLT { .. } => "IdxLT", + _ => unreachable!(), + }, *cursor_id as i32, target_pc.to_debug_int(), *start_reg as i32, OwnedValue::build_text(""), 0, - "".to_string(), + format!("key=[{}..{}]", start_reg, start_reg + num_regs - 1), ), Insn::DecrJumpZero { reg, target_pc } => ( "DecrJumpZero", diff --git a/core/vdbe/insn.rs b/core/vdbe/insn.rs index 607949efb..0047f9d11 100644 --- a/core/vdbe/insn.rs +++ b/core/vdbe/insn.rs @@ -1,12 +1,7 @@ -use std::num::NonZero; -use std::rc::Rc; +use std::{num::NonZero, rc::Rc}; -use super::{ - cast_text_to_numeric, execute, AggFunc, BranchOffset, CursorID, FuncCtx, InsnFunction, PageIdx, -}; -use crate::schema::BTreeTable; -use crate::storage::wal::CheckpointMode; -use crate::types::{OwnedValue, Record}; +use super::{execute, AggFunc, BranchOffset, CursorID, FuncCtx, InsnFunction, PageIdx}; +use crate::{schema::BTreeTable, storage::wal::CheckpointMode, types::Record}; use limbo_macros::Description; /// Flags provided to comparison instructions (e.g. Eq, Ne) which determine behavior related to NULL values. @@ -815,440 +810,6 @@ pub enum Insn { }, } -// TODO: Add remaining cookies. -#[derive(Description, Debug, Clone, Copy)] -pub enum Cookie { - /// The schema cookie. - SchemaVersion = 1, - /// The schema format number. Supported schema formats are 1, 2, 3, and 4. - DatabaseFormat = 2, - /// Default page cache size. - DefaultPageCacheSize = 3, - /// The page number of the largest root b-tree page when in auto-vacuum or incremental-vacuum modes, or zero otherwise. - LargestRootPageNumber = 4, - /// The database text encoding. A value of 1 means UTF-8. A value of 2 means UTF-16le. A value of 3 means UTF-16be. - DatabaseTextEncoding = 5, - /// The "user version" as read and set by the user_version pragma. - UserVersion = 6, -} - -pub fn exec_add(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { - let result = match (lhs, rhs) { - (OwnedValue::Integer(lhs), OwnedValue::Integer(rhs)) => { - let result = lhs.overflowing_add(*rhs); - if result.1 { - OwnedValue::Float(*lhs as f64 + *rhs as f64) - } else { - OwnedValue::Integer(result.0) - } - } - (OwnedValue::Float(lhs), OwnedValue::Float(rhs)) => OwnedValue::Float(lhs + rhs), - (OwnedValue::Float(f), OwnedValue::Integer(i)) - | (OwnedValue::Integer(i), OwnedValue::Float(f)) => OwnedValue::Float(*f + *i as f64), - (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, - (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_add( - &cast_text_to_numeric(lhs.as_str()), - &cast_text_to_numeric(rhs.as_str()), - ), - (OwnedValue::Text(text), other) | (other, OwnedValue::Text(text)) => { - exec_add(&cast_text_to_numeric(text.as_str()), other) - } - _ => todo!(), - }; - match result { - OwnedValue::Float(f) if f.is_nan() => OwnedValue::Null, - _ => result, - } -} - -pub fn exec_subtract(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { - let result = match (lhs, rhs) { - (OwnedValue::Integer(lhs), OwnedValue::Integer(rhs)) => { - let result = lhs.overflowing_sub(*rhs); - if result.1 { - OwnedValue::Float(*lhs as f64 - *rhs as f64) - } else { - OwnedValue::Integer(result.0) - } - } - (OwnedValue::Float(lhs), OwnedValue::Float(rhs)) => OwnedValue::Float(lhs - rhs), - (OwnedValue::Float(lhs), OwnedValue::Integer(rhs)) => OwnedValue::Float(lhs - *rhs as f64), - (OwnedValue::Integer(lhs), OwnedValue::Float(rhs)) => OwnedValue::Float(*lhs as f64 - rhs), - (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, - (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_subtract( - &cast_text_to_numeric(lhs.as_str()), - &cast_text_to_numeric(rhs.as_str()), - ), - (OwnedValue::Text(text), other) => { - exec_subtract(&cast_text_to_numeric(text.as_str()), other) - } - (other, OwnedValue::Text(text)) => { - exec_subtract(other, &cast_text_to_numeric(text.as_str())) - } - _ => todo!(), - }; - match result { - OwnedValue::Float(f) if f.is_nan() => OwnedValue::Null, - _ => result, - } -} - -pub fn exec_multiply(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { - let result = match (lhs, rhs) { - (OwnedValue::Integer(lhs), OwnedValue::Integer(rhs)) => { - let result = lhs.overflowing_mul(*rhs); - if result.1 { - OwnedValue::Float(*lhs as f64 * *rhs as f64) - } else { - OwnedValue::Integer(result.0) - } - } - (OwnedValue::Float(lhs), OwnedValue::Float(rhs)) => OwnedValue::Float(lhs * rhs), - (OwnedValue::Integer(i), OwnedValue::Float(f)) - | (OwnedValue::Float(f), OwnedValue::Integer(i)) => OwnedValue::Float(*i as f64 * { *f }), - (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, - (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_multiply( - &cast_text_to_numeric(lhs.as_str()), - &cast_text_to_numeric(rhs.as_str()), - ), - (OwnedValue::Text(text), other) | (other, OwnedValue::Text(text)) => { - exec_multiply(&cast_text_to_numeric(text.as_str()), other) - } - - _ => todo!(), - }; - match result { - OwnedValue::Float(f) if f.is_nan() => OwnedValue::Null, - _ => result, - } -} - -pub fn exec_divide(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { - let result = match (lhs, rhs) { - (_, OwnedValue::Integer(0)) | (_, OwnedValue::Float(0.0)) => OwnedValue::Null, - (OwnedValue::Integer(lhs), OwnedValue::Integer(rhs)) => { - let result = lhs.overflowing_div(*rhs); - if result.1 { - OwnedValue::Float(*lhs as f64 / *rhs as f64) - } else { - OwnedValue::Integer(result.0) - } - } - (OwnedValue::Float(lhs), OwnedValue::Float(rhs)) => OwnedValue::Float(lhs / rhs), - (OwnedValue::Float(lhs), OwnedValue::Integer(rhs)) => OwnedValue::Float(lhs / *rhs as f64), - (OwnedValue::Integer(lhs), OwnedValue::Float(rhs)) => OwnedValue::Float(*lhs as f64 / rhs), - (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, - (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_divide( - &cast_text_to_numeric(lhs.as_str()), - &cast_text_to_numeric(rhs.as_str()), - ), - (OwnedValue::Text(text), other) => exec_divide(&cast_text_to_numeric(text.as_str()), other), - (other, OwnedValue::Text(text)) => exec_divide(other, &cast_text_to_numeric(text.as_str())), - _ => todo!(), - }; - match result { - OwnedValue::Float(f) if f.is_nan() => OwnedValue::Null, - _ => result, - } -} - -pub fn exec_bit_and(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { - match (lhs, rhs) { - (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, - (_, OwnedValue::Integer(0)) - | (OwnedValue::Integer(0), _) - | (_, OwnedValue::Float(0.0)) - | (OwnedValue::Float(0.0), _) => OwnedValue::Integer(0), - (OwnedValue::Integer(lh), OwnedValue::Integer(rh)) => OwnedValue::Integer(lh & rh), - (OwnedValue::Float(lh), OwnedValue::Float(rh)) => { - OwnedValue::Integer(*lh as i64 & *rh as i64) - } - (OwnedValue::Float(lh), OwnedValue::Integer(rh)) => OwnedValue::Integer(*lh as i64 & rh), - (OwnedValue::Integer(lh), OwnedValue::Float(rh)) => OwnedValue::Integer(lh & *rh as i64), - (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_bit_and( - &cast_text_to_numeric(lhs.as_str()), - &cast_text_to_numeric(rhs.as_str()), - ), - (OwnedValue::Text(text), other) | (other, OwnedValue::Text(text)) => { - exec_bit_and(&cast_text_to_numeric(text.as_str()), other) - } - _ => todo!(), - } -} - -pub fn exec_bit_or(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { - match (lhs, rhs) { - (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, - (OwnedValue::Integer(lh), OwnedValue::Integer(rh)) => OwnedValue::Integer(lh | rh), - (OwnedValue::Float(lh), OwnedValue::Integer(rh)) => OwnedValue::Integer(*lh as i64 | rh), - (OwnedValue::Integer(lh), OwnedValue::Float(rh)) => OwnedValue::Integer(lh | *rh as i64), - (OwnedValue::Float(lh), OwnedValue::Float(rh)) => { - OwnedValue::Integer(*lh as i64 | *rh as i64) - } - (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_bit_or( - &cast_text_to_numeric(lhs.as_str()), - &cast_text_to_numeric(rhs.as_str()), - ), - (OwnedValue::Text(text), other) | (other, OwnedValue::Text(text)) => { - exec_bit_or(&cast_text_to_numeric(text.as_str()), other) - } - _ => todo!(), - } -} - -pub fn exec_remainder(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { - match (lhs, rhs) { - (OwnedValue::Null, _) - | (_, OwnedValue::Null) - | (_, OwnedValue::Integer(0)) - | (_, OwnedValue::Float(0.0)) => OwnedValue::Null, - (OwnedValue::Integer(lhs), OwnedValue::Integer(rhs)) => { - if rhs == &0 { - OwnedValue::Null - } else { - OwnedValue::Integer(lhs % rhs.abs()) - } - } - (OwnedValue::Float(lhs), OwnedValue::Float(rhs)) => { - let rhs_int = *rhs as i64; - if rhs_int == 0 { - OwnedValue::Null - } else { - OwnedValue::Float(((*lhs as i64) % rhs_int.abs()) as f64) - } - } - (OwnedValue::Float(lhs), OwnedValue::Integer(rhs)) => { - if rhs == &0 { - OwnedValue::Null - } else { - OwnedValue::Float(((*lhs as i64) % rhs.abs()) as f64) - } - } - (OwnedValue::Integer(lhs), OwnedValue::Float(rhs)) => { - let rhs_int = *rhs as i64; - if rhs_int == 0 { - OwnedValue::Null - } else { - OwnedValue::Float((lhs % rhs_int.abs()) as f64) - } - } - (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_remainder( - &cast_text_to_numeric(lhs.as_str()), - &cast_text_to_numeric(rhs.as_str()), - ), - (OwnedValue::Text(text), other) => { - exec_remainder(&cast_text_to_numeric(text.as_str()), other) - } - (other, OwnedValue::Text(text)) => { - exec_remainder(other, &cast_text_to_numeric(text.as_str())) - } - other => todo!("remainder not implemented for: {:?} {:?}", lhs, other), - } -} - -pub fn exec_bit_not(reg: &OwnedValue) -> OwnedValue { - match reg { - OwnedValue::Null => OwnedValue::Null, - OwnedValue::Integer(i) => OwnedValue::Integer(!i), - OwnedValue::Float(f) => OwnedValue::Integer(!(*f as i64)), - OwnedValue::Text(text) => exec_bit_not(&cast_text_to_numeric(text.as_str())), - _ => todo!(), - } -} - -pub fn exec_shift_left(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { - match (lhs, rhs) { - (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, - (OwnedValue::Integer(lh), OwnedValue::Integer(rh)) => { - OwnedValue::Integer(compute_shl(*lh, *rh)) - } - (OwnedValue::Float(lh), OwnedValue::Integer(rh)) => { - OwnedValue::Integer(compute_shl(*lh as i64, *rh)) - } - (OwnedValue::Integer(lh), OwnedValue::Float(rh)) => { - OwnedValue::Integer(compute_shl(*lh, *rh as i64)) - } - (OwnedValue::Float(lh), OwnedValue::Float(rh)) => { - OwnedValue::Integer(compute_shl(*lh as i64, *rh as i64)) - } - (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_shift_left( - &cast_text_to_numeric(lhs.as_str()), - &cast_text_to_numeric(rhs.as_str()), - ), - (OwnedValue::Text(text), other) => { - exec_shift_left(&cast_text_to_numeric(text.as_str()), other) - } - (other, OwnedValue::Text(text)) => { - exec_shift_left(other, &cast_text_to_numeric(text.as_str())) - } - _ => todo!(), - } -} - -fn compute_shl(lhs: i64, rhs: i64) -> i64 { - if rhs == 0 { - lhs - } else if rhs > 0 { - // for positive shifts, if it's too large return 0 - if rhs >= 64 { - 0 - } else { - lhs << rhs - } - } else { - // for negative shifts, check if it's i64::MIN to avoid overflow on negation - if rhs == i64::MIN || rhs <= -64 { - if lhs < 0 { - -1 - } else { - 0 - } - } else { - lhs >> (-rhs) - } - } -} - -pub fn exec_shift_right(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { - match (lhs, rhs) { - (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, - (OwnedValue::Integer(lh), OwnedValue::Integer(rh)) => { - OwnedValue::Integer(compute_shr(*lh, *rh)) - } - (OwnedValue::Float(lh), OwnedValue::Integer(rh)) => { - OwnedValue::Integer(compute_shr(*lh as i64, *rh)) - } - (OwnedValue::Integer(lh), OwnedValue::Float(rh)) => { - OwnedValue::Integer(compute_shr(*lh, *rh as i64)) - } - (OwnedValue::Float(lh), OwnedValue::Float(rh)) => { - OwnedValue::Integer(compute_shr(*lh as i64, *rh as i64)) - } - (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_shift_right( - &cast_text_to_numeric(lhs.as_str()), - &cast_text_to_numeric(rhs.as_str()), - ), - (OwnedValue::Text(text), other) => { - exec_shift_right(&cast_text_to_numeric(text.as_str()), other) - } - (other, OwnedValue::Text(text)) => { - exec_shift_right(other, &cast_text_to_numeric(text.as_str())) - } - _ => todo!(), - } -} - -// compute binary shift to the right if rhs >= 0 and binary shift to the left - if rhs < 0 -// note, that binary shift to the right is sign-extended -fn compute_shr(lhs: i64, rhs: i64) -> i64 { - if rhs == 0 { - lhs - } else if rhs > 0 { - // for positive right shifts - if rhs >= 64 { - if lhs < 0 { - -1 - } else { - 0 - } - } else { - lhs >> rhs - } - } else { - // for negative right shifts, check if it's i64::MIN to avoid overflow - if rhs == i64::MIN || -rhs >= 64 { - 0 - } else { - lhs << (-rhs) - } - } -} - -pub fn exec_boolean_not(reg: &OwnedValue) -> OwnedValue { - match reg { - OwnedValue::Null => OwnedValue::Null, - OwnedValue::Integer(i) => OwnedValue::Integer((*i == 0) as i64), - OwnedValue::Float(f) => OwnedValue::Integer((*f == 0.0) as i64), - OwnedValue::Text(text) => exec_boolean_not(&cast_text_to_numeric(text.as_str())), - _ => todo!(), - } -} -pub fn exec_concat(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { - match (lhs, rhs) { - (OwnedValue::Text(lhs_text), OwnedValue::Text(rhs_text)) => { - OwnedValue::build_text(&(lhs_text.as_str().to_string() + rhs_text.as_str())) - } - (OwnedValue::Text(lhs_text), OwnedValue::Integer(rhs_int)) => { - OwnedValue::build_text(&(lhs_text.as_str().to_string() + &rhs_int.to_string())) - } - (OwnedValue::Text(lhs_text), OwnedValue::Float(rhs_float)) => { - OwnedValue::build_text(&(lhs_text.as_str().to_string() + &rhs_float.to_string())) - } - (OwnedValue::Integer(lhs_int), OwnedValue::Text(rhs_text)) => { - OwnedValue::build_text(&(lhs_int.to_string() + rhs_text.as_str())) - } - (OwnedValue::Integer(lhs_int), OwnedValue::Integer(rhs_int)) => { - OwnedValue::build_text(&(lhs_int.to_string() + &rhs_int.to_string())) - } - (OwnedValue::Integer(lhs_int), OwnedValue::Float(rhs_float)) => { - OwnedValue::build_text(&(lhs_int.to_string() + &rhs_float.to_string())) - } - (OwnedValue::Float(lhs_float), OwnedValue::Text(rhs_text)) => { - OwnedValue::build_text(&(lhs_float.to_string() + rhs_text.as_str())) - } - (OwnedValue::Float(lhs_float), OwnedValue::Integer(rhs_int)) => { - OwnedValue::build_text(&(lhs_float.to_string() + &rhs_int.to_string())) - } - (OwnedValue::Float(lhs_float), OwnedValue::Float(rhs_float)) => { - OwnedValue::build_text(&(lhs_float.to_string() + &rhs_float.to_string())) - } - (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, - (OwnedValue::Blob(_), _) | (_, OwnedValue::Blob(_)) => { - todo!("TODO: Handle Blob conversion to String") - } - } -} - -pub fn exec_and(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { - match (lhs, rhs) { - (_, OwnedValue::Integer(0)) - | (OwnedValue::Integer(0), _) - | (_, OwnedValue::Float(0.0)) - | (OwnedValue::Float(0.0), _) => OwnedValue::Integer(0), - (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, - (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_and( - &cast_text_to_numeric(lhs.as_str()), - &cast_text_to_numeric(rhs.as_str()), - ), - (OwnedValue::Text(text), other) | (other, OwnedValue::Text(text)) => { - exec_and(&cast_text_to_numeric(text.as_str()), other) - } - _ => OwnedValue::Integer(1), - } -} - -pub fn exec_or(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { - match (lhs, rhs) { - (OwnedValue::Null, OwnedValue::Null) - | (OwnedValue::Null, OwnedValue::Float(0.0)) - | (OwnedValue::Float(0.0), OwnedValue::Null) - | (OwnedValue::Null, OwnedValue::Integer(0)) - | (OwnedValue::Integer(0), OwnedValue::Null) => OwnedValue::Null, - (OwnedValue::Float(0.0), OwnedValue::Integer(0)) - | (OwnedValue::Integer(0), OwnedValue::Float(0.0)) - | (OwnedValue::Float(0.0), OwnedValue::Float(0.0)) - | (OwnedValue::Integer(0), OwnedValue::Integer(0)) => OwnedValue::Integer(0), - (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_or( - &cast_text_to_numeric(lhs.as_str()), - &cast_text_to_numeric(rhs.as_str()), - ), - (OwnedValue::Text(text), other) | (other, OwnedValue::Text(text)) => { - exec_or(&cast_text_to_numeric(text.as_str()), other) - } - _ => OwnedValue::Integer(1), - } -} - impl Insn { pub fn to_function(&self) -> InsnFunction { match self { @@ -1419,471 +980,19 @@ impl Insn { } } -#[cfg(test)] -mod tests { - use crate::{ - types::{OwnedValue, Text}, - vdbe::insn::exec_or, - }; - - use super::exec_add; - - #[test] - fn test_exec_add() { - let inputs = vec![ - (OwnedValue::Integer(3), OwnedValue::Integer(1)), - (OwnedValue::Float(3.0), OwnedValue::Float(1.0)), - (OwnedValue::Float(3.0), OwnedValue::Integer(1)), - (OwnedValue::Integer(3), OwnedValue::Float(1.0)), - (OwnedValue::Null, OwnedValue::Null), - (OwnedValue::Null, OwnedValue::Integer(1)), - (OwnedValue::Null, OwnedValue::Float(1.0)), - (OwnedValue::Null, OwnedValue::Text(Text::from_str("2"))), - (OwnedValue::Integer(1), OwnedValue::Null), - (OwnedValue::Float(1.0), OwnedValue::Null), - (OwnedValue::Text(Text::from_str("1")), OwnedValue::Null), - ( - OwnedValue::Text(Text::from_str("1")), - OwnedValue::Text(Text::from_str("3")), - ), - ( - OwnedValue::Text(Text::from_str("1.0")), - OwnedValue::Text(Text::from_str("3.0")), - ), - ( - OwnedValue::Text(Text::from_str("1.0")), - OwnedValue::Float(3.0), - ), - ( - OwnedValue::Text(Text::from_str("1.0")), - OwnedValue::Integer(3), - ), - ( - OwnedValue::Float(1.0), - OwnedValue::Text(Text::from_str("3.0")), - ), - ( - OwnedValue::Integer(1), - OwnedValue::Text(Text::from_str("3")), - ), - ]; - - let outputs = [ - OwnedValue::Integer(4), - OwnedValue::Float(4.0), - OwnedValue::Float(4.0), - OwnedValue::Float(4.0), - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Integer(4), - OwnedValue::Float(4.0), - OwnedValue::Float(4.0), - OwnedValue::Float(4.0), - OwnedValue::Float(4.0), - OwnedValue::Float(4.0), - ]; - - assert_eq!( - inputs.len(), - outputs.len(), - "Inputs and Outputs should have same size" - ); - for (i, (lhs, rhs)) in inputs.iter().enumerate() { - assert_eq!( - exec_add(lhs, rhs), - outputs[i], - "Wrong ADD for lhs: {}, rhs: {}", - lhs, - rhs - ); - } - } - - use super::exec_subtract; - - #[test] - fn test_exec_subtract() { - let inputs = vec![ - (OwnedValue::Integer(3), OwnedValue::Integer(1)), - (OwnedValue::Float(3.0), OwnedValue::Float(1.0)), - (OwnedValue::Float(3.0), OwnedValue::Integer(1)), - (OwnedValue::Integer(3), OwnedValue::Float(1.0)), - (OwnedValue::Null, OwnedValue::Null), - (OwnedValue::Null, OwnedValue::Integer(1)), - (OwnedValue::Null, OwnedValue::Float(1.0)), - (OwnedValue::Null, OwnedValue::Text(Text::from_str("1"))), - (OwnedValue::Integer(1), OwnedValue::Null), - (OwnedValue::Float(1.0), OwnedValue::Null), - (OwnedValue::Text(Text::from_str("4")), OwnedValue::Null), - ( - OwnedValue::Text(Text::from_str("1")), - OwnedValue::Text(Text::from_str("3")), - ), - ( - OwnedValue::Text(Text::from_str("1.0")), - OwnedValue::Text(Text::from_str("3.0")), - ), - ( - OwnedValue::Text(Text::from_str("1.0")), - OwnedValue::Float(3.0), - ), - ( - OwnedValue::Text(Text::from_str("1.0")), - OwnedValue::Integer(3), - ), - ( - OwnedValue::Float(1.0), - OwnedValue::Text(Text::from_str("3.0")), - ), - ( - OwnedValue::Integer(1), - OwnedValue::Text(Text::from_str("3")), - ), - ]; - - let outputs = [ - OwnedValue::Integer(2), - OwnedValue::Float(2.0), - OwnedValue::Float(2.0), - OwnedValue::Float(2.0), - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Integer(-2), - OwnedValue::Float(-2.0), - OwnedValue::Float(-2.0), - OwnedValue::Float(-2.0), - OwnedValue::Float(-2.0), - OwnedValue::Float(-2.0), - ]; - - assert_eq!( - inputs.len(), - outputs.len(), - "Inputs and Outputs should have same size" - ); - for (i, (lhs, rhs)) in inputs.iter().enumerate() { - assert_eq!( - exec_subtract(lhs, rhs), - outputs[i], - "Wrong subtract for lhs: {}, rhs: {}", - lhs, - rhs - ); - } - } - use super::exec_multiply; - - #[test] - fn test_exec_multiply() { - let inputs = vec![ - (OwnedValue::Integer(3), OwnedValue::Integer(2)), - (OwnedValue::Float(3.0), OwnedValue::Float(2.0)), - (OwnedValue::Float(3.0), OwnedValue::Integer(2)), - (OwnedValue::Integer(3), OwnedValue::Float(2.0)), - (OwnedValue::Null, OwnedValue::Null), - (OwnedValue::Null, OwnedValue::Integer(1)), - (OwnedValue::Null, OwnedValue::Float(1.0)), - (OwnedValue::Null, OwnedValue::Text(Text::from_str("1"))), - (OwnedValue::Integer(1), OwnedValue::Null), - (OwnedValue::Float(1.0), OwnedValue::Null), - (OwnedValue::Text(Text::from_str("4")), OwnedValue::Null), - ( - OwnedValue::Text(Text::from_str("2")), - OwnedValue::Text(Text::from_str("3")), - ), - ( - OwnedValue::Text(Text::from_str("2.0")), - OwnedValue::Text(Text::from_str("3.0")), - ), - ( - OwnedValue::Text(Text::from_str("2.0")), - OwnedValue::Float(3.0), - ), - ( - OwnedValue::Text(Text::from_str("2.0")), - OwnedValue::Integer(3), - ), - ( - OwnedValue::Float(2.0), - OwnedValue::Text(Text::from_str("3.0")), - ), - ( - OwnedValue::Integer(2), - OwnedValue::Text(Text::from_str("3.0")), - ), - ]; - - let outputs = [ - OwnedValue::Integer(6), - OwnedValue::Float(6.0), - OwnedValue::Float(6.0), - OwnedValue::Float(6.0), - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Integer(6), - OwnedValue::Float(6.0), - OwnedValue::Float(6.0), - OwnedValue::Float(6.0), - OwnedValue::Float(6.0), - OwnedValue::Float(6.0), - ]; - - assert_eq!( - inputs.len(), - outputs.len(), - "Inputs and Outputs should have same size" - ); - for (i, (lhs, rhs)) in inputs.iter().enumerate() { - assert_eq!( - exec_multiply(lhs, rhs), - outputs[i], - "Wrong multiply for lhs: {}, rhs: {}", - lhs, - rhs - ); - } - } - use super::exec_divide; - - #[test] - fn test_exec_divide() { - let inputs = vec![ - (OwnedValue::Integer(1), OwnedValue::Integer(0)), - (OwnedValue::Float(1.0), OwnedValue::Float(0.0)), - (OwnedValue::Integer(i64::MIN), OwnedValue::Integer(-1)), - (OwnedValue::Float(6.0), OwnedValue::Float(2.0)), - (OwnedValue::Float(6.0), OwnedValue::Integer(2)), - (OwnedValue::Integer(6), OwnedValue::Integer(2)), - (OwnedValue::Null, OwnedValue::Integer(2)), - (OwnedValue::Integer(2), OwnedValue::Null), - (OwnedValue::Null, OwnedValue::Null), - ( - OwnedValue::Text(Text::from_str("6")), - OwnedValue::Text(Text::from_str("2")), - ), - ( - OwnedValue::Text(Text::from_str("6")), - OwnedValue::Integer(2), - ), - ]; - - let outputs = [ - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Float(9.223372036854776e18), - OwnedValue::Float(3.0), - OwnedValue::Float(3.0), - OwnedValue::Float(3.0), - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Float(3.0), - OwnedValue::Float(3.0), - ]; - - assert_eq!( - inputs.len(), - outputs.len(), - "Inputs and Outputs should have same size" - ); - for (i, (lhs, rhs)) in inputs.iter().enumerate() { - assert_eq!( - exec_divide(lhs, rhs), - outputs[i], - "Wrong divide for lhs: {}, rhs: {}", - lhs, - rhs - ); - } - } - - use super::exec_remainder; - #[test] - fn test_exec_remainder() { - let inputs = vec![ - (OwnedValue::Null, OwnedValue::Null), - (OwnedValue::Null, OwnedValue::Float(1.0)), - (OwnedValue::Null, OwnedValue::Integer(1)), - (OwnedValue::Null, OwnedValue::Text(Text::from_str("1"))), - (OwnedValue::Float(1.0), OwnedValue::Null), - (OwnedValue::Integer(1), OwnedValue::Null), - (OwnedValue::Integer(12), OwnedValue::Integer(0)), - (OwnedValue::Float(12.0), OwnedValue::Float(0.0)), - (OwnedValue::Float(12.0), OwnedValue::Integer(0)), - (OwnedValue::Integer(12), OwnedValue::Float(0.0)), - (OwnedValue::Integer(i64::MIN), OwnedValue::Integer(-1)), - (OwnedValue::Integer(12), OwnedValue::Integer(3)), - (OwnedValue::Float(12.0), OwnedValue::Float(3.0)), - (OwnedValue::Float(12.0), OwnedValue::Integer(3)), - (OwnedValue::Integer(12), OwnedValue::Float(3.0)), - (OwnedValue::Integer(12), OwnedValue::Integer(-3)), - (OwnedValue::Float(12.0), OwnedValue::Float(-3.0)), - (OwnedValue::Float(12.0), OwnedValue::Integer(-3)), - (OwnedValue::Integer(12), OwnedValue::Float(-3.0)), - ( - OwnedValue::Text(Text::from_str("12.0")), - OwnedValue::Text(Text::from_str("3.0")), - ), - ( - OwnedValue::Text(Text::from_str("12.0")), - OwnedValue::Float(3.0), - ), - ( - OwnedValue::Float(12.0), - OwnedValue::Text(Text::from_str("3.0")), - ), - ]; - let outputs = vec![ - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Float(0.0), - OwnedValue::Integer(0), - OwnedValue::Float(0.0), - OwnedValue::Float(0.0), - OwnedValue::Float(0.0), - OwnedValue::Integer(0), - OwnedValue::Float(0.0), - OwnedValue::Float(0.0), - OwnedValue::Float(0.0), - OwnedValue::Float(0.0), - OwnedValue::Float(0.0), - OwnedValue::Float(0.0), - ]; - - assert_eq!( - inputs.len(), - outputs.len(), - "Inputs and Outputs should have same size" - ); - - for (i, (lhs, rhs)) in inputs.iter().enumerate() { - assert_eq!( - exec_remainder(lhs, rhs), - outputs[i], - "Wrong remainder for lhs: {}, rhs: {}", - lhs, - rhs - ); - } - } - - use super::exec_and; - - #[test] - fn test_exec_and() { - let inputs = vec![ - (OwnedValue::Integer(0), OwnedValue::Null), - (OwnedValue::Null, OwnedValue::Integer(1)), - (OwnedValue::Null, OwnedValue::Null), - (OwnedValue::Float(0.0), OwnedValue::Null), - (OwnedValue::Integer(1), OwnedValue::Float(2.2)), - ( - OwnedValue::Integer(0), - OwnedValue::Text(Text::from_str("string")), - ), - ( - OwnedValue::Integer(0), - OwnedValue::Text(Text::from_str("1")), - ), - ( - OwnedValue::Integer(1), - OwnedValue::Text(Text::from_str("1")), - ), - ]; - let outputs = [ - OwnedValue::Integer(0), - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Integer(0), - OwnedValue::Integer(1), - OwnedValue::Integer(0), - OwnedValue::Integer(0), - OwnedValue::Integer(1), - ]; - - assert_eq!( - inputs.len(), - outputs.len(), - "Inputs and Outputs should have same size" - ); - for (i, (lhs, rhs)) in inputs.iter().enumerate() { - assert_eq!( - exec_and(lhs, rhs), - outputs[i], - "Wrong AND for lhs: {}, rhs: {}", - lhs, - rhs - ); - } - } - - #[test] - fn test_exec_or() { - let inputs = vec![ - (OwnedValue::Integer(0), OwnedValue::Null), - (OwnedValue::Null, OwnedValue::Integer(1)), - (OwnedValue::Null, OwnedValue::Null), - (OwnedValue::Float(0.0), OwnedValue::Null), - (OwnedValue::Integer(1), OwnedValue::Float(2.2)), - (OwnedValue::Float(0.0), OwnedValue::Integer(0)), - ( - OwnedValue::Integer(0), - OwnedValue::Text(Text::from_str("string")), - ), - ( - OwnedValue::Integer(0), - OwnedValue::Text(Text::from_str("1")), - ), - (OwnedValue::Integer(0), OwnedValue::Text(Text::from_str(""))), - ]; - let outputs = [ - OwnedValue::Null, - OwnedValue::Integer(1), - OwnedValue::Null, - OwnedValue::Null, - OwnedValue::Integer(1), - OwnedValue::Integer(0), - OwnedValue::Integer(0), - OwnedValue::Integer(1), - OwnedValue::Integer(0), - ]; - - assert_eq!( - inputs.len(), - outputs.len(), - "Inputs and Outputs should have same size" - ); - for (i, (lhs, rhs)) in inputs.iter().enumerate() { - assert_eq!( - exec_or(lhs, rhs), - outputs[i], - "Wrong OR for lhs: {}, rhs: {}", - lhs, - rhs - ); - } - } +// TODO: Add remaining cookies. +#[derive(Description, Debug, Clone, Copy)] +pub enum Cookie { + /// The schema cookie. + SchemaVersion = 1, + /// The schema format number. Supported schema formats are 1, 2, 3, and 4. + DatabaseFormat = 2, + /// Default page cache size. + DefaultPageCacheSize = 3, + /// The page number of the largest root b-tree page when in auto-vacuum or incremental-vacuum modes, or zero otherwise. + LargestRootPageNumber = 4, + /// The database text encoding. A value of 1 means UTF-8. A value of 2 means UTF-16le. A value of 3 means UTF-16be. + DatabaseTextEncoding = 5, + /// The "user version" as read and set by the user_version pragma. + UserVersion = 6, } diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index cf6918304..daad191b4 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -24,19 +24,18 @@ pub mod insn; pub mod likeop; pub mod sorter; -use crate::error::LimboError; -use crate::fast_lock::SpinLock; -use crate::function::{AggFunc, FuncCtx}; - -use crate::storage::sqlite3_ondisk::DatabaseHeader; -use crate::storage::{btree::BTreeCursor, pager::Pager}; -use crate::translate::plan::{ResultSetColumn, TableReference}; -use crate::types::{ - AggContext, Cursor, CursorResult, ImmutableRecord, OwnedValue, SeekKey, SeekOp, +use crate::{ + error::LimboError, + fast_lock::SpinLock, + function::{AggFunc, FuncCtx}, +}; + +use crate::{ + storage::{btree::BTreeCursor, pager::Pager, sqlite3_ondisk::DatabaseHeader}, + translate::plan::{ResultSetColumn, TableReference}, + types::{AggContext, Cursor, CursorResult, ImmutableRecord, OwnedValue, SeekKey, SeekOp}, + vdbe::{builder::CursorType, insn::Insn}, }; -use crate::util::cast_text_to_numeric; -use crate::vdbe::builder::CursorType; -use crate::vdbe::insn::Insn; use crate::CheckpointStatus; @@ -45,16 +44,20 @@ use crate::json::JsonCacheCell; use crate::{Connection, MvStore, Result, TransactionState}; use execute::{InsnFunction, InsnFunctionStepResult}; -use rand::distributions::{Distribution, Uniform}; -use rand::Rng; +use rand::{ + distributions::{Distribution, Uniform}, + Rng, +}; use regex::Regex; -use std::cell::{Cell, RefCell}; -use std::collections::HashMap; -use std::ffi::c_void; -use std::num::NonZero; -use std::ops::Deref; -use std::rc::{Rc, Weak}; -use std::sync::Arc; +use std::{ + cell::{Cell, RefCell}, + collections::HashMap, + ffi::c_void, + num::NonZero, + ops::Deref, + rc::{Rc, Weak}, + sync::Arc, +}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] /// Represents a target for a jump instruction. diff --git a/stress/Cargo.toml b/stress/Cargo.toml index 59c4e8256..9c0097d45 100644 --- a/stress/Cargo.toml +++ b/stress/Cargo.toml @@ -20,3 +20,8 @@ clap = { version = "4.5", features = ["derive"] } limbo = { path = "../bindings/rust" } serde_json = "1.0.139" tokio = { version = "1.29.1", features = ["full"] } +anarchist-readable-name-generator-lib = "0.1.0" +hex = "0.4" +tracing = "0.1.41" +tracing-appender = "0.2.3" +tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } diff --git a/stress/main.rs b/stress/main.rs index c62714e63..e8b61b459 100644 --- a/stress/main.rs +++ b/stress/main.rs @@ -1,14 +1,396 @@ mod opts; +use anarchist_readable_name_generator_lib::readable_name_custom; +use antithesis_sdk::random::{get_random, AntithesisRng}; use antithesis_sdk::*; use clap::Parser; -use limbo::{Builder, Value}; +use core::panic; +use hex; +use limbo::Builder; use opts::Opts; use serde_json::json; +use std::collections::HashSet; +use std::fs::File; +use std::io::{Read, Write}; use std::sync::Arc; +use tracing_appender::non_blocking::WorkerGuard; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; +use tracing_subscriber::EnvFilter; + +pub struct Plan { + pub ddl_statements: Vec, + pub queries_per_thread: Vec>, + pub nr_iterations: usize, + pub nr_threads: usize, +} + +/// Represents a column in a SQLite table +#[derive(Debug, Clone)] +pub struct Column { + pub name: String, + pub data_type: DataType, + pub constraints: Vec, +} + +/// Represents SQLite data types +#[derive(Debug, Clone)] +pub enum DataType { + Integer, + Real, + Text, + Blob, + Numeric, +} + +/// Represents column constraints +#[derive(Debug, Clone, PartialEq)] +pub enum Constraint { + PrimaryKey, + NotNull, + Unique, +} + +/// Represents a table in a SQLite schema +#[derive(Debug, Clone)] +pub struct Table { + pub name: String, + pub columns: Vec, +} + +/// Represents a complete SQLite schema +#[derive(Debug, Clone)] +pub struct ArbitrarySchema { + pub tables: Vec, +} + +// Helper functions for generating random data +fn generate_random_identifier() -> String { + readable_name_custom("_", AntithesisRng).replace('-', "_") +} + +fn generate_random_data_type() -> DataType { + match get_random() % 5 { + 0 => DataType::Integer, + 1 => DataType::Real, + 2 => DataType::Text, + 3 => DataType::Blob, + _ => DataType::Numeric, + } +} + +fn generate_random_constraint() -> Constraint { + match get_random() % 2 { + 0 => Constraint::NotNull, + _ => Constraint::Unique, + } +} + +fn generate_random_column() -> Column { + let name = generate_random_identifier(); + let data_type = generate_random_data_type(); + + let constraint_count = (get_random() % 3) as usize; + let mut constraints = Vec::with_capacity(constraint_count); + + for _ in 0..constraint_count { + constraints.push(generate_random_constraint()); + } + + Column { + name, + data_type, + constraints, + } +} + +fn generate_random_table() -> Table { + let name = generate_random_identifier(); + let column_count = (get_random() % 10 + 1) as usize; + let mut columns = Vec::with_capacity(column_count); + let mut column_names = HashSet::new(); + + // First, generate all columns without primary keys + for _ in 0..column_count { + let mut column = generate_random_column(); + + // Ensure column names are unique within the table + while column_names.contains(&column.name) { + column.name = generate_random_identifier(); + } + + column_names.insert(column.name.clone()); + columns.push(column); + } + + // Then, randomly select one column to be the primary key + let pk_index = (get_random() % column_count as u64) as usize; + columns[pk_index].constraints.push(Constraint::PrimaryKey); + + Table { name, columns } +} + +pub fn gen_schema() -> ArbitrarySchema { + let table_count = (get_random() % 10 + 1) as usize; + let mut tables = Vec::with_capacity(table_count); + let mut table_names = HashSet::new(); + + for _ in 0..table_count { + let mut table = generate_random_table(); + + // Ensure table names are unique + while table_names.contains(&table.name) { + table.name = generate_random_identifier(); + } + + table_names.insert(table.name.clone()); + tables.push(table); + } + + ArbitrarySchema { tables } +} + +impl ArbitrarySchema { + /// Convert the schema to a vector of SQL DDL statements + pub fn to_sql(&self) -> Vec { + self.tables + .iter() + .map(|table| { + let columns = table + .columns + .iter() + .map(|col| { + let mut col_def = + format!(" {} {}", col.name, data_type_to_sql(&col.data_type)); + for constraint in &col.constraints { + col_def.push(' '); + col_def.push_str(&constraint_to_sql(constraint)); + } + col_def + }) + .collect::>() + .join(","); + + format!("CREATE TABLE {} ({});", table.name, columns) + }) + .collect() + } +} + +fn data_type_to_sql(data_type: &DataType) -> &'static str { + match data_type { + DataType::Integer => "INTEGER", + DataType::Real => "REAL", + DataType::Text => "TEXT", + DataType::Blob => "BLOB", + DataType::Numeric => "NUMERIC", + } +} + +fn constraint_to_sql(constraint: &Constraint) -> String { + match constraint { + Constraint::PrimaryKey => "PRIMARY KEY".to_string(), + Constraint::NotNull => "NOT NULL".to_string(), + Constraint::Unique => "UNIQUE".to_string(), + } +} + +/// Generate a random value for a given data type +fn generate_random_value(data_type: &DataType) -> String { + match data_type { + DataType::Integer => (get_random() % 1000).to_string(), + DataType::Real => format!("{:.2}", (get_random() % 1000) as f64 / 100.0), + DataType::Text => format!("'{}'", generate_random_identifier()), + DataType::Blob => format!("x'{}'", hex::encode(generate_random_identifier())), + DataType::Numeric => (get_random() % 1000).to_string(), + } +} + +/// Generate a random INSERT statement for a table +fn generate_insert(table: &Table) -> String { + let columns = table + .columns + .iter() + .map(|col| col.name.clone()) + .collect::>() + .join(", "); + + let values = table + .columns + .iter() + .map(|col| generate_random_value(&col.data_type)) + .collect::>() + .join(", "); + + format!( + "INSERT INTO {} ({}) VALUES ({});", + table.name, columns, values + ) +} + +/// Generate a random UPDATE statement for a table +fn generate_update(table: &Table) -> String { + // Find the primary key column + let pk_column = table + .columns + .iter() + .find(|col| col.constraints.contains(&Constraint::PrimaryKey)) + .expect("Table should have a primary key"); + + // Get all non-primary key columns + let non_pk_columns: Vec<_> = table + .columns + .iter() + .filter(|col| col.name != pk_column.name) + .collect(); + + // If we have no non-PK columns, just update the primary key itself + let set_clause = if non_pk_columns.is_empty() { + format!( + "{} = {}", + pk_column.name, + generate_random_value(&pk_column.data_type) + ) + } else { + non_pk_columns + .iter() + .map(|col| format!("{} = {}", col.name, generate_random_value(&col.data_type))) + .collect::>() + .join(", ") + }; + + let where_clause = format!( + "{} = {}", + pk_column.name, + generate_random_value(&pk_column.data_type) + ); + + format!( + "UPDATE {} SET {} WHERE {};", + table.name, set_clause, where_clause + ) +} + +/// Generate a random DELETE statement for a table +fn generate_delete(table: &Table) -> String { + // Find the primary key column + let pk_column = table + .columns + .iter() + .find(|col| col.constraints.contains(&Constraint::PrimaryKey)) + .expect("Table should have a primary key"); + + let where_clause = format!( + "{} = {}", + pk_column.name, + generate_random_value(&pk_column.data_type) + ); + + format!("DELETE FROM {} WHERE {};", table.name, where_clause) +} + +/// Generate a random SQL statement for a schema +fn generate_random_statement(schema: &ArbitrarySchema) -> String { + let table = &schema.tables[get_random() as usize % schema.tables.len()]; + match get_random() % 3 { + 0 => generate_insert(table), + 1 => generate_update(table), + _ => generate_delete(table), + } +} + +fn generate_plan(opts: &Opts) -> Result> { + let schema = gen_schema(); + // Write DDL statements to log file + let mut log_file = File::create(&opts.log_file)?; + let ddl_statements = schema.to_sql(); + let mut plan = Plan { + ddl_statements: vec![], + queries_per_thread: vec![], + nr_iterations: opts.nr_iterations, + nr_threads: opts.nr_threads, + }; + writeln!(log_file, "{}", opts.nr_threads)?; + writeln!(log_file, "{}", opts.nr_iterations)?; + writeln!(log_file, "{}", ddl_statements.len())?; + for stmt in &ddl_statements { + writeln!(log_file, "{}", stmt)?; + } + plan.ddl_statements = ddl_statements; + for _ in 0..opts.nr_threads { + let mut queries = vec![]; + for _ in 0..opts.nr_iterations { + let sql = generate_random_statement(&schema); + writeln!(log_file, "{}", sql)?; + queries.push(sql); + } + plan.queries_per_thread.push(queries); + } + Ok(plan) +} + +fn read_plan_from_log_file(opts: &Opts) -> Result> { + let mut file = File::open(&opts.log_file)?; + let mut buf = String::new(); + let mut plan = Plan { + ddl_statements: vec![], + queries_per_thread: vec![], + nr_iterations: 0, + nr_threads: 0, + }; + file.read_to_string(&mut buf).unwrap(); + let mut lines = buf.lines(); + plan.nr_threads = lines.next().expect("missing threads").parse().unwrap(); + plan.nr_iterations = lines + .next() + .expect("missing nr_iterations") + .parse() + .unwrap(); + let nr_ddl = lines + .next() + .expect("number of ddl statements") + .parse() + .unwrap(); + for _ in 0..nr_ddl { + plan.ddl_statements + .push(lines.next().expect("expected ddl statement").to_string()); + } + for _ in 0..plan.nr_threads { + let mut queries = vec![]; + for _ in 0..plan.nr_iterations { + queries.push( + lines + .next() + .expect("missing query for thread {}") + .to_string(), + ); + } + plan.queries_per_thread.push(queries); + } + Ok(plan) +} + +pub fn init_tracing() -> Result { + let (non_blocking, guard) = tracing_appender::non_blocking(std::io::stderr()); + if let Err(e) = tracing_subscriber::registry() + .with( + tracing_subscriber::fmt::layer() + .with_writer(non_blocking) + .with_ansi(false) + .with_line_number(true) + .with_thread_ids(true), + ) + .with(EnvFilter::from_default_env()) + .try_init() + { + println!("Unable to setup tracing appender: {:?}", e); + } + Ok(guard) +} #[tokio::main] -async fn main() { +async fn main() -> Result<(), Box> { + let _g = init_tracing()?; let (num_nodes, main_id) = (1, "n-001"); let startup_data = json!({ "num_nodes": num_nodes, @@ -17,28 +399,69 @@ async fn main() { lifecycle::setup_complete(&startup_data); antithesis_init(); - let opts = Opts::parse(); - let mut handles = Vec::new(); + let mut opts = Opts::parse(); + + let plan = if opts.load_log { + read_plan_from_log_file(&mut opts)? + } else { + generate_plan(&opts)? + }; + + let mut handles = Vec::with_capacity(opts.nr_threads); + let plan = Arc::new(plan); + + for thread in 0..opts.nr_threads { + let db = Arc::new(Builder::new_local(&opts.db_file).build().await?); + let plan = plan.clone(); + let conn = db.connect()?; + + // Apply each DDL statement individually + for stmt in &plan.ddl_statements { + println!("executing ddl {}", stmt); + if let Err(e) = conn.execute(stmt, ()).await { + match e { + limbo::Error::SqlExecutionFailure(e) => { + if e.contains("Corrupt database") { + panic!("Error creating table: {}", e); + } else { + println!("Error creating table: {}", e); + } + } + _ => panic!("Error creating table: {}", e), + } + } + } - for _ in 0..opts.nr_threads { - // TODO: share the database between threads - let db = Arc::new(Builder::new_local(":memory:").build().await.unwrap()); let nr_iterations = opts.nr_iterations; let db = db.clone(); - let handle = tokio::spawn(async move { - let conn = db.connect().unwrap(); - for _ in 0..nr_iterations { - let mut rows = conn.query("select 1", ()).await.unwrap(); - let row = rows.next().await.unwrap().unwrap(); - let value = row.get_value(0).unwrap(); - assert_always!(matches!(value, Value::Integer(1)), "value is incorrect"); + let handle = tokio::spawn(async move { + let conn = db.connect()?; + for query_index in 0..nr_iterations { + let sql = &plan.queries_per_thread[thread][query_index]; + println!("executing: {}", sql); + if let Err(e) = conn.execute(&sql, ()).await { + match e { + limbo::Error::SqlExecutionFailure(e) => { + if e.contains("Corrupt database") { + panic!("Error executing query: {}", e); + } else { + println!("Error executing query: {}", e); + } + } + _ => panic!("Error executing query: {}", e), + } + } } + Ok::<_, Box>(()) }); handles.push(handle); } + for handle in handles { - handle.await.unwrap(); + handle.await??; } - println!("Done."); + println!("Done. SQL statements written to {}", opts.log_file); + println!("Database file: {}", opts.db_file); + Ok(()) } diff --git a/stress/opts.rs b/stress/opts.rs index 392d79448..a8cbb5b2a 100644 --- a/stress/opts.rs +++ b/stress/opts.rs @@ -4,13 +4,43 @@ use clap::{command, Parser}; #[command(name = "limbo_stress")] #[command(author, version, about, long_about = None)] pub struct Opts { + /// Number of threads to run #[clap(short = 't', long, help = "the number of threads", default_value_t = 8)] pub nr_threads: usize, + + /// Number of iterations per thread #[clap( short = 'i', long, help = "the number of iterations", - default_value_t = 1000 + default_value_t = 100000 )] pub nr_iterations: usize, + + /// Log file for SQL statements + #[clap( + short = 'l', + long, + help = "log file for SQL statements", + default_value = "limbostress.log" + )] + pub log_file: String, + + /// Load log file instead of creating a new one + #[clap( + short = 'L', + long = "load-log", + help = "load log file instead of creating a new one", + default_value_t = false + )] + pub load_log: bool, + + /// Database file + #[clap( + short = 'd', + long, + help = "database file", + default_value = "limbostress.db" + )] + pub db_file: String, } diff --git a/testing/cli_tests/memory.py b/testing/cli_tests/memory.py index e96df3475..da98bcc1d 100755 --- a/testing/cli_tests/memory.py +++ b/testing/cli_tests/memory.py @@ -2,8 +2,6 @@ import os from test_limbo_cli import TestLimboShell - -sqlite_exec = "./target/debug/limbo" sqlite_flags = os.getenv("SQLITE_FLAGS", "-q").split(" ") diff --git a/testing/scalar-functions-datetime.test b/testing/scalar-functions-datetime.test index fd450dc02..f6441384f 100755 --- a/testing/scalar-functions-datetime.test +++ b/testing/scalar-functions-datetime.test @@ -589,3 +589,70 @@ set FMT [list %S.%3f %C %y %b %B %h %a %A %D %x %v %.f %.3f %.6f %.9f %3f %6f %9 foreach i $FMT { do_execsql_test strftime-invalid-$i "SELECT strftime('$i','2025-01-23T13:14:30.567');" {} } + + +# Tests for the TIMEDIFF function + +do_execsql_test timediff-basic-positive { + SELECT timediff('14:30:45', '12:00:00'); +} {"+0000-00-00 02:30:45.000"} + +do_execsql_test timediff-basic-negative { + SELECT timediff('12:00:00', '14:30:45'); +} {"-0000-00-00 02:30:45.000"} + +do_execsql_test timediff-with-milliseconds-positive { + SELECT timediff('12:00:01.300', '12:00:00.500'); +} {"+0000-00-00 00:00:00.800"} + +do_execsql_test timediff-same-time { + SELECT timediff('12:00:00', '12:00:00'); +} {"+0000-00-00 00:00:00.000"} + +do_execsql_test timediff-across-dates { + SELECT timediff('2023-05-11 01:15:00', '2023-05-10 23:30:00'); +} {"+0000-00-00 01:45:00.000"} + +do_execsql_test timediff-across-dates-negative { + SELECT timediff('2023-05-10 23:30:00', '2023-05-11 01:15:00'); +} {"-0000-00-00 01:45:00.000"} + +do_execsql_test timediff-different-formats { + SELECT timediff('2023-05-10T23:30:00', '2023-05-10 14:15:00'); +} {"+0000-00-00 09:15:00.000"} + +do_execsql_test timediff-with-timezone { + SELECT timediff('2023-05-10 23:30:00+02:00', '2023-05-10 18:30:00Z'); +} {"+0000-00-00 03:00:00.000"} + +do_execsql_test timediff-large-difference { + SELECT timediff('2023-05-12 10:00:00', '2023-05-10 08:00:00'); +} {"+0000-00-02 02:00:00.000"} + +do_execsql_test timediff-with-seconds-precision { + SELECT timediff('12:30:45.123', '12:30:44.987'); +} {"+0000-00-00 00:00:00.136"} + +do_execsql_test timediff-null-first-arg { + SELECT timediff(NULL, '12:00:00'); +} {{}} + +do_execsql_test timediff-null-second-arg { + SELECT timediff('12:00:00', NULL); +} {{}} + +do_execsql_test timediff-invalid-first-arg { + SELECT timediff('not-a-time', '12:00:00'); +} {{}} + +do_execsql_test timediff-invalid-second-arg { + SELECT timediff('12:00:00', 'not-a-time'); +} {{}} + +do_execsql_test timediff-julian-day { + SELECT timediff(2460000, 2460000.5); +} {"-0000-00-00 12:00:00.000"} + +do_execsql_test timediff-different-time-formats { + SELECT timediff('23:59:59', '00:00:00'); +} {"+0000-00-00 23:59:59.000"} \ No newline at end of file diff --git a/testing/scalar-functions.test b/testing/scalar-functions.test index 09e99a8f3..807c4971d 100755 --- a/testing/scalar-functions.test +++ b/testing/scalar-functions.test @@ -211,6 +211,38 @@ do_execsql_test likely-null { select likely(NULL) } {} +do_execsql_test likelihood-string { + SELECT likelihood('limbo', 0.5); +} {limbo} + +do_execsql_test likelihood-string-high-probability { + SELECT likelihood('database', 0.9375); +} {database} + +do_execsql_test likelihood-integer { + SELECT likelihood(100, 0.0625); +} {100} + +do_execsql_test likelihood-integer-probability-1 { + SELECT likelihood(42, 1.0); +} {42} + +do_execsql_test likelihood-decimal { + SELECT likelihood(12.34, 0.5); +} {12.34} + +do_execsql_test likelihood-null { + SELECT likelihood(NULL, 0.5); +} {} + +do_execsql_test likelihood-blob { + SELECT hex(likelihood(x'01020304', 0.5)); +} {01020304} + +do_execsql_test likelihood-zero-probability { + SELECT likelihood(999, 0.0); +} {999} + do_execsql_test unhex-str-ab { SELECT unhex('6162'); } {ab} diff --git a/testing/testing b/testing/testing deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/integration/fuzz/mod.rs b/tests/integration/fuzz/mod.rs index 5df3b49b4..73263e5f1 100644 --- a/tests/integration/fuzz/mod.rs +++ b/tests/integration/fuzz/mod.rs @@ -232,12 +232,8 @@ mod tests { const COMPARISONS: [&str; 5] = ["=", "<", "<=", ">", ">="]; - const ORDER_BY: [Option<&str>; 4] = [ - None, - Some("ORDER BY x"), - Some("ORDER BY x DESC"), - Some("ORDER BY x ASC"), - ]; + const ORDER_BY: [Option<&str>; 3] = [None, Some("ORDER BY x DESC"), Some("ORDER BY x ASC")]; + const SECONDARY_ORDER_BY: [Option<&str>; 3] = [None, Some(", y DESC"), Some(", y ASC")]; let print_dump_on_fail = |insert: &str, seed: u64| { let comment = format!("-- seed: {}; dump for manual debugging:", seed); @@ -252,56 +248,101 @@ mod tests { for comp in COMPARISONS.iter() { for order_by in ORDER_BY.iter() { - for max in 0..=3000 { - // see comment below about ordering and the '=' comparison operator; omitting LIMIT for that reason - // we mainly have LIMIT here for performance reasons but for = we want to get all the rows to ensure - // correctness in the = case - let limit = if *comp == "=" { "" } else { "LIMIT 5" }; + // make it more likely that the full 2-column index is utilized for seeking + let iter_count_per_permutation = if *comp == "=" { 2000 } else { 500 }; + println!( + "fuzzing {} iterations with comp: {:?}, order_by: {:?}", + iter_count_per_permutation, comp, order_by + ); + for _ in 0..iter_count_per_permutation { + let first_col_val = rng.random_range(0..=3000); + let mut limit = "LIMIT 5"; + let mut second_idx_col_cond = "".to_string(); + let mut second_idx_col_comp = "".to_string(); + + // somtetimes include the second index column in the where clause. + // make it more probable when first column has '=' constraint since those queries are usually faster to run + let second_col_prob = if *comp == "=" { 0.7 } else { 0.02 }; + if rng.random_bool(second_col_prob) { + let second_idx_col = rng.random_range(0..3000); + + second_idx_col_comp = + COMPARISONS[rng.random_range(0..COMPARISONS.len())].to_string(); + second_idx_col_cond = + format!(" AND y {} {}", second_idx_col_comp, second_idx_col); + } + + // if the first constraint is =, then half the time, use the second index column in the ORDER BY too + let mut secondary_order_by = None; + let use_secondary_order_by = order_by.is_some() + && *comp == "=" + && second_idx_col_comp != "" + && rng.random_bool(0.5); + let full_order_by = if use_secondary_order_by { + secondary_order_by = + SECONDARY_ORDER_BY[rng.random_range(0..SECONDARY_ORDER_BY.len())]; + if let Some(secondary) = secondary_order_by { + format!("{}{}", order_by.unwrap_or(""), secondary,) + } else { + order_by.unwrap_or("").to_string() + } + } else { + order_by.unwrap_or("").to_string() + }; + + // There are certain cases where SQLite does not bother iterating in reverse order despite the ORDER BY. + // These cases include e.g. + // SELECT * FROM t WHERE x = 3 ORDER BY x DESC + // SELECT * FROM t WHERE x = 3 and y < 100 ORDER BY x DESC + // + // The common thread being that the ORDER BY column is also constrained by an equality predicate, meaning + // that it doesn't semantically matter what the ordering is. + // + // We do not currently replicate this "lazy" behavior, so in these cases we want the full result set and ensure + // that if the result is not exactly equal, then the ordering must be the exact reverse. + let allow_reverse_ordering = { + if *comp != "=" { + false + } else if secondary_order_by.is_some() { + second_idx_col_comp == "=" + } else { + true + } + }; + if allow_reverse_ordering { + // see comment above about ordering and the '=' comparison operator; omitting LIMIT for that reason + // we mainly have LIMIT here for performance reasons but for = we want to get all the rows to ensure + // correctness in the = case + limit = ""; + } let query = format!( - "SELECT * FROM t WHERE x {} {} {} {}", - comp, - max, - order_by.unwrap_or(""), - limit + // e.g. SELECT * FROM t WHERE x = 1 AND y > 2 ORDER BY x DESC LIMIT 5 + "SELECT * FROM t WHERE x {} {} {} {} {}", + comp, first_col_val, second_idx_col_cond, full_order_by, limit, ); - log::trace!("query: {}", query); + log::debug!("query: {}", query); let limbo = limbo_exec_rows(&db, &limbo_conn, &query); let sqlite = sqlite_exec_rows(&sqlite_conn, &query); let is_equal = limbo == sqlite; if !is_equal { - // if the condition is = and the same rows are present but in different order, then we accept that - // e.g. sqlite doesn't bother iterating in reverse order if "WHERE X = 3 ORDER BY X DESC", but we currently do. - if *comp == "=" { + if allow_reverse_ordering { let limbo_row_count = limbo.len(); let sqlite_row_count = sqlite.len(); if limbo_row_count == sqlite_row_count { - for limbo_row in limbo.iter() { - if !sqlite.contains(limbo_row) { - // save insert to file and print the filename for debugging - let error_msg = format!("row not found in sqlite: query: {}, limbo: {:?}, sqlite: {:?}, seed: {}", query, limbo, sqlite, seed); - print_dump_on_fail(&insert, seed); - panic!("{}", error_msg); - } - } - for sqlite_row in sqlite.iter() { - if !limbo.contains(sqlite_row) { - let error_msg = format!("row not found in limbo: query: {}, limbo: {:?}, sqlite: {:?}, seed: {}", query, limbo, sqlite, seed); - print_dump_on_fail(&insert, seed); - panic!("{}", error_msg); - } - } - continue; + let limbo_rev = limbo.iter().cloned().rev().collect::>(); + assert_eq!(limbo_rev, sqlite, "query: {}, limbo: {:?}, sqlite: {:?}, seed: {}, allow_reverse_ordering: {}", query, limbo, sqlite, seed, allow_reverse_ordering); } else { print_dump_on_fail(&insert, seed); - let error_msg = format!("row count mismatch (limbo: {}, sqlite: {}): query: {}, limbo: {:?}, sqlite: {:?}, seed: {}", limbo_row_count, sqlite_row_count, query, limbo, sqlite, seed); + let error_msg = format!("row count mismatch (limbo row count: {}, sqlite row count: {}): query: {}, limbo: {:?}, sqlite: {:?}, seed: {}, allow_reverse_ordering: {}", limbo_row_count, sqlite_row_count, query, limbo, sqlite, seed, allow_reverse_ordering); panic!("{}", error_msg); } + } else { + print_dump_on_fail(&insert, seed); + panic!( + "query: {}, limbo row count: {}, limbo: {:?}, sqlite row count: {}, sqlite: {:?}, seed: {}, allow_reverse_ordering: {}", + query, limbo.len(), limbo, sqlite.len(), sqlite, seed, allow_reverse_ordering + ); } - print_dump_on_fail(&insert, seed); - panic!( - "query: {}, limbo: {:?}, sqlite: {:?}, seed: {}", - query, limbo, sqlite, seed - ); } } }