From dd58be3b602dfa78c3e95d77f8332f9b15911a8a Mon Sep 17 00:00:00 2001 From: Diego Reis Date: Wed, 5 Feb 2025 23:08:20 -0300 Subject: [PATCH 01/16] Add basic structure for crypto extension --- Cargo.lock | 63 ++++++++++++++++++++++++++ Cargo.toml | 1 + core/Cargo.toml | 4 +- core/ext/mod.rs | 4 ++ extensions/core/src/types.rs | 32 +++++++++++++- extensions/crypto/Cargo.toml | 21 +++++++++ extensions/crypto/src/crypto.rs | 55 +++++++++++++++++++++++ extensions/crypto/src/lib.rs | 78 +++++++++++++++++++++++++++++++++ testing/extensions.py | 1 - 9 files changed, 256 insertions(+), 3 deletions(-) create mode 100644 extensions/crypto/Cargo.toml create mode 100644 extensions/crypto/src/crypto.rs create mode 100644 extensions/crypto/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index cb6dda8d8..700451334 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -140,6 +140,12 @@ version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b04" +[[package]] +name = "arrayref" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" + [[package]] name = "arrayvec" version = "0.7.6" @@ -204,6 +210,19 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" +[[package]] +name = "blake3" +version = "1.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8ee0c1824c4dea5b5f81736aff91bae041d2c07ee1192bec91054e10e3e601e" +dependencies = [ + "arrayref", + "arrayvec", + "cc", + "cfg-if", + "constant_time_eq", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -455,6 +474,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "constant_time_eq" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -1590,6 +1615,7 @@ dependencies = [ "julian_day_converter", "libc", "libloading", + "limbo_crypto", "limbo_ext", "limbo_macros", "limbo_percentile", @@ -1621,6 +1647,16 @@ dependencies = [ "tracing", ] +[[package]] +name = "limbo_crypto" +version = "0.0.14" +dependencies = [ + "blake3", + "limbo_ext", + "mimalloc", + "ring", +] + [[package]] name = "limbo_ext" version = "0.0.14" @@ -2538,6 +2574,21 @@ dependencies = [ "bytemuck", ] +[[package]] +name = "ring" +version = "0.17.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.15", + "libc", + "spin", + "untrusted", + "windows-sys 0.52.0", +] + [[package]] name = "rstest" version = "0.18.2" @@ -2761,6 +2812,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + [[package]] name = "sqlite3-parser" version = "0.13.0" @@ -3130,6 +3187,12 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "url" version = "2.5.4" diff --git a/Cargo.toml b/Cargo.toml index 85b8c46d7..c6b4137fd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ members = [ "extensions/percentile", "extensions/vector", "extensions/time", + "extensions/crypto", ] exclude = ["perf/latency/limbo"] diff --git a/core/Cargo.toml b/core/Cargo.toml index f2bba5f59..12f827ea2 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -14,7 +14,7 @@ name = "limbo_core" path = "lib.rs" [features] -default = ["fs", "json", "uuid", "vector", "io_uring", "time"] +default = ["fs", "json", "uuid", "vector", "io_uring", "time", "crypto"] fs = [] json = [ "dep:jsonb", @@ -27,6 +27,7 @@ io_uring = ["dep:io-uring", "rustix/io_uring"] percentile = ["limbo_percentile/static"] regexp = ["limbo_regexp/static"] time = ["limbo_time/static"] +crypto = ["limbo_crypto/static"] [target.'cfg(target_os = "linux")'.dependencies] io-uring = { version = "0.6.1", optional = true } @@ -67,6 +68,7 @@ limbo_vector = { path = "../extensions/vector", optional = true, features = ["st limbo_regexp = { path = "../extensions/regexp", optional = true, features = ["static"] } limbo_percentile = { path = "../extensions/percentile", optional = true, features = ["static"] } limbo_time = { path = "../extensions/time", optional = true, features = ["static"] } +limbo_crypto = { path = "../extensions/crypto", optional = true, features = ["static"] } miette = "7.4.0" strum = "0.26" parking_lot = "0.12.3" diff --git a/core/ext/mod.rs b/core/ext/mod.rs index 8a9212556..06ca4d7fb 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -96,6 +96,10 @@ impl Database { if unsafe { !limbo_time::register_extension_static(&ext_api).is_ok() } { return Err("Failed to register time extension".to_string()); } + #[cfg(feature = "crypto")] + if unsafe { !limbo_crypto::register_extension_static(&ext_api).is_ok() } { + return Err("Failed to register crypto extension".to_string()); + } Ok(()) } } diff --git a/extensions/core/src/types.rs b/extensions/core/src/types.rs index 74fa670ad..63c9a3b54 100644 --- a/extensions/core/src/types.rs +++ b/extensions/core/src/types.rs @@ -1,4 +1,4 @@ -use std::fmt::Display; +use std::{fmt::Display, mem}; /// Error type is of type ExtError which can be /// either a user defined error or an error code @@ -204,6 +204,13 @@ impl Blob { pub fn new(data: *const u8, size: u64) -> Self { Self { data, size } } + + pub fn as_bytes(&self) -> &[u8] { + if self.data.is_null() { + return &[]; + } + unsafe { std::slice::from_raw_parts(self.data, self.size as usize) } + } } impl Value { @@ -303,6 +310,29 @@ impl Value { } } + // Return ValueData as raw bytes + pub fn as_bytes(&self) -> Vec { + let mut bytes = vec![]; + + unsafe { + match self.value_type { + ValueType::Integer => bytes.extend_from_slice(&self.value.int.to_le_bytes()), + ValueType::Float => bytes.extend_from_slice(&self.value.float.to_le_bytes()), + ValueType::Text => { + let text = self.value.text.as_ref().expect("Invalid text pointer"); + bytes.extend_from_slice(text.as_str().as_bytes()); + } + ValueType::Blob => { + let blob = self.value.blob.as_ref().expect("Invalid blob pointer"); + bytes.extend_from_slice(blob.as_bytes()); + } + ValueType::Error | ValueType::Null => {} + } + } + + bytes + } + /// Creates a new integer Value from an i64 pub fn from_integer(i: i64) -> Self { Self { diff --git a/extensions/crypto/Cargo.toml b/extensions/crypto/Cargo.toml new file mode 100644 index 000000000..9cd714156 --- /dev/null +++ b/extensions/crypto/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "limbo_crypto" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true + +[lib] +crate-type = ["cdylib", "lib"] + +[features] +static= [ "limbo_ext/static" ] + +[dependencies] +blake3 = "1.5.5" +limbo_ext = { path = "../core", features = ["static"] } +ring = "0.17.8" + +[target.'cfg(not(target_family = "wasm"))'.dependencies] +mimalloc = { version = "*", default-features = false } diff --git a/extensions/crypto/src/crypto.rs b/extensions/crypto/src/crypto.rs new file mode 100644 index 000000000..472b8dac1 --- /dev/null +++ b/extensions/crypto/src/crypto.rs @@ -0,0 +1,55 @@ +use crate::Error; +use blake3::Hasher; +use limbo_ext::{Value, ValueType}; +use ring::digest::{self, digest}; + +pub fn sha256(data: &Value) -> Result, Error> { + match data.value_type() { + ValueType::Error | ValueType::Null => Err(Error::InvalidType), + _ => { + let hash = digest(&digest::SHA256, &data.as_bytes()); + Ok(hash.as_ref().to_vec()) + } + } +} + +pub fn sha512(data: &Value) -> Result, Error> { + match data.value_type() { + ValueType::Error | ValueType::Null => Err(Error::InvalidType), + _ => { + let hash = digest(&digest::SHA512, &data.as_bytes()); + Ok(hash.as_ref().to_vec()) + } + } +} + +pub fn sha384(data: &Value) -> Result, Error> { + match data.value_type() { + ValueType::Error | ValueType::Null => Err(Error::InvalidType), + _ => { + let hash = digest(&digest::SHA384, &data.as_bytes()); + Ok(hash.as_ref().to_vec()) + } + } +} + +pub fn blake3(data: &Value) -> Result, Error> { + match data.value_type() { + ValueType::Error | ValueType::Null => Err(Error::InvalidType), + _ => { + let mut hasher = Hasher::new(); + hasher.update(data.as_bytes().as_ref()); + Ok(hasher.finalize().as_bytes().to_vec()) + } + } +} + +pub fn sha1(data: &Value) -> Result, Error> { + match data.value_type() { + ValueType::Error | ValueType::Null => Err(Error::InvalidType), + _ => { + let hash = digest(&digest::SHA1_FOR_LEGACY_USE_ONLY, &data.as_bytes()); + Ok(hash.as_ref().to_vec()) + } + } +} diff --git a/extensions/crypto/src/lib.rs b/extensions/crypto/src/lib.rs new file mode 100644 index 000000000..49f6d3e9b --- /dev/null +++ b/extensions/crypto/src/lib.rs @@ -0,0 +1,78 @@ +use crypto::{blake3, sha1, sha256, sha384, sha512}; +use limbo_ext::{register_extension, scalar, ResultCode, Value}; + +mod crypto; + +#[derive(Debug)] +enum Error { + InvalidType, +} + +#[scalar(name = "crypto_sha256", alias = "crypto_sha256")] +fn crypto_sha256(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::Error); + } + + let Ok(hash) = sha256(&args[0]) else { + return Value::error(ResultCode::Error); + }; + + Value::from_blob(hash) +} + +#[scalar(name = "crypto_sha512", alias = "crypto_sha512")] +fn crypto_sha512(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::Error); + } + + let Ok(hash) = sha512(&args[0]) else { + return Value::error(ResultCode::Error); + }; + + Value::from_blob(hash) +} + +#[scalar(name = "crypto_sha384", alias = "crypto_sha384")] +fn crypto_sha384(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::Error); + } + + let Ok(hash) = sha384(&args[0]) else { + return Value::error(ResultCode::Error); + }; + + Value::from_blob(hash) +} + +#[scalar(name = "crypto_blake3", alias = "crypto_blake3")] +fn crypto_blake3(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::Error); + } + + let Ok(hash) = blake3(&args[0]) else { + return Value::error(ResultCode::Error); + }; + + Value::from_blob(hash) +} + +#[scalar(name = "crypto_sha1", alias = "crypto_sha1")] +fn crypto_sha1(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::Error); + } + + let Ok(hash) = sha1(&args[0]) else { + return Value::error(ResultCode::Error); + }; + + Value::from_blob(hash) +} + +register_extension! { + scalars: { crypto_sha256, crypto_sha512, crypto_sha384, crypto_blake3, crypto_sha1 }, +} diff --git a/testing/extensions.py b/testing/extensions.py index 74755a012..a1094e865 100755 --- a/testing/extensions.py +++ b/testing/extensions.py @@ -255,7 +255,6 @@ def test_aggregates(pipe): pipe, "SELECT percentile_disc(value, 0.55) from test;", validate_percentile_disc ) - def main(): pipe = init_limbo() try: From 846d5ed4141666bcd4a950c9d34e3ada4f4aafa1 Mon Sep 17 00:00:00 2001 From: Diego Reis Date: Thu, 6 Feb 2025 00:04:36 -0300 Subject: [PATCH 02/16] add md5 and encode to extension --- extensions/crypto/Cargo.toml | 4 ++++ extensions/crypto/src/crypto.rs | 36 +++++++++++++++++++++++++++++++++ extensions/crypto/src/lib.rs | 33 +++++++++++++++++++++++++++--- 3 files changed, 70 insertions(+), 3 deletions(-) diff --git a/extensions/crypto/Cargo.toml b/extensions/crypto/Cargo.toml index 9cd714156..84bd10efc 100644 --- a/extensions/crypto/Cargo.toml +++ b/extensions/crypto/Cargo.toml @@ -13,9 +13,13 @@ crate-type = ["cdylib", "lib"] static= [ "limbo_ext/static" ] [dependencies] +ascii85 = "0.2.1" blake3 = "1.5.5" +data-encoding = "2.7.0" limbo_ext = { path = "../core", features = ["static"] } +md5 = "0.7.0" ring = "0.17.8" +urlencoding = "2.1.3" [target.'cfg(not(target_family = "wasm"))'.dependencies] mimalloc = { version = "*", default-features = false } diff --git a/extensions/crypto/src/crypto.rs b/extensions/crypto/src/crypto.rs index 472b8dac1..654307bd4 100644 --- a/extensions/crypto/src/crypto.rs +++ b/extensions/crypto/src/crypto.rs @@ -1,5 +1,6 @@ use crate::Error; use blake3::Hasher; +use data_encoding::{BASE32, BASE64, HEXLOWER}; use limbo_ext::{Value, ValueType}; use ring::digest::{self, digest}; @@ -53,3 +54,38 @@ pub fn sha1(data: &Value) -> Result, Error> { } } } + +pub fn md5(data: &Value) -> Result, Error> { + match data.value_type() { + ValueType::Error | ValueType::Null => Err(Error::InvalidType), + _ => { + let digest = md5::compute::<&Vec>(data.as_bytes().as_ref()); + + Ok(digest.as_ref().to_vec()) + } + } +} + +pub fn encode(data: &Value, format: &Value) -> Result { + match (data.value_type(), format.value_type()) { + (ValueType::Error, _) | (ValueType::Null, _) => Err(Error::InvalidType), + (_, ValueType::Text) => match format.to_text().unwrap().to_lowercase().as_str() { + "base32" => Ok(Value::from_text(BASE32.encode(data.as_bytes().as_ref()))), + "base64" => Ok(Value::from_text(BASE64.encode(data.as_bytes().as_ref()))), + "hex" => Ok(Value::from_text(HEXLOWER.encode(data.as_bytes().as_ref()))), + "base85" => { + let result = ascii85::encode(data.as_bytes().as_ref()) + .replace("<~", "") + .replace("~>", ""); + Ok(Value::from_text(result)) + } + "url" => { + let data = data.as_bytes(); + let url = urlencoding::encode_binary(&data); + Ok(Value::from_text(url.to_string())) + } + _ => Err(Error::UnknownOperation), + }, + _ => Err(Error::InvalidType), + } +} diff --git a/extensions/crypto/src/lib.rs b/extensions/crypto/src/lib.rs index 49f6d3e9b..09fe28c38 100644 --- a/extensions/crypto/src/lib.rs +++ b/extensions/crypto/src/lib.rs @@ -1,4 +1,4 @@ -use crypto::{blake3, sha1, sha256, sha384, sha512}; +use crypto::{blake3, encode, md5, sha1, sha256, sha384, sha512}; use limbo_ext::{register_extension, scalar, ResultCode, Value}; mod crypto; @@ -6,6 +6,7 @@ mod crypto; #[derive(Debug)] enum Error { InvalidType, + UnknownOperation, } #[scalar(name = "crypto_sha256", alias = "crypto_sha256")] @@ -73,6 +74,32 @@ fn crypto_sha1(args: &[Value]) -> Value { Value::from_blob(hash) } -register_extension! { - scalars: { crypto_sha256, crypto_sha512, crypto_sha384, crypto_blake3, crypto_sha1 }, +#[scalar(name = "crypto_md5", alias = "crypto_md5")] +fn crypto_md5(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::error(ResultCode::Error); + } + + let Ok(hash) = md5(&args[0]) else { + return Value::error(ResultCode::Error); + }; + + Value::from_blob(hash) +} + +#[scalar(name = "crypto_encode", alias = "crypto_encode")] +fn crypto_encode(args: &[Value]) -> Value { + if args.len() != 2 { + return Value::error(ResultCode::Error); + } + + let Ok(payload) = encode(&args[0], &args[1]) else { + return Value::error(ResultCode::Error); + }; + + payload +} + +register_extension! { + scalars: { crypto_sha256, crypto_sha512, crypto_sha384, crypto_blake3, crypto_sha1, crypto_md5, crypto_encode }, } From 05057a04ac9ae0afb3269a580961a20d0cab4567 Mon Sep 17 00:00:00 2001 From: Diego Reis Date: Thu, 6 Feb 2025 01:42:47 -0300 Subject: [PATCH 03/16] completes crypto extension It aims to be compatible with https://github.com/nalgeon/sqlean/blob/main/docs/crypto.md --- Cargo.lock | 21 ++++ core/Cargo.toml | 2 +- extensions/core/src/types.rs | 2 +- extensions/crypto/Cargo.toml | 1 - extensions/crypto/src/crypto.rs | 136 +++++++++++++++++++++++-- extensions/crypto/src/lib.rs | 21 +++- testing/extensions.py | 171 +++++++++++++++++++++++++++++++- 7 files changed, 339 insertions(+), 15 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 700451334..ee60a825c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -642,6 +642,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "data-encoding" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e60eed09d8c01d3cee5b7d30acb059b76614c918fa0f992e0dd6eeb10daad6f" + [[package]] name = "debugid" version = "0.8.0" @@ -1652,9 +1658,12 @@ name = "limbo_crypto" version = "0.0.14" dependencies = [ "blake3", + "data-encoding", "limbo_ext", + "md5", "mimalloc", "ring", + "urlencoding", ] [[package]] @@ -1787,6 +1796,12 @@ version = "0.4.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" +[[package]] +name = "md5" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" + [[package]] name = "memchr" version = "2.7.4" @@ -3204,6 +3219,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf16_iter" version = "1.0.5" diff --git a/core/Cargo.toml b/core/Cargo.toml index 12f827ea2..97406179f 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -14,7 +14,7 @@ name = "limbo_core" path = "lib.rs" [features] -default = ["fs", "json", "uuid", "vector", "io_uring", "time", "crypto"] +default = ["fs", "json", "uuid", "vector", "io_uring", "time"] fs = [] json = [ "dep:jsonb", diff --git a/extensions/core/src/types.rs b/extensions/core/src/types.rs index 63c9a3b54..464e07bfd 100644 --- a/extensions/core/src/types.rs +++ b/extensions/core/src/types.rs @@ -1,4 +1,4 @@ -use std::{fmt::Display, mem}; +use std::fmt::Display; /// Error type is of type ExtError which can be /// either a user defined error or an error code diff --git a/extensions/crypto/Cargo.toml b/extensions/crypto/Cargo.toml index 84bd10efc..7aa8cc5e6 100644 --- a/extensions/crypto/Cargo.toml +++ b/extensions/crypto/Cargo.toml @@ -13,7 +13,6 @@ crate-type = ["cdylib", "lib"] static= [ "limbo_ext/static" ] [dependencies] -ascii85 = "0.2.1" blake3 = "1.5.5" data-encoding = "2.7.0" limbo_ext = { path = "../core", features = ["static"] } diff --git a/extensions/crypto/src/crypto.rs b/extensions/crypto/src/crypto.rs index 654307bd4..ddebbd0a6 100644 --- a/extensions/crypto/src/crypto.rs +++ b/extensions/crypto/src/crypto.rs @@ -3,6 +3,7 @@ use blake3::Hasher; use data_encoding::{BASE32, BASE64, HEXLOWER}; use limbo_ext::{Value, ValueType}; use ring::digest::{self, digest}; +use std::{borrow::Cow, error::Error as StdError}; pub fn sha256(data: &Value) -> Result, Error> { match data.value_type() { @@ -73,12 +74,7 @@ pub fn encode(data: &Value, format: &Value) -> Result { "base32" => Ok(Value::from_text(BASE32.encode(data.as_bytes().as_ref()))), "base64" => Ok(Value::from_text(BASE64.encode(data.as_bytes().as_ref()))), "hex" => Ok(Value::from_text(HEXLOWER.encode(data.as_bytes().as_ref()))), - "base85" => { - let result = ascii85::encode(data.as_bytes().as_ref()) - .replace("<~", "") - .replace("~>", ""); - Ok(Value::from_text(result)) - } + "base85" => Ok(Value::from_text(encode_ascii85(data.as_bytes().as_ref()))), "url" => { let data = data.as_bytes(); let url = urlencoding::encode_binary(&data); @@ -89,3 +85,131 @@ pub fn encode(data: &Value, format: &Value) -> Result { _ => Err(Error::InvalidType), } } + +pub fn decode(data: &Value, format: &Value) -> Result { + match (data.value_type(), format.value_type()) { + (ValueType::Error, _) | (ValueType::Null, _) => Err(Error::InvalidType), + (ValueType::Text, ValueType::Text) => { + let format_str = format.to_text().ok_or(Error::InvalidType)?.to_lowercase(); + let input_text = data.to_text().ok_or(Error::InvalidType)?; + + match format_str.as_str() { + "base32" => { + let payload = BASE32 + .decode(input_text.as_bytes()) + .map_err(|_| Error::DecodeFailed)?; + Ok(Value::from_text( + String::from_utf8(payload).map_err(|_| Error::InvalidUtf8)?, + )) + } + "base64" => { + let payload = BASE64 + .decode(input_text.as_bytes()) + .map_err(|_| Error::DecodeFailed)?; + Ok(Value::from_text( + String::from_utf8(payload).map_err(|_| Error::InvalidUtf8)?, + )) + } + "hex" => { + let payload = HEXLOWER + .decode(input_text.to_lowercase().as_bytes()) + .map_err(|_| Error::DecodeFailed)?; + Ok(Value::from_text( + String::from_utf8(payload).map_err(|_| Error::InvalidUtf8)?, + )) + } + "base85" => { + let decoded = decode_ascii85(&input_text).map_err(|_| Error::DecodeFailed)?; + + Ok(Value::from_text( + String::from_utf8(decoded).map_err(|_| Error::InvalidUtf8)?, + )) + } + "url" => { + let decoded = urlencoding::decode_binary(input_text.as_bytes()); + Ok(Value::from_text( + String::from_utf8(decoded.to_vec()).map_err(|_| Error::InvalidUtf8)?, + )) + } + _ => Err(Error::UnknownOperation), + } + } + _ => Err(Error::InvalidType), + } +} + +// Ascii85 functions to avoid +1 dependency and to remove '~>' '<~' + +const TABLE: [u32; 5] = [85 * 85 * 85 * 85, 85 * 85 * 85, 85 * 85, 85, 1]; + +fn decode_ascii85(input: &str) -> Result, Box> { + let mut result = Vec::with_capacity(4 * (input.len() / 5 + 16)); + + let mut counter = 0; + let mut chunk = 0; + + for digit in input.trim().bytes().filter(|c| !c.is_ascii_whitespace()) { + if digit == b'z' { + if counter == 0 { + result.extend_from_slice(&[0, 0, 0, 0]); + } else { + return Err("Missaligned z in input".into()); + } + } + + if digit < 33 || digit > 117 { + return Err("Input char is out of range for Ascii85".into()); + } + + decode_digit(digit, &mut counter, &mut chunk, &mut result); + } + + let mut to_remove = 0; + + while counter != 0 { + decode_digit(b'u', &mut counter, &mut chunk, &mut result); + to_remove += 1; + } + + result.drain((result.len() - to_remove)..result.len()); + + Ok(result) +} + +fn decode_digit(digit: u8, counter: &mut usize, chunk: &mut u32, result: &mut Vec) { + let byte = digit - 33; + + *chunk += byte as u32 * TABLE[*counter]; + + if *counter == 4 { + result.extend_from_slice(&chunk.to_be_bytes()); + *chunk = 0; + *counter = 0; + } else { + *counter += 1; + } +} + +fn encode_ascii85(input: &[u8]) -> String { + let mut result = String::with_capacity(5 * (input.len() / 4 + 16)); + + for chunk in input.chunks(4) { + let (chunk, count) = if chunk.len() == 4 { + (Cow::from(chunk), 5) + } else { + let mut new_chunk = Vec::new(); + new_chunk.resize_with(4, || 0); + new_chunk[..chunk.len()].copy_from_slice(chunk); + (Cow::from(new_chunk), 5 - (4 - chunk.len())) + }; + + let number = u32::from_be_bytes(chunk.as_ref().try_into().expect("Internal Error")); + + for i in 0..count { + let digit = (((number / TABLE[i]) % 85) + 33) as u8; + result.push(digit as char); + } + } + + result +} diff --git a/extensions/crypto/src/lib.rs b/extensions/crypto/src/lib.rs index 09fe28c38..604f313e0 100644 --- a/extensions/crypto/src/lib.rs +++ b/extensions/crypto/src/lib.rs @@ -1,4 +1,4 @@ -use crypto::{blake3, encode, md5, sha1, sha256, sha384, sha512}; +use crypto::{blake3, decode, encode, md5, sha1, sha256, sha384, sha512}; use limbo_ext::{register_extension, scalar, ResultCode, Value}; mod crypto; @@ -7,6 +7,8 @@ mod crypto; enum Error { InvalidType, UnknownOperation, + DecodeFailed, + InvalidUtf8, } #[scalar(name = "crypto_sha256", alias = "crypto_sha256")] @@ -100,6 +102,19 @@ fn crypto_encode(args: &[Value]) -> Value { payload } -register_extension! { - scalars: { crypto_sha256, crypto_sha512, crypto_sha384, crypto_blake3, crypto_sha1, crypto_md5, crypto_encode }, +#[scalar(name = "crypto_decode", alias = "crypto_decode")] +fn crypto_decode(args: &[Value]) -> Value { + if args.len() != 2 { + return Value::error(ResultCode::Error); + } + + let Ok(payload) = decode(&args[0], &args[1]) else { + return Value::error(ResultCode::Error); + }; + + payload +} + +register_extension! { + scalars: { crypto_sha256, crypto_sha512, crypto_sha384, crypto_blake3, crypto_sha1, crypto_md5, crypto_encode, crypto_decode }, } diff --git a/testing/extensions.py b/testing/extensions.py index a1094e865..d4a0a69c0 100755 --- a/testing/extensions.py +++ b/testing/extensions.py @@ -110,7 +110,6 @@ def validate_blob(result): # and assert they are valid hex digits return int(result, 16) is not None - def validate_string_uuid(result): return len(result) == 36 and result.count("-") == 4 @@ -130,7 +129,6 @@ def assert_now_unixtime(result): def assert_specific_time(result): return result == "1736720789" - def test_uuid(pipe): specific_time = "01945ca0-3189-76c0-9a8f-caf310fc8b8e" # these are built into the binary, so we just test they work @@ -207,7 +205,6 @@ def validate_percentile2(res): def validate_percentile_disc(res): return res == "40.0" - def test_aggregates(pipe): extension_path = "./target/debug/liblimbo_percentile.so" # assert no function before extension loads @@ -255,12 +252,180 @@ def test_aggregates(pipe): pipe, "SELECT percentile_disc(value, 0.55) from test;", validate_percentile_disc ) +# Hashes +def validate_blake3(a): + return a == "6437b3ac38465133ffb63b75273a8db548c558465d79db03fd359c6cd5bd9d85" + +def validate_md5(a): + return a == "900150983cd24fb0d6963f7d28e17f72" + +def validate_sha1(a): + return a == "a9993e364706816aba3e25717850c26c9cd0d89d" + +def validate_sha256(a): + return a == "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad" + +def validate_sha384(a): + return a == "cb00753f45a35e8bb5a03d699ac65007272c32ab0eded1631a8b605a43ff5bed8086072ba1e7cc2358baeca134c825a7" + +def validate_sha512(a): + return a == "ddaf35a193617abacc417349ae20413112e6fa4e89a97ea20a9eeee64b55d39a2192992a274fc1a836ba3c23a3feebbd454d4423643ce80e2a9ac94fa54ca49f" + +# Encoders and decoders +def validate_url_encode(a): + return a == f"%2Fhello%3Ftext%3D%28%E0%B2%A0_%E0%B2%A0%29" + +def validate_url_decode(a): + return a == "/hello?text=(ಠ_ಠ)" + +def validate_hex_encode(a): + return a == "68656c6c6f" + +def validate_hex_decode(a): + return a == "hello" + +def validate_base85_encode(a): + return a == "BOu!rDZ" + +def validate_base85_decode(a): + return a == "hello" + +def validate_base32_encode(a): + return a == "NBSWY3DP" + +def validate_base32_decode(a): + return a == "hello" + +def validate_base64_encode(a): + return a == "aGVsbG8=" + +def validate_base64_decode(a): + return a == "hello" + +def test_crypto(pipe): + extension_path = "./target/debug/liblimbo_crypto.so" + # assert no function before extension loads + run_test( + pipe, + "SELECT crypto_blake('a');", + returns_error, + "crypto_blake3 returns null when ext not loaded", + ) + run_test( + pipe, + f".load {extension_path}", + returns_null, + "load extension command works properly", + ) + # Hashing and Decode + run_test( + pipe, + "SELECT crypto_encode(crypto_blake3('abc'), 'hex');", + validate_blake3, + "blake3 should encrypt correctly" + ) + run_test( + pipe, + "SELECT crypto_encode(crypto_md5('abc'), 'hex');", + validate_md5, + "md5 should encrypt correctly" + ) + run_test( + pipe, + "SELECT crypto_encode(crypto_sha1('abc'), 'hex');", + validate_sha1, + "sha1 should encrypt correctly" + ) + run_test( + pipe, + "SELECT crypto_encode(crypto_sha256('abc'), 'hex');", + validate_sha256, + "sha256 should encrypt correctly" + ) + run_test( + pipe, + "SELECT crypto_encode(crypto_sha384('abc'), 'hex');", + validate_sha384, + "sha384 should encrypt correctly" + ) + run_test( + pipe, + "SELECT crypto_encode(crypto_sha512('abc'), 'hex');", + validate_sha512, + "sha512 should encrypt correctly" + ) + + # Encoding and Decoding + run_test( + pipe, + "SELECT crypto_encode('hello', 'base32');", + validate_base32_encode, + "base32 should encode correctly" + ) + run_test( + pipe, + "SELECT crypto_decode('NBSWY3DP', 'base32');", + validate_base32_decode, + "base32 should decode correctly" + ) + run_test( + pipe, + "SELECT crypto_encode('hello', 'base64');", + validate_base64_encode, + "base64 should encode correctly" + ) + run_test( + pipe, + "SELECT crypto_decode('aGVsbG8=', 'base64');", + validate_base64_decode, + "base64 should decode correctly" + ) + run_test( + pipe, + "SELECT crypto_encode('hello', 'base85');", + validate_base85_encode, + "base85 should encode correctly" + ) + run_test( + pipe, + "SELECT crypto_decode('BOu!rDZ', 'base85');", + validate_base85_decode, + "base85 should decode correctly" + ) + + run_test( + pipe, + "SELECT crypto_encode('hello', 'hex');", + validate_hex_encode, + "hex should encode correctly" + ) + run_test( + pipe, + "SELECT crypto_decode('68656c6c6f', 'hex');", + validate_hex_decode, + "hex should decode correctly" + ) + + run_test( + pipe, + "SELECT crypto_encode('/hello?text=(ಠ_ಠ)', 'url');", + validate_url_encode, + "url should encode correctly" + ) + run_test( + pipe, + f"SELECT crypto_decode('%2Fhello%3Ftext%3D%28%E0%B2%A0_%E0%B2%A0%29', 'url');", + validate_url_decode, + "url should decode correctly" + ) + def main(): pipe = init_limbo() try: test_regexp(pipe) test_uuid(pipe) test_aggregates(pipe) + test_crypto(pipe) except Exception as e: print(f"Test FAILED: {e}") pipe.terminate() From f9828e0e6f68ade34d2b676aa40d3f0626ab043c Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Thu, 6 Feb 2025 12:26:25 +0200 Subject: [PATCH 04/16] core: Parse UTF-8 strings lazily --- core/json/json_operations.rs | 12 ++-- core/json/mod.rs | 93 ++++++++++++---------------- core/storage/sqlite3_ondisk.rs | 11 +++- core/types.rs | 38 +++++++----- core/vdbe/datetime.rs | 8 +-- core/vdbe/insn.rs | 94 ++++++++++++++-------------- core/vdbe/likeop.rs | 2 +- core/vdbe/mod.rs | 109 +++++++++++++++++---------------- core/vdbe/printf.rs | 4 +- 9 files changed, 188 insertions(+), 183 deletions(-) diff --git a/core/json/json_operations.rs b/core/json/json_operations.rs index 6b378b09f..a52b45760 100644 --- a/core/json/json_operations.rs +++ b/core/json/json_operations.rs @@ -164,7 +164,7 @@ pub fn json_remove(args: &[OwnedValue]) -> crate::Result { .iter() .map(|path| { if let OwnedValue::Text(path) = path { - json_path(&path.value) + json_path(path.as_str()) } else { crate::bail_constraint_error!("bad JSON path: {:?}", path.to_string()) } @@ -514,7 +514,7 @@ mod tests { let result = json_remove(&args).unwrap(); match result { - OwnedValue::Text(t) => assert_eq!(t.value.as_str(), "[1,2,4,5]"), + OwnedValue::Text(t) => assert_eq!(t.as_str(), "[1,2,4,5]"), _ => panic!("Expected Text value"), } } @@ -529,7 +529,7 @@ mod tests { let result = json_remove(&args).unwrap(); match result { - OwnedValue::Text(t) => assert_eq!(t.value.as_str(), r#"{"b":2}"#), + OwnedValue::Text(t) => assert_eq!(t.as_str(), r#"{"b":2}"#), _ => panic!("Expected Text value"), } } @@ -543,7 +543,7 @@ mod tests { let result = json_remove(&args).unwrap(); match result { - OwnedValue::Text(t) => assert_eq!(t.value.as_str(), r#"{"a":{"b":{"d":2}}}"#), + OwnedValue::Text(t) => assert_eq!(t.as_str(), r#"{"a":{"b":{"d":2}}}"#), _ => panic!("Expected Text value"), } } @@ -557,7 +557,7 @@ mod tests { let result = json_remove(&args).unwrap(); match result { - OwnedValue::Text(t) => assert_eq!(t.value.as_str(), r#"{"a":2,"a":3}"#), + OwnedValue::Text(t) => assert_eq!(t.as_str(), r#"{"a":2,"a":3}"#), _ => panic!("Expected Text value"), } } @@ -584,7 +584,7 @@ mod tests { let result = json_remove(&args).unwrap(); match result { OwnedValue::Text(t) => { - let value = t.value.as_str(); + let value = t.as_str(); assert!(value.contains(r#"[1,3]"#)); assert!(value.contains(r#"{"x":2}"#)); } diff --git a/core/json/mod.rs b/core/json/mod.rs index 57d9ba2b3..31f12ba14 100644 --- a/core/json/mod.rs +++ b/core/json/mod.rs @@ -73,7 +73,7 @@ pub fn get_json(json_value: &OwnedValue, indent: Option<&str>) -> crate::Result< fn get_json_value(json_value: &OwnedValue) -> crate::Result { match json_value { - OwnedValue::Text(ref t) => match from_str::(&t.value) { + OwnedValue::Text(ref t) => match from_str::(t.as_str()) { Ok(json) => Ok(json), Err(_) => { crate::bail_parse_error!("malformed JSON") @@ -104,9 +104,9 @@ pub fn json_array(values: &[OwnedValue]) -> crate::Result { OwnedValue::Blob(_) => crate::bail_constraint_error!("JSON cannot hold BLOB values"), OwnedValue::Text(t) => { if t.subtype == TextSubtype::Json { - s.push_str(&t.value); + s.push_str(t.as_str()); } else { - match to_string(&*t.value) { + match to_string(&t.as_str().to_string()) { Ok(json) => s.push_str(&json), Err(_) => crate::bail_parse_error!("malformed JSON"), } @@ -166,10 +166,12 @@ pub fn json_set(json: &OwnedValue, values: &[OwnedValue]) -> crate::Result Val::String(value.to_string()), + OwnedValue::Text( + t @ Text { + subtype: TextSubtype::Text, + .. + }, + ) => Val::String(t.as_str().to_string()), _ => get_json_value(value)?, }; @@ -323,7 +325,7 @@ fn convert_db_type_to_json(value: &OwnedValue) -> crate::Result { OwnedValue::Text(t) => match t.subtype { // Convert only to json if the subtype is json (if we got it from another json function) TextSubtype::Json => get_json_value(value)?, - TextSubtype::Text => Val::String(t.value.to_string()), + TextSubtype::Text => Val::String(t.as_str().to_string()), }, OwnedValue::Blob(_) => crate::bail_constraint_error!("JSON cannot hold BLOB values"), unsupported_value => crate::bail_constraint_error!( @@ -431,18 +433,21 @@ fn json_extract_single<'a>( fn json_path_from_owned_value(path: &OwnedValue, strict: bool) -> crate::Result> { let json_path = if strict { match path { - OwnedValue::Text(t) => json_path(t.value.as_str())?, + OwnedValue::Text(t) => json_path(t.as_str())?, OwnedValue::Null => return Ok(None), _ => crate::bail_constraint_error!("JSON path error near: {:?}", path.to_string()), } } else { match path { OwnedValue::Text(t) => { - if t.value.starts_with("$") { - json_path(t.value.as_str())? + if t.as_str().starts_with("$") { + json_path(t.as_str())? } else { JsonPath { - elements: vec![PathElement::Root(), PathElement::Key(t.value.to_string())], + elements: vec![ + PathElement::Root(), + PathElement::Key(t.as_str().to_string()), + ], } } } @@ -606,7 +611,7 @@ fn find_or_create_target<'a>(json: &'a mut Val, path: &JsonPath) -> Option crate::Result { match json { - OwnedValue::Text(t) => match from_str::(&t.value) { + OwnedValue::Text(t) => match from_str::(t.as_str()) { Ok(_) => Ok(OwnedValue::Integer(0)), Err(JsonError::Message { location, .. }) => { if let Some(loc) = location { @@ -639,7 +644,7 @@ pub fn json_object(values: &[OwnedValue]) -> crate::Result { .map(|chunk| match chunk { [key, value] => { let key = match key { - OwnedValue::Text(t) => t.value.to_string(), + OwnedValue::Text(t) => t.as_str().to_string(), _ => crate::bail_constraint_error!("labels must be TEXT"), }; let json_val = convert_db_type_to_json(value)?; @@ -656,7 +661,7 @@ pub fn json_object(values: &[OwnedValue]) -> crate::Result { pub fn is_json_valid(json_value: &OwnedValue) -> crate::Result { match json_value { - OwnedValue::Text(ref t) => match from_str::(&t.value) { + OwnedValue::Text(ref t) => match from_str::(t.as_str()) { Ok(_) => Ok(OwnedValue::Integer(1)), Err(_) => Ok(OwnedValue::Integer(0)), }, @@ -679,7 +684,7 @@ mod tests { let input = OwnedValue::build_text(Rc::new("{ key: 'value' }".to_string())); let result = get_json(&input, None).unwrap(); if let OwnedValue::Text(result_str) = result { - assert!(result_str.value.contains("\"key\":\"value\"")); + assert!(result_str.as_str().contains("\"key\":\"value\"")); assert_eq!(result_str.subtype, TextSubtype::Json); } else { panic!("Expected OwnedValue::Text"); @@ -691,7 +696,7 @@ mod tests { let input = OwnedValue::build_text(Rc::new("{ key: ''value'' }".to_string())); let result = get_json(&input, None).unwrap(); if let OwnedValue::Text(result_str) = result { - assert!(result_str.value.contains("\"key\":\"value\"")); + assert!(result_str.as_str().contains("\"key\":\"value\"")); assert_eq!(result_str.subtype, TextSubtype::Json); } else { panic!("Expected OwnedValue::Text"); @@ -703,7 +708,7 @@ mod tests { let input = OwnedValue::build_text(Rc::new("{ \"key\": Infinity }".to_string())); let result = get_json(&input, None).unwrap(); if let OwnedValue::Text(result_str) = result { - assert!(result_str.value.contains("{\"key\":9e999}")); + assert!(result_str.as_str().contains("{\"key\":9e999}")); assert_eq!(result_str.subtype, TextSubtype::Json); } else { panic!("Expected OwnedValue::Text"); @@ -715,7 +720,7 @@ mod tests { let input = OwnedValue::build_text(Rc::new("{ \"key\": -Infinity }".to_string())); let result = get_json(&input, None).unwrap(); if let OwnedValue::Text(result_str) = result { - assert!(result_str.value.contains("{\"key\":-9e999}")); + assert!(result_str.as_str().contains("{\"key\":-9e999}")); assert_eq!(result_str.subtype, TextSubtype::Json); } else { panic!("Expected OwnedValue::Text"); @@ -727,7 +732,7 @@ mod tests { let input = OwnedValue::build_text(Rc::new("{ \"key\": NaN }".to_string())); let result = get_json(&input, None).unwrap(); if let OwnedValue::Text(result_str) = result { - assert!(result_str.value.contains("{\"key\":null}")); + assert!(result_str.as_str().contains("{\"key\":null}")); assert_eq!(result_str.subtype, TextSubtype::Json); } else { panic!("Expected OwnedValue::Text"); @@ -749,7 +754,7 @@ mod tests { let input = OwnedValue::build_text(Rc::new("{\"key\":\"value\"}".to_string())); let result = get_json(&input, None).unwrap(); if let OwnedValue::Text(result_str) = result { - assert!(result_str.value.contains("\"key\":\"value\"")); + assert!(result_str.as_str().contains("\"key\":\"value\"")); assert_eq!(result_str.subtype, TextSubtype::Json); } else { panic!("Expected OwnedValue::Text"); @@ -772,7 +777,7 @@ mod tests { let input = OwnedValue::Blob(Rc::new(binary_json)); let result = get_json(&input, None).unwrap(); if let OwnedValue::Text(result_str) = result { - assert!(result_str.value.contains("\"asd\":\"adf\"")); + assert!(result_str.as_str().contains("\"asd\":\"adf\"")); assert_eq!(result_str.subtype, TextSubtype::Json); } else { panic!("Expected OwnedValue::Text"); @@ -809,7 +814,7 @@ mod tests { let result = json_array(&input).unwrap(); if let OwnedValue::Text(res) = result { - assert_eq!(res.value.as_str(), "[\"value1\",\"value2\",1,1.1]"); + assert_eq!(res.as_str(), "[\"value1\",\"value2\",1,1.1]"); assert_eq!(res.subtype, TextSubtype::Json); } else { panic!("Expected OwnedValue::Text"); @@ -822,7 +827,7 @@ mod tests { let result = json_array(&input).unwrap(); if let OwnedValue::Text(res) = result { - assert_eq!(res.value.as_str(), "[]"); + assert_eq!(res.as_str(), "[]"); assert_eq!(res.subtype, TextSubtype::Json); } else { panic!("Expected OwnedValue::Text"); @@ -1064,7 +1069,7 @@ mod tests { let OwnedValue::Text(json_text) = result else { panic!("Expected OwnedValue::Text"); }; - assert_eq!(json_text.value.as_str(), r#"{"key":"value"}"#); + assert_eq!(json_text.as_str(), r#"{"key":"value"}"#); } #[test] @@ -1100,7 +1105,7 @@ mod tests { panic!("Expected OwnedValue::Text"); }; assert_eq!( - json_text.value.as_str(), + json_text.as_str(), r#"{"text_key":"text_value","json_key":{"json":"value","number":1},"integer_key":1,"float_key":1.1,"null_key":null}"# ); } @@ -1115,7 +1120,7 @@ mod tests { let OwnedValue::Text(json_text) = result else { panic!("Expected OwnedValue::Text"); }; - assert_eq!(json_text.value.as_str(), r#"{"key":{"json":"value"}}"#); + assert_eq!(json_text.as_str(), r#"{"key":{"json":"value"}}"#); } #[test] @@ -1128,10 +1133,7 @@ mod tests { let OwnedValue::Text(json_text) = result else { panic!("Expected OwnedValue::Text"); }; - assert_eq!( - json_text.value.as_str(), - r#"{"key":"{\"json\":\"value\"}"}"# - ); + assert_eq!(json_text.as_str(), r#"{"key":"{\"json\":\"value\"}"}"#); } #[test] @@ -1149,10 +1151,7 @@ mod tests { let OwnedValue::Text(json_text) = result else { panic!("Expected OwnedValue::Text"); }; - assert_eq!( - json_text.value.as_str(), - r#"{"parent_key":{"key":"value"}}"# - ); + assert_eq!(json_text.as_str(), r#"{"parent_key":{"key":"value"}}"#); } #[test] @@ -1165,7 +1164,7 @@ mod tests { let OwnedValue::Text(json_text) = result else { panic!("Expected OwnedValue::Text"); }; - assert_eq!(json_text.value.as_str(), r#"{"key":"value"}"#); + assert_eq!(json_text.as_str(), r#"{"key":"value"}"#); } #[test] @@ -1176,7 +1175,7 @@ mod tests { let OwnedValue::Text(json_text) = result else { panic!("Expected OwnedValue::Text"); }; - assert_eq!(json_text.value.as_str(), r#"{}"#); + assert_eq!(json_text.as_str(), r#"{}"#); } #[test] @@ -1301,10 +1300,7 @@ mod tests { #[test] fn test_json_path_from_owned_value_root_strict() { - let path = OwnedValue::Text(Text { - value: Rc::new("$".to_string()), - subtype: TextSubtype::Text, - }); + let path = OwnedValue::Text(Text::new(Rc::new("$".to_string()))); let result = json_path_from_owned_value(&path, true); assert!(result.is_ok()); @@ -1321,10 +1317,7 @@ mod tests { #[test] fn test_json_path_from_owned_value_root_non_strict() { - let path = OwnedValue::Text(Text { - value: Rc::new("$".to_string()), - subtype: TextSubtype::Text, - }); + let path = OwnedValue::Text(Text::new(Rc::new("$".to_string()))); let result = json_path_from_owned_value(&path, false); assert!(result.is_ok()); @@ -1341,20 +1334,14 @@ mod tests { #[test] fn test_json_path_from_owned_value_named_strict() { - let path = OwnedValue::Text(Text { - value: Rc::new("field".to_string()), - subtype: TextSubtype::Text, - }); + let path = OwnedValue::Text(Text::new(Rc::new("field".to_string()))); assert!(json_path_from_owned_value(&path, true).is_err()); } #[test] fn test_json_path_from_owned_value_named_non_strict() { - let path = OwnedValue::Text(Text { - value: Rc::new("field".to_string()), - subtype: TextSubtype::Text, - }); + let path = OwnedValue::Text(Text::new(Rc::new("field".to_string()))); let result = json_path_from_owned_value(&path, false); assert!(result.is_ok()); diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index e86d89520..2f2482826 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -46,7 +46,7 @@ use crate::io::{Buffer, Completion, ReadCompletion, SyncCompletion, WriteComplet use crate::storage::buffer_pool::BufferPool; use crate::storage::database::DatabaseStorage; use crate::storage::pager::Pager; -use crate::types::{OwnedRecord, OwnedValue}; +use crate::types::{OwnedRecord, OwnedValue, Text, TextSubtype}; use crate::{File, Result}; use log::trace; use parking_lot::RwLock; @@ -1059,8 +1059,13 @@ pub fn read_value(buf: &[u8], serial_type: &SerialType) -> Result<(OwnedValue, u ); } let bytes = buf[0..n].to_vec(); - let value = unsafe { String::from_utf8_unchecked(bytes) }; - Ok((OwnedValue::build_text(value.into()), n)) + Ok(( + OwnedValue::Text(Text { + value: Rc::new(bytes), + subtype: TextSubtype::Text, + }), + n, + )) } } } diff --git a/core/types.rs b/core/types.rs index 340806e25..cdc47e9f9 100644 --- a/core/types.rs +++ b/core/types.rs @@ -49,7 +49,7 @@ pub enum TextSubtype { #[derive(Debug, Clone, PartialEq)] pub struct Text { - pub value: Rc, + pub value: Rc>, pub subtype: TextSubtype, } @@ -60,17 +60,21 @@ impl Text { pub fn new(value: Rc) -> Self { Self { - value, + value: Rc::new(value.as_bytes().to_vec()), subtype: TextSubtype::Text, } } pub fn json(value: Rc) -> Self { Self { - value, + value: Rc::new(value.as_bytes().to_vec()), subtype: TextSubtype::Json, } } + + pub fn as_str(&self) -> &str { + unsafe { std::str::from_utf8_unchecked(self.value.as_ref()) } + } } #[derive(Debug, Clone, PartialEq)] @@ -103,7 +107,7 @@ impl OwnedValue { pub fn to_text(&self) -> Option<&str> { match self { - OwnedValue::Text(t) => Some(&t.value), + OwnedValue::Text(t) => Some(t.as_str()), _ => None, } } @@ -129,7 +133,7 @@ impl OwnedValue { OwnedValue::Null => Value::Null, OwnedValue::Integer(i) => Value::Integer(*i), OwnedValue::Float(f) => Value::Float(*f), - OwnedValue::Text(s) => Value::Text(&s.value), + OwnedValue::Text(s) => Value::Text(s.as_str()), OwnedValue::Blob(b) => Value::Blob(b), OwnedValue::Agg(a) => match a.as_ref() { AggContext::Avg(acc, _count) => match acc { @@ -187,7 +191,7 @@ impl Display for OwnedValue { Self::Null => write!(f, "NULL"), Self::Integer(i) => write!(f, "{}", i), Self::Float(fl) => write!(f, "{:?}", fl), - Self::Text(s) => write!(f, "{}", s.value), + Self::Text(s) => write!(f, "{}", s.as_str()), Self::Blob(b) => write!(f, "{}", String::from_utf8_lossy(b)), Self::Agg(a) => match a.as_ref() { AggContext::Avg(acc, _count) => write!(f, "{}", acc), @@ -211,7 +215,7 @@ impl OwnedValue { Self::Null => ExtValue::null(), Self::Integer(i) => ExtValue::from_integer(*i), Self::Float(fl) => ExtValue::from_float(*fl), - Self::Text(text) => ExtValue::from_text(text.value.to_string()), + Self::Text(text) => ExtValue::from_text(text.as_str().to_string()), Self::Blob(blob) => ExtValue::from_blob(blob.to_vec()), Self::Agg(_) => todo!(), Self::Record(_) => todo!("Record values not yet supported"), @@ -377,21 +381,21 @@ impl std::ops::Add for OwnedValue { Self::Float(float_left + float_right) } (Self::Text(string_left), Self::Text(string_right)) => Self::build_text(Rc::new( - string_left.value.to_string() + &string_right.value.to_string(), + string_left.as_str().to_string() + &string_right.as_str(), )), (Self::Text(string_left), Self::Integer(int_right)) => Self::build_text(Rc::new( - string_left.value.to_string() + &int_right.to_string(), - )), - (Self::Integer(int_left), Self::Text(string_right)) => Self::build_text(Rc::new( - int_left.to_string() + &string_right.value.to_string(), + string_left.as_str().to_string() + &int_right.to_string(), )), + (Self::Integer(int_left), Self::Text(string_right)) => { + Self::build_text(Rc::new(int_left.to_string() + &string_right.as_str())) + } (Self::Text(string_left), Self::Float(float_right)) => { let string_right = Self::Float(float_right).to_string(); - Self::build_text(Rc::new(string_left.value.to_string() + &string_right)) + Self::build_text(Rc::new(string_left.as_str().to_string() + &string_right)) } (Self::Float(float_left), Self::Text(string_right)) => { let string_left = Self::Float(float_left).to_string(); - Self::build_text(Rc::new(string_left + &string_right.value.to_string())) + Self::build_text(Rc::new(string_left + &string_right.as_str())) } (lhs, Self::Null) => lhs, (Self::Null, rhs) => rhs, @@ -500,7 +504,7 @@ impl<'a> FromValue<'a> for i64 { impl<'a> FromValue<'a> for String { fn from_value(value: &'a OwnedValue) -> Result { match value { - OwnedValue::Text(s) => Ok(s.value.to_string()), + OwnedValue::Text(s) => Ok(s.as_str().to_string()), _ => Err(LimboError::ConversionError("Expected text value".into())), } } @@ -509,7 +513,7 @@ impl<'a> FromValue<'a> for String { impl<'a> FromValue<'a> for &'a str { fn from_value(value: &'a OwnedValue) -> Result { match value { - OwnedValue::Text(s) => Ok(s.value.as_str()), + OwnedValue::Text(s) => Ok(s.as_str()), _ => Err(LimboError::ConversionError("Expected text value".into())), } } @@ -635,7 +639,7 @@ impl OwnedRecord { } } OwnedValue::Float(f) => buf.extend_from_slice(&f.to_be_bytes()), - OwnedValue::Text(t) => buf.extend_from_slice(t.value.as_bytes()), + OwnedValue::Text(t) => buf.extend_from_slice(&t.value), OwnedValue::Blob(b) => buf.extend_from_slice(b), // non serializable OwnedValue::Agg(_) => unreachable!(), diff --git a/core/vdbe/datetime.rs b/core/vdbe/datetime.rs index a4fe2a680..878e722e6 100644 --- a/core/vdbe/datetime.rs +++ b/core/vdbe/datetime.rs @@ -29,7 +29,7 @@ pub fn exec_strftime(values: &[OwnedValue]) -> OwnedValue { } let format_str = match &values[0] { - OwnedValue::Text(text) => text.value.to_string(), + OwnedValue::Text(text) => text.as_str().to_string(), OwnedValue::Integer(num) => num.to_string(), OwnedValue::Float(num) => format!("{:.14}", num), _ => return OwnedValue::Null, @@ -82,7 +82,7 @@ fn modify_dt( if let OwnedValue::Text(ref text_rc) = modifier { // TODO: to prevent double conversion and properly support 'utc'/'localtime', we also // need to keep track of the current timezone and apply it to the modifier. - match apply_modifier(dt, &text_rc.value) { + match apply_modifier(dt, text_rc.as_str()) { Ok(true) => subsec_requested = true, Ok(false) => {} Err(_) => return OwnedValue::build_text(Rc::new(String::new())), @@ -382,7 +382,7 @@ fn get_unixepoch_from_naive_datetime(value: NaiveDateTime) -> String { fn parse_naive_date_time(time_value: &OwnedValue) -> Option { match time_value { - OwnedValue::Text(s) => get_date_time_from_time_value_string(&s.value), + OwnedValue::Text(s) => get_date_time_from_time_value_string(s.as_str()), OwnedValue::Integer(i) => get_date_time_from_time_value_integer(*i), OwnedValue::Float(f) => get_date_time_from_time_value_float(*f), _ => None, @@ -1100,7 +1100,7 @@ mod tests { for (input, expected) in test_cases { let result = exec_time(&[input]); if let OwnedValue::Text(result_str) = result { - assert_eq!(result_str.value.as_str(), expected); + assert_eq!(result_str.as_str(), expected); } else { panic!("Expected OwnedValue::Text, but got: {:?}", result); } diff --git a/core/vdbe/insn.rs b/core/vdbe/insn.rs index 110601e82..b7b817e54 100644 --- a/core/vdbe/insn.rs +++ b/core/vdbe/insn.rs @@ -643,11 +643,11 @@ pub fn exec_add(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { | (OwnedValue::Integer(i), OwnedValue::Float(f)) => OwnedValue::Float(*f + *i as f64), (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_add( - &cast_text_to_numerical(&lhs.value), - &cast_text_to_numerical(&rhs.value), + &cast_text_to_numerical(lhs.as_str()), + &cast_text_to_numerical(rhs.as_str()), ), (OwnedValue::Text(text), other) | (other, OwnedValue::Text(text)) => { - exec_add(&cast_text_to_numerical(&text.value), other) + exec_add(&cast_text_to_numerical(text.as_str()), other) } _ => todo!(), } @@ -674,14 +674,14 @@ pub fn exec_subtract(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { (OwnedValue::Integer(lhs), OwnedValue::Float(rhs)) => OwnedValue::Float(*lhs as f64 - rhs), (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_subtract( - &cast_text_to_numerical(&lhs.value), - &cast_text_to_numerical(&rhs.value), + &cast_text_to_numerical(lhs.as_str()), + &cast_text_to_numerical(rhs.as_str()), ), (OwnedValue::Text(text), other) => { - exec_subtract(&cast_text_to_numerical(&text.value), other) + exec_subtract(&cast_text_to_numerical(text.as_str()), other) } (other, OwnedValue::Text(text)) => { - exec_subtract(other, &cast_text_to_numerical(&text.value)) + exec_subtract(other, &cast_text_to_numerical(text.as_str())) } _ => todo!(), } @@ -707,11 +707,11 @@ pub fn exec_multiply(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { | (OwnedValue::Float(f), OwnedValue::Integer(i)) => OwnedValue::Float(*i as f64 * { *f }), (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_multiply( - &cast_text_to_numerical(&lhs.value), - &cast_text_to_numerical(&rhs.value), + &cast_text_to_numerical(lhs.as_str()), + &cast_text_to_numerical(rhs.as_str()), ), (OwnedValue::Text(text), other) | (other, OwnedValue::Text(text)) => { - exec_multiply(&cast_text_to_numerical(&text.value), other) + exec_multiply(&cast_text_to_numerical(text.as_str()), other) } _ => todo!(), @@ -740,11 +740,15 @@ pub fn exec_divide(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { (OwnedValue::Integer(lhs), OwnedValue::Float(rhs)) => OwnedValue::Float(*lhs as f64 / rhs), (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_divide( - &cast_text_to_numerical(&lhs.value), - &cast_text_to_numerical(&rhs.value), + &cast_text_to_numerical(lhs.as_str()), + &cast_text_to_numerical(rhs.as_str()), ), - (OwnedValue::Text(text), other) => exec_divide(&cast_text_to_numerical(&text.value), other), - (other, OwnedValue::Text(text)) => exec_divide(other, &cast_text_to_numerical(&text.value)), + (OwnedValue::Text(text), other) => { + exec_divide(&cast_text_to_numerical(text.as_str()), other) + } + (other, OwnedValue::Text(text)) => { + exec_divide(other, &cast_text_to_numerical(text.as_str())) + } _ => todo!(), } } @@ -769,11 +773,11 @@ pub fn exec_bit_and(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { (OwnedValue::Float(lh), OwnedValue::Integer(rh)) => OwnedValue::Integer(*lh as i64 & rh), (OwnedValue::Integer(lh), OwnedValue::Float(rh)) => OwnedValue::Integer(lh & *rh as i64), (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_bit_and( - &cast_text_to_numerical(&lhs.value), - &cast_text_to_numerical(&rhs.value), + &cast_text_to_numerical(lhs.as_str()), + &cast_text_to_numerical(rhs.as_str()), ), (OwnedValue::Text(text), other) | (other, OwnedValue::Text(text)) => { - exec_bit_and(&cast_text_to_numerical(&text.value), other) + exec_bit_and(&cast_text_to_numerical(text.as_str()), other) } _ => todo!(), } @@ -795,11 +799,11 @@ pub fn exec_bit_or(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { OwnedValue::Integer(*lh as i64 | *rh as i64) } (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_bit_or( - &cast_text_to_numerical(&lhs.value), - &cast_text_to_numerical(&rhs.value), + &cast_text_to_numerical(lhs.as_str()), + &cast_text_to_numerical(rhs.as_str()), ), (OwnedValue::Text(text), other) | (other, OwnedValue::Text(text)) => { - exec_bit_or(&cast_text_to_numerical(&text.value), other) + exec_bit_or(&cast_text_to_numerical(text.as_str()), other) } _ => todo!(), } @@ -839,7 +843,7 @@ pub fn exec_bit_not(mut reg: &OwnedValue) -> OwnedValue { OwnedValue::Null => OwnedValue::Null, OwnedValue::Integer(i) => OwnedValue::Integer(!i), OwnedValue::Float(f) => OwnedValue::Integer(!(*f as i64)), - OwnedValue::Text(text) => exec_bit_not(&cast_text_to_numerical(&text.value)), + OwnedValue::Text(text) => exec_bit_not(&cast_text_to_numerical(text.as_str())), _ => todo!(), } } @@ -866,14 +870,14 @@ pub fn exec_shift_left(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue OwnedValue::Integer(compute_shl(*lh as i64, *rh as i64)) } (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_shift_left( - &cast_text_to_numerical(&lhs.value), - &cast_text_to_numerical(&rhs.value), + &cast_text_to_numerical(lhs.as_str()), + &cast_text_to_numerical(rhs.as_str()), ), (OwnedValue::Text(text), other) => { - exec_shift_left(&cast_text_to_numerical(&text.value), other) + exec_shift_left(&cast_text_to_numerical(text.as_str()), other) } (other, OwnedValue::Text(text)) => { - exec_shift_left(other, &cast_text_to_numerical(&text.value)) + exec_shift_left(other, &cast_text_to_numerical(text.as_str())) } _ => todo!(), } @@ -905,14 +909,14 @@ pub fn exec_shift_right(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValu OwnedValue::Integer(compute_shr(*lh as i64, *rh as i64)) } (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_shift_right( - &cast_text_to_numerical(&lhs.value), - &cast_text_to_numerical(&rhs.value), + &cast_text_to_numerical(lhs.as_str()), + &cast_text_to_numerical(rhs.as_str()), ), (OwnedValue::Text(text), other) => { - exec_shift_right(&cast_text_to_numerical(&text.value), other) + exec_shift_right(&cast_text_to_numerical(text.as_str()), other) } (other, OwnedValue::Text(text)) => { - exec_shift_right(other, &cast_text_to_numerical(&text.value)) + exec_shift_right(other, &cast_text_to_numerical(text.as_str())) } _ => todo!(), } @@ -943,7 +947,7 @@ pub fn exec_boolean_not(mut reg: &OwnedValue) -> OwnedValue { OwnedValue::Null => OwnedValue::Null, OwnedValue::Integer(i) => OwnedValue::Integer((*i == 0) as i64), OwnedValue::Float(f) => OwnedValue::Integer((*f == 0.0) as i64), - OwnedValue::Text(text) => exec_boolean_not(&cast_text_to_numerical(&text.value)), + OwnedValue::Text(text) => exec_boolean_not(&cast_text_to_numerical(text.as_str())), _ => todo!(), } } @@ -951,20 +955,20 @@ pub fn exec_boolean_not(mut reg: &OwnedValue) -> OwnedValue { pub fn exec_concat(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { match (lhs, rhs) { (OwnedValue::Text(lhs_text), OwnedValue::Text(rhs_text)) => { - OwnedValue::build_text(Rc::new(lhs_text.value.as_ref().clone() + &rhs_text.value)) + OwnedValue::build_text(Rc::new(lhs_text.as_str().to_string() + &rhs_text.as_str())) } (OwnedValue::Text(lhs_text), OwnedValue::Integer(rhs_int)) => OwnedValue::build_text( - Rc::new(lhs_text.value.as_ref().clone() + &rhs_int.to_string()), + Rc::new(lhs_text.as_str().to_string() + &rhs_int.to_string()), ), (OwnedValue::Text(lhs_text), OwnedValue::Float(rhs_float)) => OwnedValue::build_text( - Rc::new(lhs_text.value.as_ref().clone() + &rhs_float.to_string()), + Rc::new(lhs_text.as_str().to_string() + &rhs_float.to_string()), ), (OwnedValue::Text(lhs_text), OwnedValue::Agg(rhs_agg)) => OwnedValue::build_text(Rc::new( - lhs_text.value.as_ref().clone() + &rhs_agg.final_value().to_string(), + lhs_text.as_str().to_string() + &rhs_agg.final_value().to_string(), )), (OwnedValue::Integer(lhs_int), OwnedValue::Text(rhs_text)) => { - OwnedValue::build_text(Rc::new(lhs_int.to_string() + &rhs_text.value)) + OwnedValue::build_text(Rc::new(lhs_int.to_string() + rhs_text.as_str())) } (OwnedValue::Integer(lhs_int), OwnedValue::Integer(rhs_int)) => { OwnedValue::build_text(Rc::new(lhs_int.to_string() + &rhs_int.to_string())) @@ -977,7 +981,7 @@ pub fn exec_concat(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { ), (OwnedValue::Float(lhs_float), OwnedValue::Text(rhs_text)) => { - OwnedValue::build_text(Rc::new(lhs_float.to_string() + &rhs_text.value)) + OwnedValue::build_text(Rc::new(lhs_float.to_string() + rhs_text.as_str())) } (OwnedValue::Float(lhs_float), OwnedValue::Integer(rhs_int)) => { OwnedValue::build_text(Rc::new(lhs_float.to_string() + &rhs_int.to_string())) @@ -989,9 +993,9 @@ pub fn exec_concat(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { Rc::new(lhs_float.to_string() + &rhs_agg.final_value().to_string()), ), - (OwnedValue::Agg(lhs_agg), OwnedValue::Text(rhs_text)) => { - OwnedValue::build_text(Rc::new(lhs_agg.final_value().to_string() + &rhs_text.value)) - } + (OwnedValue::Agg(lhs_agg), OwnedValue::Text(rhs_text)) => OwnedValue::build_text(Rc::new( + lhs_agg.final_value().to_string() + rhs_text.as_str(), + )), (OwnedValue::Agg(lhs_agg), OwnedValue::Integer(rhs_int)) => OwnedValue::build_text( Rc::new(lhs_agg.final_value().to_string() + &rhs_int.to_string()), ), @@ -1025,11 +1029,11 @@ pub fn exec_and(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { | (OwnedValue::Float(0.0), _) => OwnedValue::Integer(0), (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_and( - &cast_text_to_numerical(&lhs.value), - &cast_text_to_numerical(&rhs.value), + &cast_text_to_numerical(lhs.as_str()), + &cast_text_to_numerical(rhs.as_str()), ), (OwnedValue::Text(text), other) | (other, OwnedValue::Text(text)) => { - exec_and(&cast_text_to_numerical(&text.value), other) + exec_and(&cast_text_to_numerical(text.as_str()), other) } _ => OwnedValue::Integer(1), } @@ -1054,11 +1058,11 @@ pub fn exec_or(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { | (OwnedValue::Float(0.0), OwnedValue::Float(0.0)) | (OwnedValue::Integer(0), OwnedValue::Integer(0)) => OwnedValue::Integer(0), (OwnedValue::Text(lhs), OwnedValue::Text(rhs)) => exec_or( - &cast_text_to_numerical(&lhs.value), - &cast_text_to_numerical(&rhs.value), + &cast_text_to_numerical(lhs.as_str()), + &cast_text_to_numerical(rhs.as_str()), ), (OwnedValue::Text(text), other) | (other, OwnedValue::Text(text)) => { - exec_or(&cast_text_to_numerical(&text.value), other) + exec_or(&cast_text_to_numerical(text.as_str()), other) } _ => OwnedValue::Integer(1), } diff --git a/core/vdbe/likeop.rs b/core/vdbe/likeop.rs index c2d75438d..3815c504a 100644 --- a/core/vdbe/likeop.rs +++ b/core/vdbe/likeop.rs @@ -7,7 +7,7 @@ use crate::{types::OwnedValue, LimboError}; pub fn construct_like_escape_arg(escape_value: &OwnedValue) -> Result { match escape_value { OwnedValue::Text(text) => { - let mut escape_chars = text.value.chars(); + let mut escape_chars = text.as_str().chars(); match (escape_chars.next(), escape_chars.next()) { (Some(escape), None) => Ok(escape), _ => Err(LimboError::Constraint( diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index b5f205e59..9132a9771 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -1840,12 +1840,12 @@ impl Program { // To the way blobs are parsed here in SQLite. let indent = match indent { Some(value) => match value { - OwnedValue::Text(text) => text.value.as_str(), + OwnedValue::Text(text) => text.as_str(), OwnedValue::Integer(val) => &val.to_string(), OwnedValue::Float(val) => &val.to_string(), OwnedValue::Blob(val) => &String::from_utf8_lossy(val), OwnedValue::Agg(ctx) => match ctx.final_value() { - OwnedValue::Text(text) => text.value.as_str(), + OwnedValue::Text(text) => text.as_str(), OwnedValue::Integer(val) => &val.to_string(), OwnedValue::Float(val) => &val.to_string(), OwnedValue::Blob(val) => &String::from_utf8_lossy(val), @@ -1883,7 +1883,8 @@ impl Program { else { unreachable!("Cast with non-text type"); }; - let result = exec_cast(®_value_argument, ®_value_type.value); + let result = + exec_cast(®_value_argument, ®_value_type.as_str()); state.registers[*dest] = result; } ScalarFunc::Changes => { @@ -1921,8 +1922,8 @@ impl Program { }; OwnedValue::Integer(exec_glob( cache, - &pattern.value, - &text.value, + &pattern.as_str(), + &text.as_str(), ) as i64) } @@ -1964,8 +1965,8 @@ impl Program { }; OwnedValue::Integer(exec_like_with_escape( - &pattern.value, - &text.value, + &pattern.as_str(), + &text.as_str(), escape, ) as i64) @@ -1978,8 +1979,8 @@ impl Program { }; OwnedValue::Integer(exec_like( cache, - &pattern.value, - &text.value, + &pattern.as_str(), + &text.as_str(), ) as i64) } @@ -2408,14 +2409,16 @@ impl Program { "MustBeInt: the value in register cannot be cast to integer" ), }, - OwnedValue::Text(text) => match checked_cast_text_to_numeric(&text.value) { - Ok(OwnedValue::Integer(i)) => { - state.registers[*reg] = OwnedValue::Integer(i) + OwnedValue::Text(text) => { + match checked_cast_text_to_numeric(&text.as_str()) { + Ok(OwnedValue::Integer(i)) => { + state.registers[*reg] = OwnedValue::Integer(i) + } + _ => crate::bail_parse_error!( + "MustBeInt: the value in register cannot be cast to integer" + ), } - _ => crate::bail_parse_error!( - "MustBeInt: the value in register cannot be cast to integer" - ), - }, + } _ => { crate::bail_parse_error!( "MustBeInt: the value in register cannot be cast to integer" @@ -2513,7 +2516,7 @@ impl Program { } Insn::CreateBtree { db, root, flags } => { if *db > 0 { - // TODO: implement temp datbases + // TODO: implement temp databases todo!("temp databases not implemented yet"); } let mut cursor = Box::new(BTreeCursor::new(pager.clone(), 0)); @@ -2717,7 +2720,7 @@ fn get_indent_count(indent_count: usize, curr_insn: &Insn, prev_insn: Option<&In fn exec_lower(reg: &OwnedValue) -> Option { match reg { - OwnedValue::Text(t) => Some(OwnedValue::build_text(Rc::new(t.value.to_lowercase()))), + OwnedValue::Text(t) => Some(OwnedValue::build_text(Rc::new(t.as_str().to_lowercase()))), t => Some(t.to_owned()), } } @@ -2746,7 +2749,7 @@ fn exec_octet_length(reg: &OwnedValue) -> OwnedValue { fn exec_upper(reg: &OwnedValue) -> Option { match reg { - OwnedValue::Text(t) => Some(OwnedValue::build_text(Rc::new(t.value.to_uppercase()))), + OwnedValue::Text(t) => Some(OwnedValue::build_text(Rc::new(t.as_str().to_uppercase()))), t => Some(t.to_owned()), } } @@ -2755,7 +2758,7 @@ fn exec_concat_strings(registers: &[OwnedValue]) -> OwnedValue { let mut result = String::new(); for reg in registers { match reg { - OwnedValue::Text(text) => result.push_str(&text.value), + OwnedValue::Text(text) => result.push_str(text.as_str()), OwnedValue::Integer(i) => result.push_str(&i.to_string()), OwnedValue::Float(f) => result.push_str(&f.to_string()), OwnedValue::Agg(aggctx) => result.push_str(&aggctx.final_value().to_string()), @@ -2773,9 +2776,9 @@ fn exec_concat_ws(registers: &[OwnedValue]) -> OwnedValue { } let separator = match ®isters[0] { - OwnedValue::Text(text) => text.value.clone(), - OwnedValue::Integer(i) => Rc::new(i.to_string()), - OwnedValue::Float(f) => Rc::new(f.to_string()), + OwnedValue::Text(text) => text.as_str().to_string(), + OwnedValue::Integer(i) => i.to_string(), + OwnedValue::Float(f) => f.to_string(), _ => return OwnedValue::Null, }; @@ -2785,7 +2788,7 @@ fn exec_concat_ws(registers: &[OwnedValue]) -> OwnedValue { result.push_str(&separator); } match reg { - OwnedValue::Text(text) => result.push_str(&text.value), + OwnedValue::Text(text) => result.push_str(text.as_str()), OwnedValue::Integer(i) => result.push_str(&i.to_string()), OwnedValue::Float(f) => result.push_str(&f.to_string()), _ => continue, @@ -2800,9 +2803,9 @@ fn exec_sign(reg: &OwnedValue) -> Option { OwnedValue::Integer(i) => *i as f64, OwnedValue::Float(f) => *f, OwnedValue::Text(s) => { - if let Ok(i) = s.value.parse::() { + if let Ok(i) = s.as_str().parse::() { i as f64 - } else if let Ok(f) = s.value.parse::() { + } else if let Ok(f) = s.as_str().parse::() { f } else { return Some(OwnedValue::Null); @@ -2840,7 +2843,7 @@ pub fn exec_soundex(reg: &OwnedValue) -> OwnedValue { OwnedValue::Null => return OwnedValue::build_text(Rc::new("?000".to_string())), OwnedValue::Text(s) => { // return ?000 if non ASCII alphabet character is found - if !s.value.chars().all(|c| c.is_ascii_alphabetic()) { + if !s.as_str().chars().all(|c| c.is_ascii_alphabetic()) { return OwnedValue::build_text(Rc::new("?000".to_string())); } s.clone() @@ -2850,7 +2853,7 @@ pub fn exec_soundex(reg: &OwnedValue) -> OwnedValue { // Remove numbers and spaces let word: String = s - .value + .as_str() .chars() .filter(|c| !c.is_ascii_digit()) .collect::() @@ -2958,7 +2961,7 @@ fn exec_randomblob(reg: &OwnedValue) -> OwnedValue { let length = match reg { OwnedValue::Integer(i) => *i, OwnedValue::Float(f) => *f as i64, - OwnedValue::Text(t) => t.value.parse().unwrap_or(1), + OwnedValue::Text(t) => t.as_str().parse().unwrap_or(1), _ => 1, } .max(1) as usize; @@ -2974,9 +2977,9 @@ fn exec_quote(value: &OwnedValue) -> OwnedValue { OwnedValue::Integer(_) | OwnedValue::Float(_) => value.to_owned(), OwnedValue::Blob(_) => todo!(), OwnedValue::Text(s) => { - let mut quoted = String::with_capacity(s.value.len() + 2); + let mut quoted = String::with_capacity(s.as_str().len() + 2); quoted.push('\''); - for c in s.value.chars() { + for c in s.as_str().chars() { if c == '\0' { break; } else { @@ -3081,7 +3084,7 @@ fn exec_substring( (str_value, start_value, length_value) { let start = *start as usize; - let str_len = str.value.len(); + let str_len = str.as_str().len(); if start > str_len { return OwnedValue::build_text(Rc::new("".to_string())); @@ -3093,19 +3096,19 @@ fn exec_substring( } else { str_len }; - let substring = &str.value[start_idx..end.min(str_len)]; + let substring = &str.as_str()[start_idx..end.min(str_len)]; OwnedValue::build_text(Rc::new(substring.to_string())) } else if let (OwnedValue::Text(str), OwnedValue::Integer(start)) = (str_value, start_value) { let start = *start as usize; - let str_len = str.value.len(); + let str_len = str.as_str().len(); if start > str_len { return OwnedValue::build_text(Rc::new("".to_string())); } let start_idx = start - 1; - let substring = &str.value[start_idx..str_len]; + let substring = &str.as_str()[start_idx..str_len]; OwnedValue::build_text(Rc::new(substring.to_string())) } else { @@ -3128,7 +3131,7 @@ fn exec_instr(reg: &OwnedValue, pattern: &OwnedValue) -> OwnedValue { let reg_str; let reg = match reg { - OwnedValue::Text(s) => s.value.as_str(), + OwnedValue::Text(s) => s.as_str(), _ => { reg_str = reg.to_string(); reg_str.as_str() @@ -3137,7 +3140,7 @@ fn exec_instr(reg: &OwnedValue, pattern: &OwnedValue) -> OwnedValue { let pattern_str; let pattern = match pattern { - OwnedValue::Text(s) => s.value.as_str(), + OwnedValue::Text(s) => s.as_str(), _ => { pattern_str = pattern.to_string(); pattern_str.as_str() @@ -3221,7 +3224,7 @@ fn exec_unicode(reg: &OwnedValue) -> OwnedValue { fn _to_float(reg: &OwnedValue) -> f64 { match reg { - OwnedValue::Text(x) => x.value.parse().unwrap_or(0.0), + OwnedValue::Text(x) => x.as_str().parse().unwrap_or(0.0), OwnedValue::Integer(x) => *x as f64, OwnedValue::Float(x) => *x, _ => 0.0, @@ -3230,7 +3233,7 @@ fn _to_float(reg: &OwnedValue) -> f64 { fn exec_round(reg: &OwnedValue, precision: Option) -> OwnedValue { let precision = match precision { - Some(OwnedValue::Text(x)) => x.value.parse().unwrap_or(0.0), + Some(OwnedValue::Text(x)) => x.as_str().parse().unwrap_or(0.0), Some(OwnedValue::Integer(x)) => x as f64, Some(OwnedValue::Float(x)) => x, Some(OwnedValue::Null) => return OwnedValue::Null, @@ -3259,7 +3262,9 @@ fn exec_trim(reg: &OwnedValue, pattern: Option) -> OwnedValue { } _ => reg.to_owned(), }, - (OwnedValue::Text(t), None) => OwnedValue::build_text(Rc::new(t.value.trim().to_string())), + (OwnedValue::Text(t), None) => { + OwnedValue::build_text(Rc::new(t.as_str().trim().to_string())) + } (reg, _) => reg.to_owned(), } } @@ -3279,7 +3284,7 @@ fn exec_ltrim(reg: &OwnedValue, pattern: Option) -> OwnedValue { _ => reg.to_owned(), }, (OwnedValue::Text(t), None) => { - OwnedValue::build_text(Rc::new(t.value.trim_start().to_string())) + OwnedValue::build_text(Rc::new(t.as_str().trim_start().to_string())) } (reg, _) => reg.to_owned(), } @@ -3300,7 +3305,7 @@ fn exec_rtrim(reg: &OwnedValue, pattern: Option) -> OwnedValue { _ => reg.to_owned(), }, (OwnedValue::Text(t), None) => { - OwnedValue::build_text(Rc::new(t.value.trim_end().to_string())) + OwnedValue::build_text(Rc::new(t.as_str().trim_end().to_string())) } (reg, _) => reg.to_owned(), } @@ -3310,7 +3315,7 @@ fn exec_zeroblob(req: &OwnedValue) -> OwnedValue { let length: i64 = match req { OwnedValue::Integer(i) => *i, OwnedValue::Float(f) => *f as i64, - OwnedValue::Text(s) => s.value.parse().unwrap_or(0), + OwnedValue::Text(s) => s.as_str().parse().unwrap_or(0), _ => 0, }; OwnedValue::Blob(Rc::new(vec![0; length.max(0) as usize])) @@ -3352,7 +3357,7 @@ fn exec_cast(value: &OwnedValue, datatype: &str) -> OwnedValue { let text = String::from_utf8_lossy(b); cast_text_to_real(&text) } - OwnedValue::Text(t) => cast_text_to_real(&t.value), + OwnedValue::Text(t) => cast_text_to_real(t.as_str()), OwnedValue::Integer(i) => OwnedValue::Float(*i as f64), OwnedValue::Float(f) => OwnedValue::Float(*f), _ => OwnedValue::Float(0.0), @@ -3363,7 +3368,7 @@ fn exec_cast(value: &OwnedValue, datatype: &str) -> OwnedValue { let text = String::from_utf8_lossy(b); cast_text_to_integer(&text) } - OwnedValue::Text(t) => cast_text_to_integer(&t.value), + OwnedValue::Text(t) => cast_text_to_integer(t.as_str()), OwnedValue::Integer(i) => OwnedValue::Integer(*i), // A cast of a REAL value into an INTEGER results in the integer between the REAL value and zero // that is closest to the REAL value. If a REAL is greater than the greatest possible signed integer (+9223372036854775807) @@ -3386,7 +3391,7 @@ fn exec_cast(value: &OwnedValue, datatype: &str) -> OwnedValue { let text = String::from_utf8_lossy(b); cast_text_to_numeric(&text) } - OwnedValue::Text(t) => cast_text_to_numeric(&t.value), + OwnedValue::Text(t) => cast_text_to_numeric(t.as_str()), OwnedValue::Integer(i) => OwnedValue::Integer(*i), OwnedValue::Float(f) => OwnedValue::Float(*f), _ => value.clone(), // TODO probably wrong @@ -3414,13 +3419,13 @@ fn exec_replace(source: &OwnedValue, pattern: &OwnedValue, replacement: &OwnedVa // If any of the casts failed, panic as text casting is not expected to fail. match (&source, &pattern, &replacement) { (OwnedValue::Text(source), OwnedValue::Text(pattern), OwnedValue::Text(replacement)) => { - if pattern.value.is_empty() { - return OwnedValue::build_text(source.value.clone()); + if pattern.as_str().is_empty() { + return OwnedValue::Text(source.clone()); } let result = source - .value - .replace(pattern.value.as_str(), &replacement.value); + .as_str() + .replace(pattern.as_str(), replacement.as_str()); OwnedValue::build_text(Rc::new(result)) } _ => unreachable!("text cast should never fail"), @@ -3567,7 +3572,7 @@ fn to_f64(reg: &OwnedValue) -> Option { match reg { OwnedValue::Integer(i) => Some(*i as f64), OwnedValue::Float(f) => Some(*f), - OwnedValue::Text(t) => t.value.parse::().ok(), + OwnedValue::Text(t) => t.as_str().parse::().ok(), OwnedValue::Agg(ctx) => to_f64(ctx.final_value()), _ => None, } @@ -3934,7 +3939,7 @@ mod tests { #[test] fn test_unhex() { - let input = OwnedValue::build_text(Rc::new(String::from("6F"))); + let input = OwnedValue::build_text(Rc::new(String::from("6f"))); let expected = OwnedValue::Blob(Rc::new(vec![0x6f])); assert_eq!(exec_unhex(&input, None), expected); diff --git a/core/vdbe/printf.rs b/core/vdbe/printf.rs index c4fb6a153..73e4bf4f3 100644 --- a/core/vdbe/printf.rs +++ b/core/vdbe/printf.rs @@ -9,7 +9,7 @@ pub fn exec_printf(values: &[OwnedValue]) -> crate::Result { return Ok(OwnedValue::Null); } let format_str = match &values[0] { - OwnedValue::Text(t) => &t.value, + OwnedValue::Text(t) => t.as_str(), _ => return Ok(OwnedValue::Null), }; @@ -44,7 +44,7 @@ pub fn exec_printf(values: &[OwnedValue]) -> crate::Result { return Err(LimboError::InvalidArgument("not enough arguments".into())); } match &values[args_index] { - OwnedValue::Text(t) => result.push_str(&t.value), + OwnedValue::Text(t) => result.push_str(t.as_str()), OwnedValue::Null => result.push_str("(null)"), v => result.push_str(&v.to_string()), } From f3902ef9b6473b01bc177594038075c889b0c9e6 Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Thu, 6 Feb 2025 13:40:34 +0200 Subject: [PATCH 05/16] core: Rename OwnedRecord to Record We only have one record type so let's call it `Record`. --- core/lib.rs | 2 +- core/pseudo.rs | 8 ++++---- core/storage/btree.rs | 25 ++++++++++--------------- core/storage/sqlite3_ondisk.rs | 6 +++--- core/translate/group_by.rs | 4 ++-- core/translate/order_by.rs | 4 ++-- core/types.rs | 22 +++++++++++----------- core/vdbe/insn.rs | 4 ++-- core/vdbe/mod.rs | 16 ++++++++-------- core/vdbe/sorter.rs | 12 ++++++------ 10 files changed, 49 insertions(+), 54 deletions(-) diff --git a/core/lib.rs b/core/lib.rs index 98e37d470..618a5a678 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -501,7 +501,7 @@ impl Statement { } } -pub type Row = types::OwnedRecord; +pub type Row = types::Record; pub type StepResult = vdbe::StepResult; diff --git a/core/pseudo.rs b/core/pseudo.rs index 93d9d6a64..bcc6c91f0 100644 --- a/core/pseudo.rs +++ b/core/pseudo.rs @@ -1,7 +1,7 @@ -use crate::types::OwnedRecord; +use crate::types::Record; pub struct PseudoCursor { - current: Option, + current: Option, } impl PseudoCursor { @@ -9,11 +9,11 @@ impl PseudoCursor { Self { current: None } } - pub fn record(&self) -> Option<&OwnedRecord> { + pub fn record(&self) -> Option<&Record> { self.current.as_ref() } - pub fn insert(&mut self, record: OwnedRecord) { + pub fn insert(&mut self, record: Record) { self.current = Some(record); } } diff --git a/core/storage/btree.rs b/core/storage/btree.rs index 5f53112e1..cb916a95e 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -5,7 +5,7 @@ use crate::storage::sqlite3_ondisk::{ read_btree_cell, read_varint, write_varint, BTreeCell, DatabaseHeader, PageContent, PageType, TableInteriorCell, TableLeafCell, }; -use crate::types::{CursorResult, OwnedRecord, OwnedValue, SeekKey, SeekOp}; +use crate::types::{CursorResult, OwnedValue, Record, SeekKey, SeekOp}; use crate::Result; use std::cell::{Ref, RefCell}; @@ -137,7 +137,7 @@ pub struct BTreeCursor { root_page: usize, /// Rowid and record are stored before being consumed. rowid: RefCell>, - record: RefCell>, + record: RefCell>, null_flag: bool, /// Index internal pages are consumed on the way up, so we store going upwards flag in case /// we just moved to a parent page and the parent page is an internal index page which requires @@ -198,7 +198,7 @@ impl BTreeCursor { /// Move the cursor to the previous record and return it. /// Used in backwards iteration. - fn get_prev_record(&mut self) -> Result, Option)>> { + fn get_prev_record(&mut self) -> Result, Option)>> { loop { let page = self.stack.top(); let cell_idx = self.stack.current_cell_index(); @@ -266,8 +266,7 @@ impl BTreeCursor { _rowid, _payload, .. }) => { self.stack.retreat(); - let record: OwnedRecord = - crate::storage::sqlite3_ondisk::read_record(&_payload)?; + let record: Record = crate::storage::sqlite3_ondisk::read_record(&_payload)?; return Ok(CursorResult::Ok((Some(_rowid), Some(record)))); } BTreeCell::IndexInteriorCell(_) => todo!(), @@ -281,7 +280,7 @@ impl BTreeCursor { fn get_next_record( &mut self, predicate: Option<(SeekKey<'_>, SeekOp)>, - ) -> Result, Option)>> { + ) -> Result, Option)>> { loop { let mem_page_rc = self.stack.top(); let cell_idx = self.stack.current_cell_index() as usize; @@ -444,7 +443,7 @@ impl BTreeCursor { &mut self, key: SeekKey<'_>, op: SeekOp, - ) -> Result, Option)>> { + ) -> Result, Option)>> { return_if_io!(self.move_to(key.clone(), op.clone())); { @@ -698,11 +697,7 @@ impl BTreeCursor { /// Insert a record into the btree. /// If the insert operation overflows the page, it will be split and the btree will be balanced. - fn insert_into_page( - &mut self, - key: &OwnedValue, - record: &OwnedRecord, - ) -> Result> { + fn insert_into_page(&mut self, key: &OwnedValue, record: &Record) -> Result> { if let CursorState::None = &self.state { self.state = CursorState::Write(WriteInfo::new()); } @@ -1581,7 +1576,7 @@ impl BTreeCursor { page_type: PageType, int_key: Option, cell_payload: &mut Vec, - record: &OwnedRecord, + record: &Record, ) { assert!(matches!( page_type, @@ -1815,14 +1810,14 @@ impl BTreeCursor { Ok(CursorResult::Ok(rowid.is_some())) } - pub fn record(&self) -> Result>> { + pub fn record(&self) -> Result>> { Ok(self.record.borrow()) } pub fn insert( &mut self, key: &OwnedValue, - _record: &OwnedRecord, + _record: &Record, moved_before: bool, /* Indicate whether it's necessary to traverse to find the leaf page */ ) -> Result> { let int_key = match key { diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index 2f2482826..56bc959f9 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -46,7 +46,7 @@ use crate::io::{Buffer, Completion, ReadCompletion, SyncCompletion, WriteComplet use crate::storage::buffer_pool::BufferPool; use crate::storage::database::DatabaseStorage; use crate::storage::pager::Pager; -use crate::types::{OwnedRecord, OwnedValue, Text, TextSubtype}; +use crate::types::{OwnedValue, Record, Text, TextSubtype}; use crate::{File, Result}; use log::trace; use parking_lot::RwLock; @@ -948,7 +948,7 @@ impl TryFrom for SerialType { } } -pub fn read_record(payload: &[u8]) -> Result { +pub fn read_record(payload: &[u8]) -> Result { let mut pos = 0; let (header_size, nr) = read_varint(payload)?; assert!((header_size as usize) >= nr); @@ -969,7 +969,7 @@ pub fn read_record(payload: &[u8]) -> Result { pos += n; values.push(value); } - Ok(OwnedRecord::new(values)) + Ok(Record::new(values)) } pub fn read_value(buf: &[u8], serial_type: &SerialType) -> Result<(OwnedValue, usize)> { diff --git a/core/translate/group_by.rs b/core/translate/group_by.rs index 5baa7f72a..0e521ccef 100644 --- a/core/translate/group_by.rs +++ b/core/translate/group_by.rs @@ -5,7 +5,7 @@ use sqlite3_parser::ast; use crate::{ function::AggFunc, schema::{Column, PseudoTable}, - types::{OwnedRecord, OwnedValue}, + types::{OwnedValue, Record}, vdbe::{ builder::{CursorType, ProgramBuilder}, insn::Insn, @@ -72,7 +72,7 @@ pub fn init_group_by( program.emit_insn(Insn::SorterOpen { cursor_id: sort_cursor, columns: aggregates.len() + group_by.exprs.len(), - order: OwnedRecord::new(order), + order: Record::new(order), }); program.add_comment(program.offset(), "clear group by abort flag"); diff --git a/core/translate/order_by.rs b/core/translate/order_by.rs index bfcb0d77c..091d6dd58 100644 --- a/core/translate/order_by.rs +++ b/core/translate/order_by.rs @@ -4,7 +4,7 @@ use sqlite3_parser::ast; use crate::{ schema::{Column, PseudoTable}, - types::{OwnedRecord, OwnedValue}, + types::{OwnedValue, Record}, util::exprs_are_equivalent, vdbe::{ builder::{CursorType, ProgramBuilder}, @@ -47,7 +47,7 @@ pub fn init_order_by( program.emit_insn(Insn::SorterOpen { cursor_id: sort_cursor, columns: order_by.len(), - order: OwnedRecord::new(order), + order: Record::new(order), }); Ok(()) } diff --git a/core/types.rs b/core/types.rs index cdc47e9f9..06c459741 100644 --- a/core/types.rs +++ b/core/types.rs @@ -85,7 +85,7 @@ pub enum OwnedValue { Text(Text), Blob(Rc>), Agg(Box), // TODO(pere): make this without Box. Currently this might cause cache miss but let's leave it for future analysis - Record(OwnedRecord), + Record(Record), } impl OwnedValue { @@ -520,11 +520,11 @@ impl<'a> FromValue<'a> for &'a str { } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] -pub struct OwnedRecord { +pub struct Record { pub values: Vec, } -impl OwnedRecord { +impl Record { pub fn get<'a, T: FromValue<'a> + 'a>(&'a self, idx: usize) -> Result { let value = &self.values[idx]; T::from_value(value) @@ -604,7 +604,7 @@ impl From for u64 { } } -impl OwnedRecord { +impl Record { pub fn new(values: Vec) -> Self { Self { values } } @@ -733,7 +733,7 @@ pub enum SeekOp { #[derive(Clone, PartialEq, Debug)] pub enum SeekKey<'a> { TableRowId(u64), - IndexKey(&'a OwnedRecord), + IndexKey(&'a Record), } #[cfg(test)] @@ -743,7 +743,7 @@ mod tests { #[test] fn test_serialize_null() { - let record = OwnedRecord::new(vec![OwnedValue::Null]); + let record = Record::new(vec![OwnedValue::Null]); let mut buf = Vec::new(); record.serialize(&mut buf); @@ -759,7 +759,7 @@ mod tests { #[test] fn test_serialize_integers() { - let record = OwnedRecord::new(vec![ + let record = Record::new(vec![ OwnedValue::Integer(42), // Should use SERIAL_TYPE_I8 OwnedValue::Integer(1000), // Should use SERIAL_TYPE_I16 OwnedValue::Integer(1_000_000), // Should use SERIAL_TYPE_I24 @@ -835,7 +835,7 @@ mod tests { #[test] fn test_serialize_float() { #[warn(clippy::approx_constant)] - let record = OwnedRecord::new(vec![OwnedValue::Float(3.15555)]); + let record = Record::new(vec![OwnedValue::Float(3.15555)]); let mut buf = Vec::new(); record.serialize(&mut buf); @@ -856,7 +856,7 @@ mod tests { #[test] fn test_serialize_text() { let text = Rc::new("hello".to_string()); - let record = OwnedRecord::new(vec![OwnedValue::Text(Text::new(text.clone()))]); + let record = Record::new(vec![OwnedValue::Text(Text::new(text.clone()))]); let mut buf = Vec::new(); record.serialize(&mut buf); @@ -875,7 +875,7 @@ mod tests { #[test] fn test_serialize_blob() { let blob = Rc::new(vec![1, 2, 3, 4, 5]); - let record = OwnedRecord::new(vec![OwnedValue::Blob(blob.clone())]); + let record = Record::new(vec![OwnedValue::Blob(blob.clone())]); let mut buf = Vec::new(); record.serialize(&mut buf); @@ -894,7 +894,7 @@ mod tests { #[test] fn test_serialize_mixed_types() { let text = Rc::new("test".to_string()); - let record = OwnedRecord::new(vec![ + let record = Record::new(vec![ OwnedValue::Null, OwnedValue::Integer(42), OwnedValue::Float(3.15), diff --git a/core/vdbe/insn.rs b/core/vdbe/insn.rs index b7b817e54..1cdb81e25 100644 --- a/core/vdbe/insn.rs +++ b/core/vdbe/insn.rs @@ -3,7 +3,7 @@ use std::rc::Rc; use super::{AggFunc, BranchOffset, CursorID, FuncCtx, PageIdx}; use crate::storage::wal::CheckpointMode; -use crate::types::{OwnedRecord, OwnedValue}; +use crate::types::{OwnedValue, Record}; use limbo_macros::Description; /// Flags provided to comparison instructions (e.g. Eq, Ne) which determine behavior related to NULL values. @@ -418,7 +418,7 @@ pub enum Insn { SorterOpen { cursor_id: CursorID, // P1 columns: usize, // P2 - order: OwnedRecord, // P4. 0 if ASC and 1 if DESC + order: Record, // P4. 0 if ASC and 1 if DESC }, // Insert a row into the sorter. diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 9132a9771..2cb592828 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -37,7 +37,7 @@ use crate::storage::wal::CheckpointResult; use crate::storage::{btree::BTreeCursor, pager::Pager}; use crate::translate::plan::{ResultSetColumn, TableReference}; use crate::types::{ - AggContext, Cursor, CursorResult, ExternalAggState, OwnedRecord, OwnedValue, SeekKey, SeekOp, + AggContext, Cursor, CursorResult, ExternalAggState, OwnedValue, Record, SeekKey, SeekOp, }; use crate::util::parse_schema_rows; use crate::vdbe::builder::CursorType; @@ -298,7 +298,7 @@ pub struct ProgramState { pub pc: InsnReference, cursors: RefCell>>, registers: Vec, - pub(crate) result_row: Option, + pub(crate) result_row: Option, last_compare: Option, deferred_seek: Option<(CursorID, CursorID)>, ended_coroutine: Bitfield<4>, // flag to indicate that a coroutine has ended (key is the yield register. currently we assume that the yield register is always between 0-255, YOLO) @@ -1208,7 +1208,7 @@ impl Program { let mut cursors = state.cursors.borrow_mut(); if *is_index { let cursor = get_cursor_as_index_mut(&mut cursors, *cursor_id); - let record_from_regs: OwnedRecord = + let record_from_regs: Record = make_owned_record(&state.registers, start_reg, num_regs); let found = return_if_io!( cursor.seek(SeekKey::IndexKey(&record_from_regs), SeekOp::GE) @@ -1254,7 +1254,7 @@ impl Program { let mut cursors = state.cursors.borrow_mut(); if *is_index { let cursor = get_cursor_as_index_mut(&mut cursors, *cursor_id); - let record_from_regs: OwnedRecord = + let record_from_regs: Record = make_owned_record(&state.registers, start_reg, num_regs); let found = return_if_io!( cursor.seek(SeekKey::IndexKey(&record_from_regs), SeekOp::GT) @@ -1298,7 +1298,7 @@ impl Program { assert!(target_pc.is_offset()); let mut cursors = state.cursors.borrow_mut(); let cursor = get_cursor_as_index_mut(&mut cursors, *cursor_id); - let record_from_regs: OwnedRecord = + let record_from_regs: Record = make_owned_record(&state.registers, start_reg, num_regs); if let Some(ref idx_record) = *cursor.record()? { // omit the rowid from the idx_record, which is the last value @@ -1322,7 +1322,7 @@ impl Program { assert!(target_pc.is_offset()); let mut cursors = state.cursors.borrow_mut(); let cursor = get_cursor_as_index_mut(&mut cursors, *cursor_id); - let record_from_regs: OwnedRecord = + let record_from_regs: Record = make_owned_record(&state.registers, start_reg, num_regs); if let Some(ref idx_record) = *cursor.record()? { // omit the rowid from the idx_record, which is the last value @@ -2655,12 +2655,12 @@ fn get_new_rowid(cursor: &mut BTreeCursor, mut rng: R) -> Result OwnedRecord { +fn make_owned_record(registers: &[OwnedValue], start_reg: &usize, count: &usize) -> Record { let mut values = Vec::with_capacity(*count); for r in registers.iter().skip(*start_reg).take(*count) { values.push(r.clone()) } - OwnedRecord::new(values) + Record::new(values) } fn trace_insn(program: &Program, addr: InsnReference, insn: &Insn) { diff --git a/core/vdbe/sorter.rs b/core/vdbe/sorter.rs index 2682c5f46..eecc42e04 100644 --- a/core/vdbe/sorter.rs +++ b/core/vdbe/sorter.rs @@ -1,9 +1,9 @@ -use crate::types::OwnedRecord; +use crate::types::Record; use std::cmp::Ordering; pub struct Sorter { - records: Vec, - current: Option, + records: Vec, + current: Option, order: Vec, } @@ -51,11 +51,11 @@ impl Sorter { pub fn next(&mut self) { self.current = self.records.pop(); } - pub fn record(&self) -> Option<&OwnedRecord> { + pub fn record(&self) -> Option<&Record> { self.current.as_ref() } - pub fn insert(&mut self, record: &OwnedRecord) { - self.records.push(OwnedRecord::new(record.values.to_vec())); + pub fn insert(&mut self, record: &Record) { + self.records.push(Record::new(record.values.to_vec())); } } From 7513f859df6dca5957e7acc8dac009a42ce70e53 Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Thu, 6 Feb 2025 13:50:05 +0200 Subject: [PATCH 06/16] core: Move printf to functions module --- core/functions/mod.rs | 1 + core/{vdbe => functions}/printf.rs | 0 core/lib.rs | 1 + core/vdbe/mod.rs | 3 +-- 4 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 core/functions/mod.rs rename core/{vdbe => functions}/printf.rs (100%) diff --git a/core/functions/mod.rs b/core/functions/mod.rs new file mode 100644 index 000000000..1874bec84 --- /dev/null +++ b/core/functions/mod.rs @@ -0,0 +1 @@ +pub mod printf; diff --git a/core/vdbe/printf.rs b/core/functions/printf.rs similarity index 100% rename from core/vdbe/printf.rs rename to core/functions/printf.rs diff --git a/core/lib.rs b/core/lib.rs index 618a5a678..884aad963 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -1,6 +1,7 @@ mod error; mod ext; mod function; +mod functions; mod info; mod io; #[cfg(feature = "json")] diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 2cb592828..8c84c98c3 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -22,13 +22,13 @@ mod datetime; pub mod explain; pub mod insn; pub mod likeop; -mod printf; pub mod sorter; mod strftime; use crate::error::{LimboError, SQLITE_CONSTRAINT_PRIMARYKEY}; use crate::ext::ExtValue; use crate::function::{AggFunc, ExtFunc, FuncCtx, MathFunc, MathFuncArity, ScalarFunc, VectorFunc}; +use crate::functions::printf::exec_printf; use crate::info; use crate::pseudo::PseudoCursor; use crate::result::LimboResult; @@ -60,7 +60,6 @@ use insn::{ exec_subtract, }; use likeop::{construct_like_escape_arg, exec_glob, exec_like_with_escape}; -use printf::exec_printf; use rand::distributions::{Distribution, Uniform}; use rand::{thread_rng, Rng}; use regex::{Regex, RegexBuilder}; From ee8eabf16739c29715970812449ff96357418abd Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Thu, 6 Feb 2025 13:52:25 +0200 Subject: [PATCH 07/16] core: Move datetime to functions module --- core/{vdbe => functions}/datetime.rs | 2 +- core/functions/mod.rs | 1 + core/vdbe/mod.rs | 9 ++++----- 3 files changed, 6 insertions(+), 6 deletions(-) rename core/{vdbe => functions}/datetime.rs (99%) diff --git a/core/vdbe/datetime.rs b/core/functions/datetime.rs similarity index 99% rename from core/vdbe/datetime.rs rename to core/functions/datetime.rs index 878e722e6..ce84816b1 100644 --- a/core/vdbe/datetime.rs +++ b/core/functions/datetime.rs @@ -122,7 +122,7 @@ fn format_dt(dt: NaiveDateTime, output_type: DateTimeOutput, subsec: bool) -> St // Not as fast as if the formatting was native to chrono, but a good enough // for now, just to have the feature implemented fn strftime_format(dt: &NaiveDateTime, format_str: &str) -> String { - use super::strftime::CustomStrftimeItems; + use crate::vdbe::strftime::CustomStrftimeItems; use std::fmt::Write; // Necessary to remove %f and %J that are exclusive formatters to sqlite // Chrono does not support them, so it is necessary to replace the modifiers manually diff --git a/core/functions/mod.rs b/core/functions/mod.rs index 1874bec84..d2a84e41e 100644 --- a/core/functions/mod.rs +++ b/core/functions/mod.rs @@ -1 +1,2 @@ +pub mod datetime; pub mod printf; diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 8c84c98c3..e22acb77f 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -18,16 +18,18 @@ //! https://www.sqlite.org/opcode.html pub mod builder; -mod datetime; pub mod explain; pub mod insn; pub mod likeop; pub mod sorter; -mod strftime; +pub mod strftime; use crate::error::{LimboError, SQLITE_CONSTRAINT_PRIMARYKEY}; use crate::ext::ExtValue; use crate::function::{AggFunc, ExtFunc, FuncCtx, MathFunc, MathFuncArity, ScalarFunc, VectorFunc}; +use crate::functions::datetime::{ + exec_date, exec_datetime_full, exec_julianday, exec_strftime, exec_time, exec_unixepoch, +}; use crate::functions::printf::exec_printf; use crate::info; use crate::pseudo::PseudoCursor; @@ -51,9 +53,6 @@ use crate::{ json::json_remove, json::json_set, json::json_type, }; use crate::{resolve_ext_path, Connection, Result, TransactionState, DATABASE_VERSION}; -use datetime::{ - exec_date, exec_datetime_full, exec_julianday, exec_strftime, exec_time, exec_unixepoch, -}; use insn::{ exec_add, exec_and, exec_bit_and, exec_bit_not, exec_bit_or, exec_boolean_not, exec_concat, exec_divide, exec_multiply, exec_or, exec_remainder, exec_shift_left, exec_shift_right, From f4a574e6bc92936c9a9b8f859895fe06e4c202e1 Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Thu, 6 Feb 2025 13:53:36 +0200 Subject: [PATCH 08/16] core: Move strftime to functions module --- core/functions/datetime.rs | 2 +- core/functions/mod.rs | 1 + core/{vdbe => functions}/strftime.rs | 0 core/vdbe/mod.rs | 1 - 4 files changed, 2 insertions(+), 2 deletions(-) rename core/{vdbe => functions}/strftime.rs (100%) diff --git a/core/functions/datetime.rs b/core/functions/datetime.rs index ce84816b1..8c5c0aaff 100644 --- a/core/functions/datetime.rs +++ b/core/functions/datetime.rs @@ -122,7 +122,7 @@ fn format_dt(dt: NaiveDateTime, output_type: DateTimeOutput, subsec: bool) -> St // Not as fast as if the formatting was native to chrono, but a good enough // for now, just to have the feature implemented fn strftime_format(dt: &NaiveDateTime, format_str: &str) -> String { - use crate::vdbe::strftime::CustomStrftimeItems; + use crate::functions::strftime::CustomStrftimeItems; use std::fmt::Write; // Necessary to remove %f and %J that are exclusive formatters to sqlite // Chrono does not support them, so it is necessary to replace the modifiers manually diff --git a/core/functions/mod.rs b/core/functions/mod.rs index d2a84e41e..cf97f318a 100644 --- a/core/functions/mod.rs +++ b/core/functions/mod.rs @@ -1,2 +1,3 @@ pub mod datetime; pub mod printf; +pub mod strftime; diff --git a/core/vdbe/strftime.rs b/core/functions/strftime.rs similarity index 100% rename from core/vdbe/strftime.rs rename to core/functions/strftime.rs diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index e22acb77f..e77ccf551 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -22,7 +22,6 @@ pub mod explain; pub mod insn; pub mod likeop; pub mod sorter; -pub mod strftime; use crate::error::{LimboError, SQLITE_CONSTRAINT_PRIMARYKEY}; use crate::ext::ExtValue; From 55eb55a63480a56256c79f46438ea96e0d69ca00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=84=A0=EC=9A=B0?= Date: Thu, 6 Feb 2025 20:54:13 +0900 Subject: [PATCH 09/16] Rename package to tursodatabase --- bindings/java/example/build.gradle.kts | 8 +++++++- .../java/org/github/{seonwkim => tursodatabase}/Main.java | 8 ++++---- 2 files changed, 11 insertions(+), 5 deletions(-) rename bindings/java/example/src/main/java/org/github/{seonwkim => tursodatabase}/Main.java (77%) diff --git a/bindings/java/example/build.gradle.kts b/bindings/java/example/build.gradle.kts index 44e605d11..5a0c01a8c 100644 --- a/bindings/java/example/build.gradle.kts +++ b/bindings/java/example/build.gradle.kts @@ -2,7 +2,7 @@ plugins { id("java") } -group = "org.github.seonwkim" +group = "org.github.tursodatabase" version = "1.0-SNAPSHOT" repositories { @@ -19,3 +19,9 @@ dependencies { tasks.test { useJUnitPlatform() } + +tasks.register("run") { + group = "application" + classpath = sourceSets["main"].runtimeClasspath + mainClass.set("org.github.tursodatabase.Main") +} diff --git a/bindings/java/example/src/main/java/org/github/seonwkim/Main.java b/bindings/java/example/src/main/java/org/github/tursodatabase/Main.java similarity index 77% rename from bindings/java/example/src/main/java/org/github/seonwkim/Main.java rename to bindings/java/example/src/main/java/org/github/tursodatabase/Main.java index ca1d8bc9d..41a477560 100644 --- a/bindings/java/example/src/main/java/org/github/seonwkim/Main.java +++ b/bindings/java/example/src/main/java/org/github/tursodatabase/Main.java @@ -1,4 +1,4 @@ -package org.github.seonwkim; +package org.github.tursodatabase; import java.sql.Connection; import java.sql.DriverManager; @@ -14,9 +14,9 @@ public class Main { ResultSet.CONCUR_READ_ONLY, ResultSet.CLOSE_CURSORS_AT_COMMIT); stmt.execute("CREATE TABLE users (id INT PRIMARY KEY, username TEXT);"); - stmt.execute("INSERT INTO users VALUES (1, 'seonwoo');"); - stmt.execute("INSERT INTO users VALUES (2, 'seonwoo');"); - stmt.execute("INSERT INTO users VALUES (3, 'seonwoo');"); + stmt.execute("INSERT INTO users VALUES (1, 'limbo');"); + stmt.execute("INSERT INTO users VALUES (2, 'turso');"); + stmt.execute("INSERT INTO users VALUES (3, 'who knows');"); stmt.execute("SELECT * FROM users"); System.out.println( "result: " + stmt.getResultSet().getInt(1) + ", " + stmt.getResultSet().getString(2)); From f5f77c0bd164e13e82c32f8bc2124c8bab429a8b Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Sun, 19 Jan 2025 15:10:47 +0200 Subject: [PATCH 10/16] Initial virtual table implementation --- Cargo.lock | 8 ++ Cargo.toml | 1 + core/ext/mod.rs | 80 +++++++++++++++++++- core/lib.rs | 74 ++++++++++++++++++- core/schema.rs | 17 +++++ core/translate/delete.rs | 3 +- core/translate/expr.rs | 55 ++++++++++---- core/translate/main_loop.rs | 138 +++++++++++++++++++++++------------ core/translate/optimizer.rs | 26 +++---- core/translate/plan.rs | 37 +++++++++- core/translate/planner.rs | 31 +++++++- core/types.rs | 9 +++ core/util.rs | 75 ++++++++++++++++++- core/vdbe/builder.rs | 6 +- core/vdbe/explain.rs | 57 +++++++++++++++ core/vdbe/insn.rs | 29 ++++++++ core/vdbe/mod.rs | 98 +++++++++++++++++++++++++ extensions/core/src/lib.rs | 75 ++++++++++++++++++- extensions/core/src/types.rs | 4 +- extensions/series/Cargo.toml | 15 ++++ extensions/series/src/lib.rs | 136 ++++++++++++++++++++++++++++++++++ macros/src/args.rs | 22 +++--- macros/src/lib.rs | 132 +++++++++++++++++++++++++++++++-- 23 files changed, 1015 insertions(+), 113 deletions(-) create mode 100644 extensions/series/Cargo.toml create mode 100644 extensions/series/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index f8a0df0e6..26cd80646 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1701,6 +1701,14 @@ dependencies = [ "regex", ] +[[package]] +name = "limbo_series" +version = "0.0.14" +dependencies = [ + "limbo_ext", + "log", +] + [[package]] name = "limbo_sim" version = "0.0.14" diff --git a/Cargo.toml b/Cargo.toml index 754583f42..400595b4a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ members = [ "extensions/percentile", "extensions/time", "extensions/crypto", + "extensions/series", ] exclude = ["perf/latency/limbo"] diff --git a/core/ext/mod.rs b/core/ext/mod.rs index db7876431..6d034e313 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -1,6 +1,11 @@ -use crate::{function::ExternalFunc, Database}; -use limbo_ext::{ExtensionApi, InitAggFunction, ResultCode, ScalarFunction}; +use crate::{function::ExternalFunc, util::columns_from_create_table_body, Database, VirtualTable}; +use fallible_iterator::FallibleIterator; +use limbo_ext::{ExtensionApi, InitAggFunction, ResultCode, ScalarFunction, VTabModuleImpl}; pub use limbo_ext::{FinalizeFunction, StepFunction, Value as ExtValue, ValueType as ExtValueType}; +use sqlite3_parser::{ + ast::{Cmd, Stmt}, + lexer::sql::Parser, +}; use std::{ ffi::{c_char, c_void, CStr}, rc::Rc, @@ -44,6 +49,48 @@ unsafe extern "C" fn register_aggregate_function( db.register_aggregate_function_impl(&name_str, args, (init_func, step_func, finalize_func)) } +unsafe extern "C" fn register_module( + ctx: *mut c_void, + name: *const c_char, + module: VTabModuleImpl, +) -> ResultCode { + let c_str = unsafe { CStr::from_ptr(name) }; + let name_str = match c_str.to_str() { + Ok(s) => s.to_string(), + Err(_) => return ResultCode::Error, + }; + if ctx.is_null() { + return ResultCode::Error; + } + let db = unsafe { &mut *(ctx as *mut Database) }; + + db.register_module_impl(&name_str, module) +} + +unsafe extern "C" fn declare_vtab( + ctx: *mut c_void, + name: *const c_char, + sql: *const c_char, +) -> ResultCode { + let c_str = unsafe { CStr::from_ptr(name) }; + let name_str = match c_str.to_str() { + Ok(s) => s.to_string(), + Err(_) => return ResultCode::Error, + }; + + let c_str = unsafe { CStr::from_ptr(sql) }; + let sql_str = match c_str.to_str() { + Ok(s) => s.to_string(), + Err(_) => return ResultCode::Error, + }; + + if ctx.is_null() { + return ResultCode::Error; + } + let db = unsafe { &mut *(ctx as *mut Database) }; + db.declare_vtab_impl(&name_str, &sql_str) +} + impl Database { fn register_scalar_function_impl(&self, name: &str, func: ScalarFunction) -> ResultCode { self.syms.borrow_mut().functions.insert( @@ -66,11 +113,40 @@ impl Database { ResultCode::OK } + fn register_module_impl(&mut self, name: &str, module: VTabModuleImpl) -> ResultCode { + self.vtab_modules.insert(name.to_string(), Rc::new(module)); + ResultCode::OK + } + + fn declare_vtab_impl(&mut self, name: &str, sql: &str) -> ResultCode { + let mut parser = Parser::new(sql.as_bytes()); + let cmd = parser.next().unwrap().unwrap(); + let Cmd::Stmt(stmt) = cmd else { + return ResultCode::Error; + }; + let Stmt::CreateTable { body, .. } = stmt else { + return ResultCode::Error; + }; + let Ok(columns) = columns_from_create_table_body(body) else { + return ResultCode::Error; + }; + let vtab_module = self.vtab_modules.get(name).unwrap().clone(); + let vtab = VirtualTable { + name: name.to_string(), + implementation: vtab_module, + columns, + }; + self.syms.borrow_mut().vtabs.insert(name.to_string(), vtab); + ResultCode::OK + } + pub fn build_limbo_ext(&self) -> ExtensionApi { ExtensionApi { ctx: self as *const _ as *mut c_void, register_scalar_function, register_aggregate_function, + register_module, + declare_vtab, } } diff --git a/core/lib.rs b/core/lib.rs index 884aad963..ccc2c2273 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -25,12 +25,13 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; use fallible_iterator::FallibleIterator; #[cfg(not(target_family = "wasm"))] use libloading::{Library, Symbol}; -use limbo_ext::{ExtensionApi, ExtensionEntryPoint}; +#[cfg(not(target_family = "wasm"))] +use limbo_ext::{ExtensionApi, ExtensionEntryPoint, ResultCode}; +use limbo_ext::{VTabModuleImpl, Value as ExtValue}; use log::trace; use parking_lot::RwLock; -use schema::Schema; -use sqlite3_parser::ast; -use sqlite3_parser::{ast::Cmd, lexer::sql::Parser}; +use schema::{Column, Schema}; +use sqlite3_parser::{ast, ast::Cmd, lexer::sql::Parser}; use std::cell::Cell; use std::collections::HashMap; use std::num::NonZero; @@ -44,9 +45,11 @@ use storage::pager::allocate_page; use storage::sqlite3_ondisk::{DatabaseHeader, DATABASE_HEADER_SIZE}; pub use storage::wal::WalFile; pub use storage::wal::WalFileShared; +use types::OwnedValue; pub use types::Value; use util::parse_schema_rows; use vdbe::builder::QueryMode; +use vdbe::VTabOpaqueCursor; pub use error::LimboError; use translate::select::prepare_select_plan; @@ -82,6 +85,7 @@ pub struct Database { schema: Rc>, header: Rc>, syms: Rc>, + vtab_modules: HashMap>, // Shared structures of a Database are the parts that are common to multiple threads that might // create DB connections. _shared_page_cache: Arc>, @@ -144,6 +148,7 @@ impl Database { _shared_page_cache: _shared_page_cache.clone(), _shared_wal: shared_wal.clone(), syms, + vtab_modules: HashMap::new(), }; if let Err(e) = db.register_builtins() { return Err(LimboError::ExtensionError(e)); @@ -506,10 +511,70 @@ pub type Row = types::Record; pub type StepResult = vdbe::StepResult; +#[derive(Clone, Debug)] +pub struct VirtualTable { + name: String, + pub implementation: Rc, + columns: Vec, +} + +impl VirtualTable { + pub fn open(&self) -> VTabOpaqueCursor { + let cursor = unsafe { (self.implementation.open)() }; + VTabOpaqueCursor::new(cursor) + } + + pub fn filter( + &self, + cursor: &VTabOpaqueCursor, + arg_count: usize, + args: Vec, + ) -> Result<()> { + let mut filter_args = Vec::with_capacity(arg_count); + for i in 0..arg_count { + let ownedvalue_arg = args.get(i).unwrap(); + let extvalue_arg: ExtValue = match ownedvalue_arg { + OwnedValue::Null => Ok(ExtValue::null()), + OwnedValue::Integer(i) => Ok(ExtValue::from_integer(*i)), + OwnedValue::Float(f) => Ok(ExtValue::from_float(*f)), + OwnedValue::Text(t) => Ok(ExtValue::from_text((*t.value).clone())), + OwnedValue::Blob(b) => Ok(ExtValue::from_blob((**b).clone())), + other => Err(LimboError::ExtensionError(format!( + "Unsupported value type: {:?}", + other + ))), + }?; + filter_args.push(extvalue_arg); + } + let rc = unsafe { + (self.implementation.filter)(cursor.as_ptr(), arg_count as i32, filter_args.as_ptr()) + }; + match rc { + ResultCode::OK => Ok(()), + _ => Err(LimboError::ExtensionError("Filter failed".to_string())), + } + } + + pub fn column(&self, cursor: &VTabOpaqueCursor, column: usize) -> Result { + let val = unsafe { (self.implementation.column)(cursor.as_ptr(), column as u32) }; + OwnedValue::from_ffi(&val) + } + + pub fn next(&self, cursor: &VTabOpaqueCursor) -> Result { + let rc = unsafe { (self.implementation.next)(cursor.as_ptr()) }; + match rc { + ResultCode::OK => Ok(true), + ResultCode::EOF => Ok(false), + _ => Err(LimboError::ExtensionError("Next failed".to_string())), + } + } +} + pub(crate) struct SymbolTable { pub functions: HashMap>, #[cfg(not(target_family = "wasm"))] extensions: Vec<(Library, *const ExtensionApi)>, + pub vtabs: HashMap, } impl std::fmt::Debug for SymbolTable { @@ -551,6 +616,7 @@ impl SymbolTable { pub fn new() -> Self { Self { functions: HashMap::new(), + vtabs: HashMap::new(), #[cfg(not(target_family = "wasm"))] extensions: Vec::new(), } diff --git a/core/schema.rs b/core/schema.rs index a5f1e6121..e7688b58e 100644 --- a/core/schema.rs +++ b/core/schema.rs @@ -1,3 +1,4 @@ +use crate::VirtualTable; use crate::{util::normalize_ident, Result}; use core::fmt; use fallible_iterator::FallibleIterator; @@ -47,6 +48,7 @@ impl Schema { pub enum Table { BTree(Rc), Pseudo(Rc), + Virtual(Rc), } impl Table { @@ -54,6 +56,7 @@ impl Table { match self { Table::BTree(table) => table.root_page, Table::Pseudo(_) => unimplemented!(), + Table::Virtual(_) => unimplemented!(), } } @@ -61,6 +64,7 @@ impl Table { match self { Self::BTree(table) => &table.name, Self::Pseudo(_) => "", + Self::Virtual(table) => &table.name, } } @@ -74,6 +78,10 @@ impl Table { .columns .get(index) .expect("column index out of bounds"), + Self::Virtual(table) => table + .columns + .get(index) + .expect("column index out of bounds"), } } @@ -81,6 +89,7 @@ impl Table { match self { Self::BTree(table) => &table.columns, Self::Pseudo(table) => &table.columns, + Self::Virtual(table) => &table.columns, } } @@ -88,6 +97,13 @@ impl Table { match self { Self::BTree(table) => Some(table.clone()), Self::Pseudo(_) => None, + Self::Virtual(_) => None, + } + } + pub fn virtual_table(&self) -> Option> { + match self { + Self::Virtual(table) => Some(table.clone()), + _ => None, } } } @@ -97,6 +113,7 @@ impl PartialEq for Table { match (self, other) { (Self::BTree(a), Self::BTree(b)) => Rc::ptr_eq(a, b), (Self::Pseudo(a), Self::Pseudo(b)) => Rc::ptr_eq(a, b), + (Self::Virtual(a), Self::Virtual(b)) => Rc::ptr_eq(a, b), _ => false, } } diff --git a/core/translate/delete.rs b/core/translate/delete.rs index ffad33d73..675b58f34 100644 --- a/core/translate/delete.rs +++ b/core/translate/delete.rs @@ -7,7 +7,7 @@ use crate::vdbe::builder::{ProgramBuilder, ProgramBuilderOpts, QueryMode}; use crate::{schema::Schema, Result, SymbolTable}; use sqlite3_parser::ast::{Expr, Limit, QualifiedName}; -use super::plan::TableReference; +use super::plan::{TableReference, TableReferenceType}; pub fn translate_delete( query_mode: QueryMode, @@ -48,6 +48,7 @@ pub fn prepare_delete_plan( identifier: table.name.clone(), op: Operation::Scan { iter_dir: None }, join_info: None, + reference_type: TableReferenceType::BTreeTable, }]; let mut where_predicates = vec![]; diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 45fb1a648..8b2e70185 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -13,7 +13,7 @@ use crate::vdbe::{ use crate::Result; use super::emitter::Resolver; -use super::plan::{Operation, TableReference}; +use super::plan::{Operation, TableReference, TableReferenceType}; #[derive(Debug, Clone, Copy)] pub struct ConditionMetadata { @@ -1824,22 +1824,45 @@ pub fn translate_expr( // If we are reading a column from a table, we find the cursor that corresponds to // the table and read the column from the cursor. Operation::Scan { .. } | Operation::Search(_) => { - let cursor_id = program.resolve_cursor_id(&table_reference.identifier); - if *is_rowid_alias { - program.emit_insn(Insn::RowId { - cursor_id, - dest: target_register, - }); - } else { - program.emit_insn(Insn::Column { - cursor_id, - column: *column, - dest: target_register, - }); + match &table_reference.reference_type { + TableReferenceType::BTreeTable => { + let cursor_id = program.resolve_cursor_id(&table_reference.identifier); + if *is_rowid_alias { + program.emit_insn(Insn::RowId { + cursor_id, + dest: target_register, + }); + } else { + program.emit_insn(Insn::Column { + cursor_id, + column: *column, + dest: target_register, + }); + } + let column = table_reference.table.get_column_at(*column); + maybe_apply_affinity(column.ty, target_register, program); + Ok(target_register) + } + TableReferenceType::VirtualTable { .. } => { + let cursor_id = program.resolve_cursor_id(&table_reference.identifier); + program.emit_insn(Insn::VColumn { + cursor_id, + column: *column, + dest: target_register, + }); + Ok(target_register) + } + TableReferenceType::Subquery { + result_columns_start_reg, + } => { + program.emit_insn(Insn::Copy { + src_reg: result_columns_start_reg + *column, + dst_reg: target_register, + amount: 0, + }); + Ok(target_register) + } } - let column = table_reference.table.get_column_at(*column); - maybe_apply_affinity(column.ty, target_register, program); - Ok(target_register) } // If we are reading a column from a subquery, we instead copy the column from the // subquery's result registers. diff --git a/core/translate/main_loop.rs b/core/translate/main_loop.rs index b4fff9c7a..4558693e3 100644 --- a/core/translate/main_loop.rs +++ b/core/translate/main_loop.rs @@ -17,7 +17,7 @@ use super::{ order_by::{order_by_sorter_insert, sorter_insert}, plan::{ IterationDirection, Operation, Search, SelectPlan, SelectQueryType, TableReference, - WhereTerm, + TableReferenceType, WhereTerm, }, }; @@ -78,27 +78,40 @@ pub fn init_loop( } match &table.op { Operation::Scan { .. } => { + let ref_type = &table.reference_type; let cursor_id = program.alloc_cursor_id( Some(table.identifier.clone()), - CursorType::BTreeTable(table.btree().unwrap().clone()), + match ref_type { + TableReferenceType::BTreeTable => { + CursorType::BTreeTable(table.btree().unwrap().clone()) + } + TableReferenceType::VirtualTable { .. } => { + CursorType::VirtualTable(table.virtual_table().unwrap().clone()) + } + other => panic!("Invalid table reference type in Scan: {:?}", other), + }, ); - let root_page = table.table.get_root_page(); - - match mode { - OperationMode::SELECT => { + match (mode, ref_type) { + (OperationMode::SELECT, TableReferenceType::BTreeTable) => { + let root_page = table.btree().unwrap().root_page; program.emit_insn(Insn::OpenReadAsync { cursor_id, root_page, }); program.emit_insn(Insn::OpenReadAwait {}); } - OperationMode::DELETE => { + (OperationMode::DELETE, TableReferenceType::BTreeTable) => { + let root_page = table.btree().unwrap().root_page; program.emit_insn(Insn::OpenWriteAsync { cursor_id, root_page, }); program.emit_insn(Insn::OpenWriteAwait {}); } + (OperationMode::SELECT, TableReferenceType::VirtualTable { .. }) => { + program.emit_insn(Insn::VOpenAsync { cursor_id }); + program.emit_insn(Insn::VOpenAwait {}); + } _ => { unimplemented!() } @@ -245,31 +258,52 @@ pub fn open_loop( } } Operation::Scan { iter_dir } => { + let ref_type = &table.reference_type; let cursor_id = program.resolve_cursor_id(&table.identifier); - if iter_dir - .as_ref() - .is_some_and(|dir| *dir == IterationDirection::Backwards) - { - program.emit_insn(Insn::LastAsync { cursor_id }); - } else { - program.emit_insn(Insn::RewindAsync { cursor_id }); - } - program.emit_insn( + + if !matches!(ref_type, TableReferenceType::VirtualTable { .. }) { if iter_dir .as_ref() .is_some_and(|dir| *dir == IterationDirection::Backwards) { - Insn::LastAwait { - cursor_id, - pc_if_empty: loop_end, - } + program.emit_insn(Insn::LastAsync { cursor_id }); } else { - Insn::RewindAwait { - cursor_id, - pc_if_empty: loop_end, + program.emit_insn(Insn::RewindAsync { cursor_id }); + } + } + match ref_type { + TableReferenceType::BTreeTable => program.emit_insn( + if iter_dir + .as_ref() + .is_some_and(|dir| *dir == IterationDirection::Backwards) + { + Insn::LastAwait { + cursor_id, + pc_if_empty: loop_end, + } + } else { + Insn::RewindAwait { + cursor_id, + pc_if_empty: loop_end, + } + }, + ), + TableReferenceType::VirtualTable { args, .. } => { + let start_reg = program.alloc_registers(args.len()); + let mut cur_reg = start_reg; + for arg in args { + let reg = cur_reg; + cur_reg += 1; + translate_expr(program, Some(tables), arg, reg, &t_ctx.resolver)?; } - }, - ); + program.emit_insn(Insn::VFilter { + cursor_id, + arg_count: args.len(), + args_reg: start_reg, + }); + } + other => panic!("Unsupported table reference type: {:?}", other), + } program.resolve_label(loop_start, program.offset()); for cond in predicates @@ -688,29 +722,41 @@ pub fn close_loop( }); } Operation::Scan { iter_dir, .. } => { + let ref_type = &table.reference_type; program.resolve_label(loop_labels.next, program.offset()); let cursor_id = program.resolve_cursor_id(&table.identifier); - if iter_dir - .as_ref() - .is_some_and(|dir| *dir == IterationDirection::Backwards) - { - program.emit_insn(Insn::PrevAsync { cursor_id }); - } else { - program.emit_insn(Insn::NextAsync { cursor_id }); - } - if iter_dir - .as_ref() - .is_some_and(|dir| *dir == IterationDirection::Backwards) - { - program.emit_insn(Insn::PrevAwait { - cursor_id, - pc_if_next: loop_labels.loop_start, - }); - } else { - program.emit_insn(Insn::NextAwait { - cursor_id, - pc_if_next: loop_labels.loop_start, - }); + match ref_type { + TableReferenceType::BTreeTable { .. } => { + if iter_dir + .as_ref() + .is_some_and(|dir| *dir == IterationDirection::Backwards) + { + program.emit_insn(Insn::PrevAsync { cursor_id }); + } else { + program.emit_insn(Insn::NextAsync { cursor_id }); + } + if iter_dir + .as_ref() + .is_some_and(|dir| *dir == IterationDirection::Backwards) + { + program.emit_insn(Insn::PrevAwait { + cursor_id, + pc_if_next: loop_labels.loop_start, + }); + } else { + program.emit_insn(Insn::NextAwait { + cursor_id, + pc_if_next: loop_labels.loop_start, + }); + } + } + TableReferenceType::VirtualTable { .. } => { + program.emit_insn(Insn::VNext { + cursor_id, + pc_if_next: loop_labels.loop_start, + }); + } + other => unreachable!("Unsupported table reference type: {:?}", other), } } Operation::Search(search) => { diff --git a/core/translate/optimizer.rs b/core/translate/optimizer.rs index 53df77956..73124060f 100644 --- a/core/translate/optimizer.rs +++ b/core/translate/optimizer.rs @@ -204,16 +204,16 @@ fn eliminate_constant_conditions( } fn push_scan_direction(table: &mut TableReference, direction: &Direction) { - match &mut table.op { - Operation::Scan { iter_dir, .. } => { - if iter_dir.is_none() { - match direction { - Direction::Ascending => *iter_dir = Some(IterationDirection::Forwards), - Direction::Descending => *iter_dir = Some(IterationDirection::Backwards), - } + if let Operation::Scan { + ref mut iter_dir, .. + } = table.op + { + if iter_dir.is_none() { + match direction { + Direction::Ascending => *iter_dir = Some(IterationDirection::Forwards), + Direction::Descending => *iter_dir = Some(IterationDirection::Backwards), } } - _ => {} } } @@ -309,12 +309,10 @@ impl Optimizable for ast::Expr { }; let column = table_reference.table.get_column_at(*column); for index in available_indexes_for_table.iter() { - if column - .name - .as_ref() - .map_or(false, |name| *name == index.columns.first().unwrap().name) - { - return Ok(Some(index.clone())); + if let Some(name) = column.name.as_ref() { + if &index.columns.first().unwrap().name == name { + return Ok(Some(index.clone())); + } } } Ok(None) diff --git a/core/translate/plan.rs b/core/translate/plan.rs index e59ffd5e8..43cba8e1b 100644 --- a/core/translate/plan.rs +++ b/core/translate/plan.rs @@ -9,6 +9,7 @@ use crate::{ function::AggFunc, schema::{BTreeTable, Column, Index, Table}, vdbe::BranchOffset, + VirtualTable, }; use crate::{ schema::{PseudoTable, Type}, @@ -197,11 +198,9 @@ pub struct TableReference { pub identifier: String, /// The join info for this table reference, if it is the right side of a join (which all except the first table reference have) pub join_info: Option, + pub reference_type: TableReferenceType, } -/** - A SourceOperator is a reference in the query plan that reads data from a table. -*/ #[derive(Clone, Debug)] pub enum Operation { // Scan operation @@ -226,10 +225,37 @@ pub enum Operation { }, } +/// The type of the table reference, either BTreeTable or Subquery +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum TableReferenceType { + /// A BTreeTable is a table that is stored on disk in a B-tree index. + BTreeTable, + /// A subquery. + Subquery { + /// The index of the first register in the query plan that contains the result columns of the subquery. + result_columns_start_reg: usize, + }, + /// A virtual table. + VirtualTable { + /// Arguments to pass e.g. generate_series(1, 10, 2) + args: Vec, + }, +} + impl TableReference { /// Returns the btree table for this table reference, if it is a BTreeTable. pub fn btree(&self) -> Option> { - self.table.btree() + match &self.reference_type { + TableReferenceType::BTreeTable => self.table.btree(), + TableReferenceType::Subquery { .. } => None, + TableReferenceType::VirtualTable { .. } => None, + } + } + pub fn virtual_table(&self) -> Option> { + match &self.reference_type { + TableReferenceType::VirtualTable { .. } => self.table.virtual_table(), + _ => None, + } } /// Creates a new TableReference for a subquery. @@ -254,6 +280,9 @@ impl TableReference { result_columns_start_reg: 0, // Will be set in the bytecode emission phase }, table, + reference_type: TableReferenceType::Subquery { + result_columns_start_reg: 0, // Will be set in the bytecode emission phase + }, identifier: identifier.clone(), join_info, } diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 95cee1edf..272b788a2 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -1,7 +1,9 @@ +use std::rc::Rc; + use super::{ plan::{ Aggregate, JoinInfo, Operation, Plan, ResultSetColumn, SelectQueryType, TableReference, - WhereTerm, + TableReferenceType, WhereTerm, }, select::prepare_select_plan, SymbolTable, @@ -301,6 +303,7 @@ fn parse_from_clause_table( table: Table::BTree(table.clone()), identifier: alias.unwrap_or(normalized_qualified_name), join_info: None, + reference_type: TableReferenceType::BTreeTable, }) } ast::SelectTable::Select(subselect, maybe_alias) => { @@ -317,8 +320,30 @@ fn parse_from_clause_table( ast::As::Elided(id) => id.0.clone(), }) .unwrap_or(format!("subquery_{}", cur_table_index)); - let table_reference = TableReference::new_subquery(identifier, subplan, None); - Ok(table_reference) + Ok(TableReference::new_subquery(identifier, subplan, None)) + } + ast::SelectTable::TableCall(qualified_name, mut maybe_args, maybe_alias) => { + let normalized_name = normalize_ident(qualified_name.name.0.as_str()); + let Some(vtab) = syms.vtabs.get(&normalized_name) else { + crate::bail_parse_error!("Virtual table {} not found", normalized_name); + }; + let alias = maybe_alias + .as_ref() + .map(|a| match a { + ast::As::As(id) => id.0.clone(), + ast::As::Elided(id) => id.0.clone(), + }) + .unwrap_or(normalized_name); + + Ok(TableReference { + op: Operation::Scan { iter_dir: None }, + join_info: None, + table: Table::Virtual(vtab.clone().into()), + identifier: alias.clone(), + reference_type: TableReferenceType::VirtualTable { + args: maybe_args.take().unwrap_or_default(), + }, + }) } _ => todo!(), } diff --git a/core/types.rs b/core/types.rs index 06c459741..ae31314c1 100644 --- a/core/types.rs +++ b/core/types.rs @@ -6,6 +6,7 @@ use crate::pseudo::PseudoCursor; use crate::storage::btree::BTreeCursor; use crate::storage::sqlite3_ondisk::write_varint; use crate::vdbe::sorter::Sorter; +use crate::vdbe::VTabOpaqueCursor; use crate::Result; use std::fmt::Display; use std::rc::Rc; @@ -670,6 +671,7 @@ pub enum Cursor { Index(BTreeCursor), Pseudo(PseudoCursor), Sorter(Sorter), + Virtual(VTabOpaqueCursor), } impl Cursor { @@ -716,6 +718,13 @@ impl Cursor { _ => panic!("Cursor is not a sorter cursor"), } } + + pub fn as_virtual_mut(&mut self) -> &mut VTabOpaqueCursor { + match self { + Self::Virtual(cursor) => cursor, + _ => panic!("Cursor is not a virtual cursor"), + } + } } pub enum CursorResult { diff --git a/core/util.rs b/core/util.rs index c92c31412..5251b36cf 100644 --- a/core/util.rs +++ b/core/util.rs @@ -1,9 +1,9 @@ use std::{rc::Rc, sync::Arc}; -use sqlite3_parser::ast::{Expr, FunctionTail, Literal}; +use sqlite3_parser::ast::{CreateTableBody, Expr, FunctionTail, Literal}; use crate::{ - schema::{self, Schema}, + schema::{self, Column, Schema, Type}, Result, Statement, StepResult, IO, }; @@ -308,6 +308,77 @@ pub fn exprs_are_equivalent(expr1: &Expr, expr2: &Expr) -> bool { } } +pub fn columns_from_create_table_body(body: CreateTableBody) -> Result, ()> { + let CreateTableBody::ColumnsAndConstraints { columns, .. } = body else { + return Err(()); + }; + + Ok(columns + .into_iter() + .filter_map(|(name, column_def)| { + // if column_def.col_type includes HIDDEN, omit it for now + if let Some(data_type) = column_def.col_type.as_ref() { + if data_type.name.as_str().contains("HIDDEN") { + return None; + } + } + let column = Column { + name: Some(name.0), + ty: match column_def.col_type { + Some(ref data_type) => { + // https://www.sqlite.org/datatype3.html + let type_name = data_type.name.as_str().to_uppercase(); + if type_name.contains("INT") { + Type::Integer + } else if type_name.contains("CHAR") + || type_name.contains("CLOB") + || type_name.contains("TEXT") + { + Type::Text + } else if type_name.contains("BLOB") || type_name.is_empty() { + Type::Blob + } else if type_name.contains("REAL") + || type_name.contains("FLOA") + || type_name.contains("DOUB") + { + Type::Real + } else { + Type::Numeric + } + } + None => Type::Null, + }, + default: column_def + .constraints + .iter() + .find_map(|c| match &c.constraint { + sqlite3_parser::ast::ColumnConstraint::Default(val) => Some(val.clone()), + _ => None, + }), + notnull: column_def.constraints.iter().any(|c| { + matches!( + c.constraint, + sqlite3_parser::ast::ColumnConstraint::NotNull { .. } + ) + }), + ty_str: column_def + .col_type + .clone() + .map(|t| t.name.to_string()) + .unwrap_or_default(), + primary_key: column_def.constraints.iter().any(|c| { + matches!( + c.constraint, + sqlite3_parser::ast::ColumnConstraint::PrimaryKey { .. } + ) + }), + is_rowid_alias: false, + }; + Some(column) + }) + .collect::>()) +} + #[cfg(test)] pub mod tests { use super::*; diff --git a/core/vdbe/builder.rs b/core/vdbe/builder.rs index 21b4b8949..c3ead0d38 100644 --- a/core/vdbe/builder.rs +++ b/core/vdbe/builder.rs @@ -9,7 +9,7 @@ use crate::{ schema::{BTreeTable, Index, PseudoTable}, storage::sqlite3_ondisk::DatabaseHeader, translate::plan::{ResultSetColumn, TableReference}, - Connection, + Connection, VirtualTable, }; use super::{BranchOffset, CursorID, Insn, InsnReference, Program}; @@ -40,6 +40,7 @@ pub enum CursorType { BTreeIndex(Rc), Pseudo(Rc), Sorter, + VirtualTable(Rc), } impl CursorType { @@ -406,6 +407,9 @@ impl ProgramBuilder { Insn::IsNull { reg: _, target_pc } => { resolve(target_pc, "IsNull"); } + Insn::VNext { pc_if_next, .. } => { + resolve(pc_if_next, "VNext"); + } _ => continue, } } diff --git a/core/vdbe/explain.rs b/core/vdbe/explain.rs index e4c302bba..a0bb63023 100644 --- a/core/vdbe/explain.rs +++ b/core/vdbe/explain.rs @@ -363,6 +363,62 @@ pub fn insn_to_str( 0, "".to_string(), ), + Insn::VOpenAsync { cursor_id } => ( + "VOpenAsync", + *cursor_id as i32, + 0, + 0, + OwnedValue::build_text(Rc::new("".to_string())), + 0, + "".to_string(), + ), + Insn::VOpenAwait => ( + "VOpenAwait", + 0, + 0, + 0, + OwnedValue::build_text(Rc::new("".to_string())), + 0, + "".to_string(), + ), + Insn::VFilter { + cursor_id, + arg_count, + args_reg, + } => ( + "VFilter", + *cursor_id as i32, + *arg_count as i32, + *args_reg as i32, + OwnedValue::build_text(Rc::new("".to_string())), + 0, + "".to_string(), + ), + Insn::VColumn { + cursor_id, + column, + dest, + } => ( + "VColumn", + *cursor_id as i32, + *column as i32, + *dest as i32, + OwnedValue::build_text(Rc::new("".to_string())), + 0, + "".to_string(), + ), + Insn::VNext { + cursor_id, + pc_if_next, + } => ( + "VNext", + *cursor_id as i32, + pc_if_next.to_debug_int(), + 0, + OwnedValue::build_text(Rc::new("".to_string())), + 0, + "".to_string(), + ), Insn::OpenPseudo { cursor_id, content_reg, @@ -423,6 +479,7 @@ pub fn insn_to_str( name } CursorType::Sorter => None, + CursorType::VirtualTable(v) => v.columns.get(*column).unwrap().name.as_ref(), }; ( "Column", diff --git a/core/vdbe/insn.rs b/core/vdbe/insn.rs index 1cdb81e25..223f321aa 100644 --- a/core/vdbe/insn.rs +++ b/core/vdbe/insn.rs @@ -213,6 +213,35 @@ pub enum Insn { // Await for the completion of open cursor. OpenReadAwait, + /// Open a cursor for a virtual table. + VOpenAsync { + cursor_id: CursorID, + }, + + /// Await for the completion of open cursor for a virtual table. + VOpenAwait, + + /// Initialize the position of the virtual table cursor. + VFilter { + cursor_id: CursorID, + arg_count: usize, + args_reg: usize, + }, + + /// Read a column from the current row of the virtual table cursor. + VColumn { + cursor_id: CursorID, + column: usize, + dest: usize, + }, + + /// Advance the virtual table cursor to the next row. + /// TODO: async + VNext { + cursor_id: CursorID, + pc_if_next: BranchOffset, + }, + // Open a cursor for a pseudo-table that contains a single row. OpenPseudo { cursor_id: CursorID, diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index e77ccf551..c25dc572b 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -65,6 +65,7 @@ use sorter::Sorter; use std::borrow::BorrowMut; use std::cell::{Cell, RefCell, RefMut}; use std::collections::HashMap; +use std::ffi::c_void; use std::num::NonZero; use std::rc::{Rc, Weak}; @@ -267,6 +268,18 @@ fn get_cursor_as_sorter_mut<'long, 'short>( cursor } +fn get_cursor_as_virtual_mut<'long, 'short>( + cursors: &'short mut RefMut<'long, Vec>>, + cursor_id: CursorID, +) -> &'short mut VTabOpaqueCursor { + let cursor = cursors + .get_mut(cursor_id) + .expect("cursor id out of bounds") + .as_mut() + .expect("cursor not allocated") + .as_virtual_mut(); + cursor +} struct Bitfield([u64; N]); impl Bitfield { @@ -290,6 +303,18 @@ impl Bitfield { } } +pub struct VTabOpaqueCursor(*mut c_void); + +impl VTabOpaqueCursor { + pub fn new(cursor: *mut c_void) -> Self { + Self(cursor) + } + + pub fn as_ptr(&self) -> *mut c_void { + self.0 + } +} + /// The program state describes the environment in which the program executes. pub struct ProgramState { pub pc: InsnReference, @@ -370,6 +395,7 @@ macro_rules! must_be_btree_cursor { CursorType::BTreeIndex(_) => get_cursor_as_index_mut(&mut $cursors, $cursor_id), CursorType::Pseudo(_) => panic!("{} on pseudo cursor", $insn_name), CursorType::Sorter => panic!("{} on sorter cursor", $insn_name), + CursorType::VirtualTable(_) => panic!("{} on virtual table cursor", $insn_name), }; cursor }}; @@ -826,12 +852,79 @@ impl Program { CursorType::Sorter => { panic!("OpenReadAsync on sorter cursor"); } + CursorType::VirtualTable(_) => { + panic!("OpenReadAsync on virtual table cursor, use Insn::VOpenAsync instead"); + } } state.pc += 1; } Insn::OpenReadAwait => { state.pc += 1; } + Insn::VOpenAsync { cursor_id } => { + let (_, cursor_type) = self.cursor_ref.get(*cursor_id).unwrap(); + let CursorType::VirtualTable(virtual_table) = cursor_type else { + panic!("VOpenAsync on non-virtual table cursor"); + }; + let cursor = virtual_table.open(); + state + .cursors + .borrow_mut() + .insert(*cursor_id, Some(Cursor::Virtual(cursor))); + state.pc += 1; + } + Insn::VOpenAwait => { + state.pc += 1; + } + Insn::VFilter { + cursor_id, + arg_count, + args_reg, + } => { + let (_, cursor_type) = self.cursor_ref.get(*cursor_id).unwrap(); + let CursorType::VirtualTable(virtual_table) = cursor_type else { + panic!("VFilter on non-virtual table cursor"); + }; + let mut cursors = state.cursors.borrow_mut(); + let cursor = get_cursor_as_virtual_mut(&mut cursors, *cursor_id); + let mut args = Vec::new(); + for i in 0..*arg_count { + args.push(state.registers[args_reg + i].clone()); + } + virtual_table.filter(cursor, *arg_count, args)?; + state.pc += 1; + } + Insn::VColumn { + cursor_id, + column, + dest, + } => { + let (_, cursor_type) = self.cursor_ref.get(*cursor_id).unwrap(); + let CursorType::VirtualTable(virtual_table) = cursor_type else { + panic!("VColumn on non-virtual table cursor"); + }; + let mut cursors = state.cursors.borrow_mut(); + let cursor = get_cursor_as_virtual_mut(&mut cursors, *cursor_id); + state.registers[*dest] = virtual_table.column(cursor, *column)?; + state.pc += 1; + } + Insn::VNext { + cursor_id, + pc_if_next, + } => { + let (_, cursor_type) = self.cursor_ref.get(*cursor_id).unwrap(); + let CursorType::VirtualTable(virtual_table) = cursor_type else { + panic!("VNextAsync on non-virtual table cursor"); + }; + let mut cursors = state.cursors.borrow_mut(); + let cursor = get_cursor_as_virtual_mut(&mut cursors, *cursor_id); + let has_more = virtual_table.next(cursor)?; + if has_more { + state.pc = pc_if_next.to_offset_int(); + } else { + state.pc += 1; + } + } Insn::OpenPseudo { cursor_id, content_reg: _, @@ -943,6 +1036,11 @@ impl Program { state.registers[*dest] = OwnedValue::Null; } } + CursorType::VirtualTable(_) => { + panic!( + "Insn::Column on virtual table cursor, use Insn::VColumn instead" + ); + } } state.pc += 1; diff --git a/extensions/core/src/lib.rs b/extensions/core/src/lib.rs index 5f9bb09c5..fec363c44 100644 --- a/extensions/core/src/lib.rs +++ b/extensions/core/src/lib.rs @@ -1,5 +1,5 @@ mod types; -pub use limbo_macros::{register_extension, scalar, AggregateDerive}; +pub use limbo_macros::{register_extension, scalar, AggregateDerive, VTabModuleDerive}; use std::os::raw::{c_char, c_void}; pub use types::{ResultCode, Value, ValueType}; @@ -21,6 +21,30 @@ pub struct ExtensionApi { step_func: StepFunction, finalize_func: FinalizeFunction, ) -> ResultCode, + + pub register_module: unsafe extern "C" fn( + ctx: *mut c_void, + name: *const c_char, + module: VTabModuleImpl, + ) -> ResultCode, + + pub declare_vtab: unsafe extern "C" fn( + ctx: *mut c_void, + name: *const c_char, + sql: *const c_char, + ) -> ResultCode, +} + +impl ExtensionApi { + pub fn declare_virtual_table(&self, name: &str, sql: &str) -> ResultCode { + let Ok(name) = std::ffi::CString::new(name) else { + return ResultCode::Error; + }; + let Ok(sql) = std::ffi::CString::new(sql) else { + return ResultCode::Error; + }; + unsafe { (self.declare_vtab)(self.ctx, name.as_ptr(), sql.as_ptr()) } + } } pub type ExtensionEntryPoint = unsafe extern "C" fn(api: *const ExtensionApi) -> ResultCode; @@ -47,3 +71,52 @@ pub trait AggFunc { fn step(state: &mut Self::State, args: &[Value]); fn finalize(state: Self::State) -> Value; } + +#[repr(C)] +#[derive(Clone, Debug)] +pub struct VTabModuleImpl { + pub name: *const c_char, + pub connect: VtabFnConnect, + pub open: VtabFnOpen, + pub filter: VtabFnFilter, + pub column: VtabFnColumn, + pub next: VtabFnNext, + pub eof: VtabFnEof, +} + +pub type VtabFnConnect = unsafe extern "C" fn(api: *const c_void) -> ResultCode; + +pub type VtabFnOpen = unsafe extern "C" fn() -> *mut c_void; + +pub type VtabFnFilter = + unsafe extern "C" fn(cursor: *mut c_void, argc: i32, argv: *const Value) -> ResultCode; + +pub type VtabFnColumn = unsafe extern "C" fn(cursor: *mut c_void, idx: u32) -> Value; + +pub type VtabFnNext = unsafe extern "C" fn(cursor: *mut c_void) -> ResultCode; + +pub type VtabFnEof = unsafe extern "C" fn(cursor: *mut c_void) -> bool; + +pub trait VTabModule: 'static { + type VCursor: VTabCursor; + + fn name() -> &'static str; + fn connect(api: &ExtensionApi) -> ResultCode; + fn open() -> Self::VCursor; + fn filter(cursor: &mut Self::VCursor, arg_count: i32, args: &[Value]) -> ResultCode; + fn column(cursor: &Self::VCursor, idx: u32) -> Value; + fn next(cursor: &mut Self::VCursor) -> ResultCode; + fn eof(cursor: &Self::VCursor) -> bool; +} + +pub trait VTabCursor: Sized { + fn rowid(&self) -> i64; + fn column(&self, idx: u32) -> Value; + fn eof(&self) -> bool; + fn next(&mut self) -> ResultCode; +} + +#[repr(C)] +pub struct VTabImpl { + pub module: VTabModuleImpl, +} diff --git a/extensions/core/src/types.rs b/extensions/core/src/types.rs index 464e07bfd..4a1fa3978 100644 --- a/extensions/core/src/types.rs +++ b/extensions/core/src/types.rs @@ -2,8 +2,8 @@ use std::fmt::Display; /// Error type is of type ExtError which can be /// either a user defined error or an error code -#[derive(Clone, Copy)] #[repr(C)] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum ResultCode { OK = 0, Error = 1, @@ -20,6 +20,7 @@ pub enum ResultCode { Internal = 12, Unavailable = 13, CustomError = 14, + EOF = 15, } impl ResultCode { @@ -50,6 +51,7 @@ impl Display for ResultCode { ResultCode::Internal => write!(f, "Internal Error"), ResultCode::Unavailable => write!(f, "Unavailable"), ResultCode::CustomError => write!(f, "Error "), + ResultCode::EOF => write!(f, "EOF"), } } } diff --git a/extensions/series/Cargo.toml b/extensions/series/Cargo.toml new file mode 100644 index 000000000..73a634ac7 --- /dev/null +++ b/extensions/series/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "limbo_series" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true + +[lib] +crate-type = ["cdylib", "lib"] + + +[dependencies] +limbo_ext = { path = "../core"} +log = "0.4.20" diff --git a/extensions/series/src/lib.rs b/extensions/series/src/lib.rs new file mode 100644 index 000000000..f438c6fce --- /dev/null +++ b/extensions/series/src/lib.rs @@ -0,0 +1,136 @@ +use limbo_ext::{ + register_extension, ExtensionApi, ResultCode, VTabCursor, VTabModule, VTabModuleDerive, Value, + ValueType, +}; + +register_extension! { + vtabs: { GenerateSeriesVTab } +} + +/// A virtual table that generates a sequence of integers +#[derive(Debug, VTabModuleDerive)] +struct GenerateSeriesVTab; + +impl VTabModule for GenerateSeriesVTab { + type VCursor = GenerateSeriesCursor; + fn name() -> &'static str { + "generate_series" + } + + fn connect(api: &ExtensionApi) -> ResultCode { + // Create table schema + let sql = "CREATE TABLE generate_series( + value INTEGER, + start INTEGER HIDDEN, + stop INTEGER HIDDEN, + step INTEGER HIDDEN + )"; + let name = Self::name(); + api.declare_virtual_table(name, sql) + } + + fn open() -> Self::VCursor { + GenerateSeriesCursor { + start: 0, + stop: 0, + step: 0, + current: 0, + } + } + + fn filter(cursor: &mut Self::VCursor, arg_count: i32, args: &[Value]) -> ResultCode { + // args are the start, stop, and step + if arg_count == 0 || arg_count > 3 { + return ResultCode::InvalidArgs; + } + let start = { + if args[0].value_type() == ValueType::Integer { + args[0].to_integer().unwrap() + } else { + return ResultCode::InvalidArgs; + } + }; + let stop = if args.len() == 1 { + i64::MAX + } else { + if args[1].value_type() == ValueType::Integer { + args[1].to_integer().unwrap() + } else { + return ResultCode::InvalidArgs; + } + }; + let step = if args.len() <= 2 { + 1 + } else { + if args[2].value_type() == ValueType::Integer { + args[2].to_integer().unwrap() + } else { + return ResultCode::InvalidArgs; + } + }; + cursor.start = start; + cursor.current = start; + cursor.stop = stop; + cursor.step = step; + ResultCode::OK + } + + fn column(cursor: &Self::VCursor, idx: u32) -> Value { + cursor.column(idx) + } + + fn next(cursor: &mut Self::VCursor) -> ResultCode { + GenerateSeriesCursor::next(cursor) + } + + fn eof(cursor: &Self::VCursor) -> bool { + cursor.eof() + } +} + +/// The cursor for iterating over the generated sequence +#[derive(Debug)] +struct GenerateSeriesCursor { + start: i64, + stop: i64, + step: i64, + current: i64, +} + +impl GenerateSeriesCursor { + fn next(&mut self) -> ResultCode { + let current = self.current; + + // Check if we've reached the end + if (self.step > 0 && current >= self.stop) || (self.step < 0 && current <= self.stop) { + return ResultCode::EOF; + } + + self.current = current.saturating_add(self.step); + ResultCode::OK + } +} + +impl VTabCursor for GenerateSeriesCursor { + fn next(&mut self) -> ResultCode { + GenerateSeriesCursor::next(self) + } + + fn eof(&self) -> bool { + (self.step > 0 && self.current > self.stop) || (self.step < 0 && self.current < self.stop) + } + + fn column(&self, idx: u32) -> Value { + match idx { + 0 => Value::from_integer(self.current), + 1 => Value::from_integer(self.start), + 2 => Value::from_integer(self.stop), + 3 => Value::from_integer(self.step), + _ => Value::null(), + } + } + + fn rowid(&self) -> i64 { + ((self.current - self.start) / self.step) + 1 + } +} diff --git a/macros/src/args.rs b/macros/src/args.rs index d9e59cbd3..12446b660 100644 --- a/macros/src/args.rs +++ b/macros/src/args.rs @@ -6,31 +6,32 @@ use syn::{Ident, LitStr, Token}; pub(crate) struct RegisterExtensionInput { pub aggregates: Vec, pub scalars: Vec, + pub vtabs: Vec, } impl syn::parse::Parse for RegisterExtensionInput { fn parse(input: syn::parse::ParseStream) -> syn::Result { let mut aggregates = Vec::new(); let mut scalars = Vec::new(); - + let mut vtabs = Vec::new(); while !input.is_empty() { if input.peek(syn::Ident) && input.peek2(Token![:]) { let section_name: Ident = input.parse()?; input.parse::()?; - - if section_name == "aggregates" || section_name == "scalars" { + let names = ["aggregates", "scalars", "vtabs"]; + if names.contains(§ion_name.to_string().as_str()) { let content; syn::braced!(content in input); - let parsed_items = Punctuated::::parse_terminated(&content)? .into_iter() .collect(); - if section_name == "aggregates" { - aggregates = parsed_items; - } else { - scalars = parsed_items; - } + match section_name.to_string().as_str() { + "aggregates" => aggregates = parsed_items, + "scalars" => scalars = parsed_items, + "vtabs" => vtabs = parsed_items, + _ => unreachable!(), + }; if input.peek(Token![,]) { input.parse::()?; @@ -39,13 +40,14 @@ impl syn::parse::Parse for RegisterExtensionInput { return Err(syn::Error::new(section_name.span(), "Unknown section")); } } else { - return Err(input.error("Expected aggregates: or scalars: section")); + return Err(input.error("Expected aggregates:, scalars:, or vtabs: section")); } } Ok(Self { aggregates, scalars, + vtabs, }) } } diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 1e0ef421e..6b0df9679 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -324,6 +324,103 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { TokenStream::from(expanded) } +#[proc_macro_derive(VTabModuleDerive)] +pub fn derive_vtab_module(input: TokenStream) -> TokenStream { + let ast = parse_macro_input!(input as DeriveInput); + let struct_name = &ast.ident; + + let register_fn_name = format_ident!("register_{}", struct_name); + let connect_fn_name = format_ident!("connect_{}", struct_name); + let open_fn_name = format_ident!("open_{}", struct_name); + let filter_fn_name = format_ident!("filter_{}", struct_name); + let column_fn_name = format_ident!("column_{}", struct_name); + let next_fn_name = format_ident!("next_{}", struct_name); + let eof_fn_name = format_ident!("eof_{}", struct_name); + + let expanded = quote! { + impl #struct_name { + #[no_mangle] + unsafe extern "C" fn #connect_fn_name( + db: *const ::std::ffi::c_void, + ) -> ::limbo_ext::ResultCode { + let api = unsafe { &*(db as *const ExtensionApi) }; + <#struct_name as ::limbo_ext::VTabModule>::connect(api) + } + + #[no_mangle] + unsafe extern "C" fn #open_fn_name( + ) -> *mut ::std::ffi::c_void { + let cursor = <#struct_name as ::limbo_ext::VTabModule>::open(); + Box::into_raw(Box::new(cursor)) as *mut ::std::ffi::c_void + } + + #[no_mangle] + unsafe extern "C" fn #filter_fn_name( + cursor: *mut ::std::ffi::c_void, + argc: i32, + argv: *const ::limbo_ext::Value, + ) -> ::limbo_ext::ResultCode { + let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; + let args = std::slice::from_raw_parts(argv, argc as usize); + <#struct_name as ::limbo_ext::VTabModule>::filter(cursor, argc, args) + } + + #[no_mangle] + unsafe extern "C" fn #column_fn_name( + cursor: *mut ::std::ffi::c_void, + idx: u32, + ) -> ::limbo_ext::Value { + let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; + <#struct_name as ::limbo_ext::VTabModule>::column(cursor, idx) + } + + #[no_mangle] + unsafe extern "C" fn #next_fn_name( + cursor: *mut ::std::ffi::c_void, + ) -> ::limbo_ext::ResultCode { + let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; + <#struct_name as ::limbo_ext::VTabModule>::next(cursor) + } + + #[no_mangle] + unsafe extern "C" fn #eof_fn_name( + cursor: *mut ::std::ffi::c_void, + ) -> bool { + let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; + <#struct_name as ::limbo_ext::VTabModule>::eof(cursor) + } + + #[no_mangle] + pub unsafe extern "C" fn #register_fn_name( + api: *const ::limbo_ext::ExtensionApi + ) -> ::limbo_ext::ResultCode { + if api.is_null() { + return ::limbo_ext::ResultCode::Error; + } + + let api = &*api; + let name = <#struct_name as ::limbo_ext::VTabModule>::name(); + // name needs to be a c str FFI compatible, NOT CString + let name_c = std::ffi::CString::new(name).unwrap(); + + let module = ::limbo_ext::VTabModuleImpl { + name: name_c.as_ptr(), + connect: Self::#connect_fn_name, + open: Self::#open_fn_name, + filter: Self::#filter_fn_name, + column: Self::#column_fn_name, + next: Self::#next_fn_name, + eof: Self::#eof_fn_name, + }; + + (api.register_module)(api.ctx, name_c.as_ptr(), module) + } + } + }; + + TokenStream::from(expanded) +} + /// Register your extension with 'core' by providing the relevant functions ///```ignore ///use limbo_ext::{register_extension, scalar, Value, AggregateDerive, AggFunc}; @@ -362,6 +459,7 @@ pub fn register_extension(input: TokenStream) -> TokenStream { let RegisterExtensionInput { aggregates, scalars, + vtabs, } = input_ast; let scalar_calls = scalars.iter().map(|scalar_ident| { @@ -388,8 +486,23 @@ pub fn register_extension(input: TokenStream) -> TokenStream { } } }); + let vtab_calls = vtabs.iter().map(|vtab_ident| { + let register_fn = syn::Ident::new(&format!("register_{}", vtab_ident), vtab_ident.span()); + quote! { + { + let result = unsafe{ #vtab_ident::#register_fn(api)}; + if result == ::limbo_ext::ResultCode::OK { + let result = <#vtab_ident as ::limbo_ext::VTabModule>::connect(api); + return result; + } else { + return result; + } + } + } + }); let static_aggregates = aggregate_calls.clone(); let static_scalars = scalar_calls.clone(); + let static_vtabs = vtab_calls.clone(); let expanded = quote! { #[cfg(not(target_family = "wasm"))] @@ -404,20 +517,23 @@ pub fn register_extension(input: TokenStream) -> TokenStream { #(#static_aggregates)* + #(#static_vtabs)* + ::limbo_ext::ResultCode::OK } #[cfg(not(feature = "static"))] - #[no_mangle] - pub unsafe extern "C" fn register_extension(api: &::limbo_ext::ExtensionApi) -> ::limbo_ext::ResultCode { - let api = unsafe { &*api }; - #(#scalar_calls)* + #[no_mangle] + pub unsafe extern "C" fn register_extension(api: &::limbo_ext::ExtensionApi) -> ::limbo_ext::ResultCode { + let api = unsafe { &*api }; + #(#scalar_calls)* - #(#aggregate_calls)* + #(#aggregate_calls)* - ::limbo_ext::ResultCode::OK - } + #(#vtab_calls)* + + ::limbo_ext::ResultCode::OK + } }; - TokenStream::from(expanded) } From 661c74e338287876b930d717bfd456e8400a1644 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sat, 1 Feb 2025 18:51:27 -0500 Subject: [PATCH 11/16] Apply new planner structure to virtual table impl --- Cargo.lock | 2 ++ core/Cargo.toml | 2 ++ core/ext/mod.rs | 5 ++++ core/util.rs | 5 ++-- core/vdbe/mod.rs | 1 + extensions/core/src/lib.rs | 2 +- extensions/series/Cargo.toml | 8 +++++- extensions/series/src/lib.rs | 53 ++++++++++++++---------------------- macros/src/lib.rs | 24 +++++++++++++--- 9 files changed, 61 insertions(+), 41 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 26cd80646..2c0eb8db1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1616,6 +1616,7 @@ dependencies = [ "limbo_macros", "limbo_percentile", "limbo_regexp", + "limbo_series", "limbo_time", "limbo_uuid", "log", @@ -1707,6 +1708,7 @@ version = "0.0.14" dependencies = [ "limbo_ext", "log", + "mimalloc", ] [[package]] diff --git a/core/Cargo.toml b/core/Cargo.toml index 386bf01c7..687f4ff19 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -27,6 +27,7 @@ percentile = ["limbo_percentile/static"] regexp = ["limbo_regexp/static"] time = ["limbo_time/static"] crypto = ["limbo_crypto/static"] +series = ["limbo_series/static"] [target.'cfg(target_os = "linux")'.dependencies] io-uring = { version = "0.6.1", optional = true } @@ -67,6 +68,7 @@ limbo_regexp = { path = "../extensions/regexp", optional = true, features = ["st limbo_percentile = { path = "../extensions/percentile", optional = true, features = ["static"] } limbo_time = { path = "../extensions/time", optional = true, features = ["static"] } limbo_crypto = { path = "../extensions/crypto", optional = true, features = ["static"] } +limbo_series = { path = "../extensions/series", optional = true, features = ["static"] } miette = "7.4.0" strum = "0.26" parking_lot = "0.12.3" diff --git a/core/ext/mod.rs b/core/ext/mod.rs index 6d034e313..c4b6006e3 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -131,6 +131,7 @@ impl Database { return ResultCode::Error; }; let vtab_module = self.vtab_modules.get(name).unwrap().clone(); + let vtab = VirtualTable { name: name.to_string(), implementation: vtab_module, @@ -172,6 +173,10 @@ impl Database { if unsafe { !limbo_crypto::register_extension_static(&ext_api).is_ok() } { return Err("Failed to register crypto extension".to_string()); } + #[cfg(feature = "series")] + if unsafe { !limbo_series::register_extension_static(&ext_api).is_ok() } { + return Err("Failed to register series extension".to_string()); + } Ok(()) } } diff --git a/core/util.rs b/core/util.rs index 5251b36cf..654951700 100644 --- a/core/util.rs +++ b/core/util.rs @@ -1,7 +1,6 @@ +use sqlite3_parser::ast::{self, CreateTableBody, Expr, FunctionTail, Literal}; use std::{rc::Rc, sync::Arc}; -use sqlite3_parser::ast::{CreateTableBody, Expr, FunctionTail, Literal}; - use crate::{ schema::{self, Column, Schema, Type}, Result, Statement, StepResult, IO, @@ -308,7 +307,7 @@ pub fn exprs_are_equivalent(expr1: &Expr, expr2: &Expr) -> bool { } } -pub fn columns_from_create_table_body(body: CreateTableBody) -> Result, ()> { +pub fn columns_from_create_table_body(body: ast::CreateTableBody) -> Result, ()> { let CreateTableBody::ColumnsAndConstraints { columns, .. } = body else { return Err(()); }; diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index c25dc572b..5003b72c6 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -280,6 +280,7 @@ fn get_cursor_as_virtual_mut<'long, 'short>( .as_virtual_mut(); cursor } + struct Bitfield([u64; N]); impl Bitfield { diff --git a/extensions/core/src/lib.rs b/extensions/core/src/lib.rs index fec363c44..0e550fca1 100644 --- a/extensions/core/src/lib.rs +++ b/extensions/core/src/lib.rs @@ -99,8 +99,8 @@ pub type VtabFnEof = unsafe extern "C" fn(cursor: *mut c_void) -> bool; pub trait VTabModule: 'static { type VCursor: VTabCursor; + const NAME: &'static str; - fn name() -> &'static str; fn connect(api: &ExtensionApi) -> ResultCode; fn open() -> Self::VCursor; fn filter(cursor: &mut Self::VCursor, arg_count: i32, args: &[Value]) -> ResultCode; diff --git a/extensions/series/Cargo.toml b/extensions/series/Cargo.toml index 73a634ac7..cca322294 100644 --- a/extensions/series/Cargo.toml +++ b/extensions/series/Cargo.toml @@ -6,10 +6,16 @@ edition.workspace = true license.workspace = true repository.workspace = true +[features] +static = ["limbo_ext/static"] + [lib] crate-type = ["cdylib", "lib"] [dependencies] -limbo_ext = { path = "../core"} +limbo_ext = { path = "../core", features = ["static"] } log = "0.4.20" + +[target.'cfg(not(target_family = "wasm"))'.dependencies] +mimalloc = { version = "*", default-features = false } diff --git a/extensions/series/src/lib.rs b/extensions/series/src/lib.rs index f438c6fce..63b6c6227 100644 --- a/extensions/series/src/lib.rs +++ b/extensions/series/src/lib.rs @@ -1,21 +1,27 @@ use limbo_ext::{ register_extension, ExtensionApi, ResultCode, VTabCursor, VTabModule, VTabModuleDerive, Value, - ValueType, }; register_extension! { vtabs: { GenerateSeriesVTab } } +macro_rules! try_option { + ($expr:expr, $err:expr) => { + match $expr { + Some(val) => val, + None => return $err, + } + }; +} + /// A virtual table that generates a sequence of integers #[derive(Debug, VTabModuleDerive)] struct GenerateSeriesVTab; impl VTabModule for GenerateSeriesVTab { type VCursor = GenerateSeriesCursor; - fn name() -> &'static str { - "generate_series" - } + const NAME: &'static str = "generate_series"; fn connect(api: &ExtensionApi) -> ResultCode { // Create table schema @@ -25,8 +31,7 @@ impl VTabModule for GenerateSeriesVTab { stop INTEGER HIDDEN, step INTEGER HIDDEN )"; - let name = Self::name(); - api.declare_virtual_table(name, sql) + api.declare_virtual_table(Self::NAME, sql) } fn open() -> Self::VCursor { @@ -43,35 +48,19 @@ impl VTabModule for GenerateSeriesVTab { if arg_count == 0 || arg_count > 3 { return ResultCode::InvalidArgs; } - let start = { - if args[0].value_type() == ValueType::Integer { - args[0].to_integer().unwrap() - } else { - return ResultCode::InvalidArgs; - } - }; - let stop = if args.len() == 1 { - i64::MAX - } else { - if args[1].value_type() == ValueType::Integer { - args[1].to_integer().unwrap() - } else { - return ResultCode::InvalidArgs; - } - }; - let step = if args.len() <= 2 { - 1 - } else { - if args[2].value_type() == ValueType::Integer { - args[2].to_integer().unwrap() - } else { - return ResultCode::InvalidArgs; - } - }; + let start = try_option!(args[0].to_integer(), ResultCode::InvalidArgs); + let stop = try_option!( + args.get(1).map(|v| v.to_integer().unwrap_or(i64::MAX)), + ResultCode::InvalidArgs + ); + let step = try_option!( + args.get(2).map(|v| v.to_integer().unwrap_or(1)), + ResultCode::InvalidArgs + ); cursor.start = start; cursor.current = start; - cursor.stop = stop; cursor.step = step; + cursor.stop = stop; ResultCode::OK } diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 6b0df9679..8dee8dc66 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -343,6 +343,9 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { unsafe extern "C" fn #connect_fn_name( db: *const ::std::ffi::c_void, ) -> ::limbo_ext::ResultCode { + if db.is_null() { + return ::limbo_ext::ResultCode::Error; + } let api = unsafe { &*(db as *const ExtensionApi) }; <#struct_name as ::limbo_ext::VTabModule>::connect(api) } @@ -360,6 +363,9 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { argc: i32, argv: *const ::limbo_ext::Value, ) -> ::limbo_ext::ResultCode { + if cursor.is_null() { + return ::limbo_ext::ResultCode::Error; + } let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; let args = std::slice::from_raw_parts(argv, argc as usize); <#struct_name as ::limbo_ext::VTabModule>::filter(cursor, argc, args) @@ -370,6 +376,9 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { cursor: *mut ::std::ffi::c_void, idx: u32, ) -> ::limbo_ext::Value { + if cursor.is_null() { + return ::limbo_ext::Value::error(ResultCode::Error); + } let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; <#struct_name as ::limbo_ext::VTabModule>::column(cursor, idx) } @@ -378,6 +387,9 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { unsafe extern "C" fn #next_fn_name( cursor: *mut ::std::ffi::c_void, ) -> ::limbo_ext::ResultCode { + if cursor.is_null() { + return ::limbo_ext::ResultCode::Error; + } let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; <#struct_name as ::limbo_ext::VTabModule>::next(cursor) } @@ -386,6 +398,9 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { unsafe extern "C" fn #eof_fn_name( cursor: *mut ::std::ffi::c_void, ) -> bool { + if cursor.is_null() { + return true; + } let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; <#struct_name as ::limbo_ext::VTabModule>::eof(cursor) } @@ -399,7 +414,7 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { } let api = &*api; - let name = <#struct_name as ::limbo_ext::VTabModule>::name(); + let name = <#struct_name as ::limbo_ext::VTabModule>::NAME; // name needs to be a c str FFI compatible, NOT CString let name_c = std::ffi::CString::new(name).unwrap(); @@ -493,9 +508,9 @@ pub fn register_extension(input: TokenStream) -> TokenStream { let result = unsafe{ #vtab_ident::#register_fn(api)}; if result == ::limbo_ext::ResultCode::OK { let result = <#vtab_ident as ::limbo_ext::VTabModule>::connect(api); - return result; - } else { - return result; + if !result.is_ok() { + return result; + } } } } @@ -535,5 +550,6 @@ pub fn register_extension(input: TokenStream) -> TokenStream { ::limbo_ext::ResultCode::OK } }; + TokenStream::from(expanded) } From d4c06545e14f723e34f2bacdc7b261eb5ceea93f Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sun, 2 Feb 2025 20:18:36 -0500 Subject: [PATCH 12/16] Refactor vtable impl and remove Rc Refcell from module --- core/lib.rs | 4 +- core/schema.rs | 20 +++------ core/translate/expr.rs | 4 +- core/translate/optimizer.rs | 4 +- core/translate/planner.rs | 2 - core/translate/select.rs | 6 +-- extensions/core/src/lib.rs | 5 +++ extensions/series/src/lib.rs | 13 ++++++ macros/src/lib.rs | 83 ++++++++++++++++++++++++++++++++++++ 9 files changed, 118 insertions(+), 23 deletions(-) diff --git a/core/lib.rs b/core/lib.rs index ccc2c2273..f381a9fce 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -363,7 +363,7 @@ impl Connection { pub fn execute(self: &Rc, sql: impl AsRef) -> Result<()> { let sql = sql.as_ref(); let db = &self.db; - let syms: &SymbolTable = &db.syms.borrow(); + let syms: &SymbolTable = &db.syms.borrow_mut(); let mut parser = Parser::new(sql.as_bytes()); let cmd = parser.next()?; if let Some(cmd) = cmd { @@ -417,7 +417,7 @@ impl Connection { #[cfg(not(target_family = "wasm"))] pub fn load_extension>(&self, path: P) -> Result<()> { - Database::load_extension(self.db.as_ref(), path) + Database::load_extension(&self.db, path) } /// Close a connection and checkpoint. diff --git a/core/schema.rs b/core/schema.rs index e7688b58e..f4a6aee2b 100644 --- a/core/schema.rs +++ b/core/schema.rs @@ -68,20 +68,11 @@ impl Table { } } - pub fn get_column_at(&self, index: usize) -> &Column { + pub fn get_column_at(&self, index: usize) -> Option<&Column> { match self { - Self::BTree(table) => table - .columns - .get(index) - .expect("column index out of bounds"), - Self::Pseudo(table) => table - .columns - .get(index) - .expect("column index out of bounds"), - Self::Virtual(table) => table - .columns - .get(index) - .expect("column index out of bounds"), + Self::BTree(table) => table.columns.get(index), + Self::Pseudo(table) => table.columns.get(index), + Self::Virtual(table) => table.columns.get(index), } } @@ -100,6 +91,7 @@ impl Table { Self::Virtual(_) => None, } } + pub fn virtual_table(&self) -> Option> { match self { Self::Virtual(table) => Some(table.clone()), @@ -172,7 +164,7 @@ impl BTreeTable { sql.push_str(",\n"); } sql.push_str(" "); - sql.push_str(&column.name.as_ref().expect("column name is None")); + sql.push_str(column.name.as_ref().expect("column name is None")); sql.push(' '); sql.push_str(&column.ty.to_string()); } diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 8b2e70185..bef18c9f2 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -1839,7 +1839,9 @@ pub fn translate_expr( dest: target_register, }); } - let column = table_reference.table.get_column_at(*column); + let Some(column) = table_reference.table.get_column_at(*column) else { + crate::bail_parse_error!("column index out of bounds"); + }; maybe_apply_affinity(column.ty, target_register, program); Ok(target_register) } diff --git a/core/translate/optimizer.rs b/core/translate/optimizer.rs index 73124060f..99de57398 100644 --- a/core/translate/optimizer.rs +++ b/core/translate/optimizer.rs @@ -307,7 +307,9 @@ impl Optimizable for ast::Expr { else { return Ok(None); }; - let column = table_reference.table.get_column_at(*column); + let Some(column) = table_reference.table.get_column_at(*column) else { + return Ok(None); + }; for index in available_indexes_for_table.iter() { if let Some(name) = column.name.as_ref() { if &index.columns.first().unwrap().name == name { diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 272b788a2..dcde7dc62 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -1,5 +1,3 @@ -use std::rc::Rc; - use super::{ plan::{ Aggregate, JoinInfo, Operation, Plan, ResultSetColumn, SelectQueryType, TableReference, diff --git a/core/translate/select.rs b/core/translate/select.rs index 2940cfca6..2a055afd2 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -28,9 +28,9 @@ pub fn translate_select( let mut program = ProgramBuilder::new(ProgramBuilderOpts { query_mode, - num_cursors: count_plan_required_cursors(&select), - approx_num_insns: estimate_num_instructions(&select), - approx_num_labels: estimate_num_labels(&select), + num_cursors: count_plan_required_cursors(select), + approx_num_insns: estimate_num_instructions(select), + approx_num_labels: estimate_num_labels(select), }); emit_program(&mut program, select_plan, syms)?; Ok(program) diff --git a/extensions/core/src/lib.rs b/extensions/core/src/lib.rs index 0e550fca1..30ddece57 100644 --- a/extensions/core/src/lib.rs +++ b/extensions/core/src/lib.rs @@ -110,10 +110,15 @@ pub trait VTabModule: 'static { } pub trait VTabCursor: Sized { + type Error; fn rowid(&self) -> i64; fn column(&self, idx: u32) -> Value; fn eof(&self) -> bool; fn next(&mut self) -> ResultCode; + fn set_error(&mut self, error: Self::Error); + fn error(&self) -> Option { + None + } } #[repr(C)] diff --git a/extensions/series/src/lib.rs b/extensions/series/src/lib.rs index 63b6c6227..ef278d451 100644 --- a/extensions/series/src/lib.rs +++ b/extensions/series/src/lib.rs @@ -40,12 +40,14 @@ impl VTabModule for GenerateSeriesVTab { stop: 0, step: 0, current: 0, + error: None, } } fn filter(cursor: &mut Self::VCursor, arg_count: i32, args: &[Value]) -> ResultCode { // args are the start, stop, and step if arg_count == 0 || arg_count > 3 { + cursor.set_error("Expected between 1 and 3 arguments"); return ResultCode::InvalidArgs; } let start = try_option!(args[0].to_integer(), ResultCode::InvalidArgs); @@ -84,6 +86,7 @@ struct GenerateSeriesCursor { stop: i64, step: i64, current: i64, + error: Option<&'static str>, } impl GenerateSeriesCursor { @@ -101,6 +104,8 @@ impl GenerateSeriesCursor { } impl VTabCursor for GenerateSeriesCursor { + type Error = &'static str; + fn next(&mut self) -> ResultCode { GenerateSeriesCursor::next(self) } @@ -119,6 +124,14 @@ impl VTabCursor for GenerateSeriesCursor { } } + fn error(&self) -> Option { + self.error + } + + fn set_error(&mut self, err: &'static str) { + self.error = Some(err); + } + fn rowid(&self) -> i64 { ((self.current - self.start) / self.step) + 1 } diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 8dee8dc66..56d019525 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -324,6 +324,89 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { TokenStream::from(expanded) } +/// Macro to derive a VTabModule for your extension. This macro will generate +/// the necessary functions to register your module with core. You must implement +/// the VTabModule trait for your struct, and the VTabCursor trait for your cursor. +/// ```ignore +///#[derive(Debug, VTabModuleDerive)] +///struct CsvVTab; +///impl VTabModule for CsvVTab { +/// type VCursor = CsvCursor; +/// const NAME: &'static str = "csv_data"; +/// +/// /// Declare the schema for your virtual table +/// fn connect(api: &ExtensionApi) -> ResultCode { +/// let sql = "CREATE TABLE csv_data( +/// name TEXT, +/// age TEXT, +/// city TEXT +/// )"; +/// api.declare_virtual_table(Self::NAME, sql) +/// } +/// /// Open the virtual table and return a cursor +/// fn open() -> Self::VCursor { +/// let csv_content = fs::read_to_string("data.csv").unwrap_or_default(); +/// let rows: Vec> = csv_content +/// .lines() +/// .skip(1) +/// .map(|line| { +/// line.split(',') +/// .map(|s| s.trim().to_string()) +/// .collect() +/// }) +/// .collect(); +/// CsvCursor { rows, index: 0 } +/// } +/// /// Filter the virtual table based on arguments (omitted here for simplicity) +/// fn filter(_cursor: &mut Self::VCursor, _arg_count: i32, _args: &[Value]) -> ResultCode { +/// ResultCode::OK +/// } +/// /// Return the value for a given column index +/// fn column(cursor: &Self::VCursor, idx: u32) -> Value { +/// cursor.column(idx) +/// } +/// /// Move the cursor to the next row +/// fn next(cursor: &mut Self::VCursor) -> ResultCode { +/// if cursor.index < cursor.rows.len() - 1 { +/// cursor.index += 1; +/// ResultCode::OK +/// } else { +/// ResultCode::EOF +/// } +/// } +/// fn eof(cursor: &Self::VCursor) -> bool { +/// cursor.index >= cursor.rows.len() +/// } +/// #[derive(Debug)] +/// struct CsvCursor { +/// rows: Vec>, +/// index: usize, +/// +/// impl CsvCursor { +/// /// Returns the value for a given column index. +/// fn column(&self, idx: u32) -> Value { +/// let row = &self.rows[self.index]; +/// if (idx as usize) < row.len() { +/// Value::from_text(&row[idx as usize]) +/// } else { +/// Value::null() +/// } +/// } +/// // Implement the VTabCursor trait for your virtual cursor +/// impl VTabCursor for CsvCursor { +/// fn next(&mut self) -> ResultCode { +/// Self::next(self) +/// } +/// fn eof(&self) -> bool { +/// self.index >= self.rows.len() +/// } +/// fn column(&self, idx: u32) -> Value { +/// self.column(idx) +/// } +/// fn rowid(&self) -> i64 { +/// self.index as i64 +/// } + #[proc_macro_derive(VTabModuleDerive)] pub fn derive_vtab_module(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as DeriveInput); From ad30ccdc0e43a99465a32e6997ccd811480b0cbe Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sun, 2 Feb 2025 20:20:17 -0500 Subject: [PATCH 13/16] Add docs in extension README for vtable modules --- core/lib.rs | 4 +- extensions/core/README.md | 119 +++++++++++++++++++++++++++++++++-- extensions/series/src/lib.rs | 7 +-- macros/src/lib.rs | 2 +- 4 files changed, 119 insertions(+), 13 deletions(-) diff --git a/core/lib.rs b/core/lib.rs index f381a9fce..0c10003ac 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -26,8 +26,8 @@ use fallible_iterator::FallibleIterator; #[cfg(not(target_family = "wasm"))] use libloading::{Library, Symbol}; #[cfg(not(target_family = "wasm"))] -use limbo_ext::{ExtensionApi, ExtensionEntryPoint, ResultCode}; -use limbo_ext::{VTabModuleImpl, Value as ExtValue}; +use limbo_ext::{ExtensionApi, ExtensionEntryPoint}; +use limbo_ext::{ResultCode, VTabModuleImpl, Value as ExtValue}; use log::trace; use parking_lot::RwLock; use schema::{Column, Schema}; diff --git a/extensions/core/README.md b/extensions/core/README.md index bcb7ff86f..6dd187122 100644 --- a/extensions/core/README.md +++ b/extensions/core/README.md @@ -9,7 +9,8 @@ like traditional `sqlite3` extensions, but are able to be written in much more e - [ x ] **Scalar Functions**: Create scalar functions using the `scalar` macro. - [ x ] **Aggregate Functions**: Define aggregate functions with `AggregateDerive` macro and `AggFunc` trait. - - [] **Virtual tables**: TODO + - [ x ] **Virtual tables**: Create a module for a virtual table with the `VTabModuleDerive` macro and `VTabCursor` trait. + - [] **VFS Modules** --- ## Installation @@ -17,24 +18,32 @@ like traditional `sqlite3` extensions, but are able to be written in much more e Add the crate to your `Cargo.toml`: ```toml + +[features] +static = ["limbo_ext/static"] + [dependencies] -limbo_ext = { path = "path/to/limbo/extensions/core" } # temporary until crate is published +limbo_ext = { path = "path/to/limbo/extensions/core", features = ["static"] } # temporary until crate is published + # mimalloc is required if you intend on linking dynamically. It is imported for you by the register_extension # macro, so no configuration is needed. But it must be added to your Cargo.toml [target.'cfg(not(target_family = "wasm"))'.dependencies] mimalloc = { version = "*", default-features = false } -``` -**NOTE** Crate must be of type `cdylib` if you wish to link dynamically -``` +# NOTE: Crate must be of type `cdylib` if you wish to link dynamically [lib] crate-type = ["cdylib", "lib"] ``` -`cargo build` will output a shared library that can be loaded with `.load target/debug/libyour_crate_name` +`cargo build` will output a shared library that can be loaded by the following options: +#### **CLI:** + `.load target/debug/libyour_crate_name` + +#### **SQL:** + `SELECT load_extension('target/debug/libyour_crate_name')` Extensions can be registered with the `register_extension!` macro: @@ -44,6 +53,7 @@ Extensions can be registered with the `register_extension!` macro: register_extension!{ scalars: { double }, // name of your function, if different from attribute name aggregates: { Percentile }, + vtabs: { CsvVTable }, } ``` @@ -140,4 +150,101 @@ impl AggFunc for Percentile { } ``` +### Virtual Table Example: +```rust + +/// Example: A virtual table that operates on a CSV file as a database table. +/// This example assumes that the CSV file is located at "data.csv" in the current directory. +#[derive(Debug, VTabModuleDerive)] +struct CsvVTable; + +impl VTabModule for CsvVTable { + type VCursor = CsvCursor; + /// Declare the name for your virtual table + const NAME: &'static str = "csv_data"; + + /// Declare the table schema and call `api.declare_virtual_table` with the schema sql. + fn connect(api: &ExtensionApi) -> ResultCode { + let sql = "CREATE TABLE csv_data( + name TEXT, + age TEXT, + city TEXT + )"; + api.declare_virtual_table(Self::NAME, sql) + } + + /// Open to return a new cursor: In this simple example, the CSV file is read completely into memory on connect. + fn open() -> Self::VCursor { + // Read CSV file contents from "data.csv" + let csv_content = fs::read_to_string("data.csv").unwrap_or_default(); + // For simplicity, we'll ignore the header row. + let rows: Vec> = csv_content + .lines() + .skip(1) + .map(|line| { + line.split(',') + .map(|s| s.trim().to_string()) + .collect() + }) + .collect(); + CsvCursor { rows, index: 0 } + } + + /// Filter through result columns. (not used in this simple example) + fn filter(_cursor: &mut Self::VCursor, _arg_count: i32, _args: &[Value]) -> ResultCode { + ResultCode::OK + } + + /// Return the value for the column at the given index in the current row. + fn column(cursor: &Self::VCursor, idx: u32) -> Value { + cursor.column(idx) + } + + /// Next advances the cursor to the next row. + fn next(cursor: &mut Self::VCursor) -> ResultCode { + if cursor.index < cursor.rows.len() - 1 { + cursor.index += 1; + ResultCode::OK + } else { + ResultCode::EOF + } + } + + /// Return true if the cursor is at the end. + fn eof(cursor: &Self::VCursor) -> bool { + cursor.index >= cursor.rows.len() + } +} + +/// The cursor for iterating over CSV rows. +#[derive(Debug)] +struct CsvCursor { + rows: Vec>, + index: usize, +} + +/// Implement the VTabCursor trait for your cursor type +impl VTabCursor for CsvCursor { + fn next(&mut self) -> ResultCode { + CsvCursor::next(self) + } + + fn eof(&self) -> bool { + self.index >= self.rows.len() + } + + fn column(&self, idx: u32) -> Value { + let row = &self.rows[self.index]; + if (idx as usize) < row.len() { + Value::from_text(&row[idx as usize]) + } else { + Value::null() + } + } + + fn rowid(&self) -> i64 { + self.index as i64 + } +} +``` diff --git a/extensions/series/src/lib.rs b/extensions/series/src/lib.rs index ef278d451..fdd84c3b2 100644 --- a/extensions/series/src/lib.rs +++ b/extensions/series/src/lib.rs @@ -47,7 +47,6 @@ impl VTabModule for GenerateSeriesVTab { fn filter(cursor: &mut Self::VCursor, arg_count: i32, args: &[Value]) -> ResultCode { // args are the start, stop, and step if arg_count == 0 || arg_count > 3 { - cursor.set_error("Expected between 1 and 3 arguments"); return ResultCode::InvalidArgs; } let start = try_option!(args[0].to_integer(), ResultCode::InvalidArgs); @@ -86,7 +85,7 @@ struct GenerateSeriesCursor { stop: i64, step: i64, current: i64, - error: Option<&'static str>, + error: Option, } impl GenerateSeriesCursor { @@ -104,7 +103,7 @@ impl GenerateSeriesCursor { } impl VTabCursor for GenerateSeriesCursor { - type Error = &'static str; + type Error = ResultCode; fn next(&mut self) -> ResultCode { GenerateSeriesCursor::next(self) @@ -128,7 +127,7 @@ impl VTabCursor for GenerateSeriesCursor { self.error } - fn set_error(&mut self, err: &'static str) { + fn set_error(&mut self, err: ResultCode) { self.error = Some(err); } diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 56d019525..632b95615 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -406,7 +406,7 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream { /// fn rowid(&self) -> i64 { /// self.index as i64 /// } - +/// #[proc_macro_derive(VTabModuleDerive)] pub fn derive_vtab_module(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as DeriveInput); From a8ae95716268162e313762599a5f35c90194a779 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Wed, 5 Feb 2025 13:35:24 -0500 Subject: [PATCH 14/16] Add tests for series extension, finish initial vtable impl --- core/lib.rs | 6 +- extensions/core/src/lib.rs | 2 +- extensions/series/src/lib.rs | 21 ++---- macros/src/lib.rs | 3 + testing/extensions.py | 134 ++++++++++++++++++++++------------- 5 files changed, 98 insertions(+), 68 deletions(-) diff --git a/core/lib.rs b/core/lib.rs index 0c10003ac..41e8a6cd6 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -272,8 +272,8 @@ impl Connection { let sql = sql.as_ref(); trace!("Preparing: {}", sql); let db = &self.db; - let syms: &SymbolTable = &db.syms.borrow(); let mut parser = Parser::new(sql.as_bytes()); + let syms = &db.syms.borrow(); let cmd = parser.next()?; if let Some(cmd) = cmd { match cmd { @@ -363,7 +363,7 @@ impl Connection { pub fn execute(self: &Rc, sql: impl AsRef) -> Result<()> { let sql = sql.as_ref(); let db = &self.db; - let syms: &SymbolTable = &db.syms.borrow_mut(); + let syms: &SymbolTable = &db.syms.borrow(); let mut parser = Parser::new(sql.as_bytes()); let cmd = parser.next()?; if let Some(cmd) = cmd { @@ -551,7 +551,7 @@ impl VirtualTable { }; match rc { ResultCode::OK => Ok(()), - _ => Err(LimboError::ExtensionError("Filter failed".to_string())), + _ => Err(LimboError::ExtensionError(rc.to_string())), } } diff --git a/extensions/core/src/lib.rs b/extensions/core/src/lib.rs index 30ddece57..805051079 100644 --- a/extensions/core/src/lib.rs +++ b/extensions/core/src/lib.rs @@ -110,7 +110,7 @@ pub trait VTabModule: 'static { } pub trait VTabCursor: Sized { - type Error; + type Error: std::fmt::Display; fn rowid(&self) -> i64; fn column(&self, idx: u32) -> Value; fn eof(&self) -> bool; diff --git a/extensions/series/src/lib.rs b/extensions/series/src/lib.rs index fdd84c3b2..9732c909d 100644 --- a/extensions/series/src/lib.rs +++ b/extensions/series/src/lib.rs @@ -88,25 +88,16 @@ struct GenerateSeriesCursor { error: Option, } -impl GenerateSeriesCursor { - fn next(&mut self) -> ResultCode { - let current = self.current; - - // Check if we've reached the end - if (self.step > 0 && current >= self.stop) || (self.step < 0 && current <= self.stop) { - return ResultCode::EOF; - } - - self.current = current.saturating_add(self.step); - ResultCode::OK - } -} - impl VTabCursor for GenerateSeriesCursor { type Error = ResultCode; fn next(&mut self) -> ResultCode { - GenerateSeriesCursor::next(self) + let next_val = self.current.saturating_add(self.step); + if (self.step > 0 && next_val > self.stop) || (self.step < 0 && next_val < self.stop) { + return ResultCode::EOF; + } + self.current = next_val; + ResultCode::OK } fn eof(&self) -> bool { diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 632b95615..2c88776f7 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -463,6 +463,9 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { return ::limbo_ext::Value::error(ResultCode::Error); } let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; + if let Some(err) = <#struct_name as ::limbo_ext::VTabModule>::VCursor::error(cursor) { + return ::limbo_ext::Value::error_with_message(err.to_string()); + } <#struct_name as ::limbo_ext::VTabModule>::column(cursor, idx) } diff --git a/testing/extensions.py b/testing/extensions.py index d4a0a69c0..f2099d101 100755 --- a/testing/extensions.py +++ b/testing/extensions.py @@ -110,14 +110,19 @@ def validate_blob(result): # and assert they are valid hex digits return int(result, 16) is not None + def validate_string_uuid(result): return len(result) == 36 and result.count("-") == 4 -def returns_error(result): +def returns_error_no_func(result): return "error: no such function: " in result +def returns_vtable_parse_err(result): + return "Parse error: Virtual table" in result + + def returns_null(result): return result == "" or result == "\n" @@ -129,6 +134,7 @@ def assert_now_unixtime(result): def assert_specific_time(result): return result == "1736720789" + def test_uuid(pipe): specific_time = "01945ca0-3189-76c0-9a8f-caf310fc8b8e" # these are built into the binary, so we just test they work @@ -165,7 +171,7 @@ def test_regexp(pipe): extension_path = "./target/debug/liblimbo_regexp.so" # before extension loads, assert no function - run_test(pipe, "SELECT regexp('a.c', 'abc');", returns_error) + run_test(pipe, "SELECT regexp('a.c', 'abc');", returns_error_no_func) run_test(pipe, f".load {extension_path}", returns_null) print(f"Extension {extension_path} loaded successfully.") run_test(pipe, "SELECT regexp('a.c', 'abc');", validate_true) @@ -205,13 +211,14 @@ def validate_percentile2(res): def validate_percentile_disc(res): return res == "40.0" + def test_aggregates(pipe): extension_path = "./target/debug/liblimbo_percentile.so" # assert no function before extension loads run_test( pipe, "SELECT median(1);", - returns_error, + returns_error_no_func, "median agg function returns null when ext not loaded", ) run_test( @@ -252,63 +259,55 @@ def test_aggregates(pipe): pipe, "SELECT percentile_disc(value, 0.55) from test;", validate_percentile_disc ) -# Hashes -def validate_blake3(a): - return a == "6437b3ac38465133ffb63b75273a8db548c558465d79db03fd359c6cd5bd9d85" - -def validate_md5(a): - return a == "900150983cd24fb0d6963f7d28e17f72" - -def validate_sha1(a): - return a == "a9993e364706816aba3e25717850c26c9cd0d89d" - -def validate_sha256(a): - return a == "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad" - -def validate_sha384(a): - return a == "cb00753f45a35e8bb5a03d699ac65007272c32ab0eded1631a8b605a43ff5bed8086072ba1e7cc2358baeca134c825a7" - -def validate_sha512(a): - return a == "ddaf35a193617abacc417349ae20413112e6fa4e89a97ea20a9eeee64b55d39a2192992a274fc1a836ba3c23a3feebbd454d4423643ce80e2a9ac94fa54ca49f" # Encoders and decoders def validate_url_encode(a): - return a == f"%2Fhello%3Ftext%3D%28%E0%B2%A0_%E0%B2%A0%29" + return a == "%2Fhello%3Ftext%3D%28%E0%B2%A0_%E0%B2%A0%29" + def validate_url_decode(a): return a == "/hello?text=(ಠ_ಠ)" + def validate_hex_encode(a): return a == "68656c6c6f" + def validate_hex_decode(a): return a == "hello" + def validate_base85_encode(a): return a == "BOu!rDZ" + def validate_base85_decode(a): return a == "hello" + def validate_base32_encode(a): return a == "NBSWY3DP" + def validate_base32_decode(a): return a == "hello" + def validate_base64_encode(a): return a == "aGVsbG8=" + def validate_base64_decode(a): return a == "hello" + def test_crypto(pipe): extension_path = "./target/debug/liblimbo_crypto.so" # assert no function before extension loads run_test( pipe, "SELECT crypto_blake('a');", - returns_error, + lambda res: "Error" in res, "crypto_blake3 returns null when ext not loaded", ) run_test( @@ -321,104 +320,139 @@ def test_crypto(pipe): run_test( pipe, "SELECT crypto_encode(crypto_blake3('abc'), 'hex');", - validate_blake3, - "blake3 should encrypt correctly" + lambda res: res + == "6437b3ac38465133ffb63b75273a8db548c558465d79db03fd359c6cd5bd9d85", + "blake3 should encrypt correctly", ) run_test( pipe, "SELECT crypto_encode(crypto_md5('abc'), 'hex');", - validate_md5, - "md5 should encrypt correctly" + lambda res: res == "900150983cd24fb0d6963f7d28e17f72", + "md5 should encrypt correctly", ) run_test( pipe, "SELECT crypto_encode(crypto_sha1('abc'), 'hex');", - validate_sha1, - "sha1 should encrypt correctly" + lambda res: res == "a9993e364706816aba3e25717850c26c9cd0d89d", + "sha1 should encrypt correctly", ) run_test( pipe, "SELECT crypto_encode(crypto_sha256('abc'), 'hex');", - validate_sha256, - "sha256 should encrypt correctly" + lambda a: a + == "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad", + "sha256 should encrypt correctly", ) run_test( pipe, "SELECT crypto_encode(crypto_sha384('abc'), 'hex');", - validate_sha384, - "sha384 should encrypt correctly" + lambda a: a + == "cb00753f45a35e8bb5a03d699ac65007272c32ab0eded1631a8b605a43ff5bed8086072ba1e7cc2358baeca134c825a7", + "sha384 should encrypt correctly", ) run_test( pipe, "SELECT crypto_encode(crypto_sha512('abc'), 'hex');", - validate_sha512, - "sha512 should encrypt correctly" - ) + lambda a: a + == "ddaf35a193617abacc417349ae20413112e6fa4e89a97ea20a9eeee64b55d39a2192992a274fc1a836ba3c23a3feebbd454d4423643ce80e2a9ac94fa54ca49f", + "sha512 should encrypt correctly", + ) # Encoding and Decoding run_test( pipe, "SELECT crypto_encode('hello', 'base32');", validate_base32_encode, - "base32 should encode correctly" - ) + "base32 should encode correctly", + ) run_test( pipe, "SELECT crypto_decode('NBSWY3DP', 'base32');", validate_base32_decode, - "base32 should decode correctly" + "base32 should decode correctly", ) run_test( pipe, "SELECT crypto_encode('hello', 'base64');", validate_base64_encode, - "base64 should encode correctly" + "base64 should encode correctly", ) run_test( pipe, "SELECT crypto_decode('aGVsbG8=', 'base64');", validate_base64_decode, - "base64 should decode correctly" + "base64 should decode correctly", ) run_test( pipe, "SELECT crypto_encode('hello', 'base85');", validate_base85_encode, - "base85 should encode correctly" + "base85 should encode correctly", ) run_test( pipe, "SELECT crypto_decode('BOu!rDZ', 'base85');", validate_base85_decode, - "base85 should decode correctly" + "base85 should decode correctly", ) run_test( pipe, "SELECT crypto_encode('hello', 'hex');", validate_hex_encode, - "hex should encode correctly" + "hex should encode correctly", ) run_test( pipe, "SELECT crypto_decode('68656c6c6f', 'hex');", validate_hex_decode, - "hex should decode correctly" + "hex should decode correctly", ) - + run_test( pipe, "SELECT crypto_encode('/hello?text=(ಠ_ಠ)', 'url');", validate_url_encode, - "url should encode correctly" + "url should encode correctly", ) run_test( pipe, - f"SELECT crypto_decode('%2Fhello%3Ftext%3D%28%E0%B2%A0_%E0%B2%A0%29', 'url');", + "SELECT crypto_decode('%2Fhello%3Ftext%3D%28%E0%B2%A0_%E0%B2%A0%29', 'url');", validate_url_decode, - "url should decode correctly" + "url should decode correctly", ) + +def test_series(pipe): + ext_path = "./target/debug/liblimbo_series" + run_test( + pipe, + "SELECT * FROM generate_series(1, 10);", + lambda res: "Virtual table generate_series not found" in res, + ) + run_test(pipe, f".load {ext_path}", returns_null) + run_test( + pipe, + "SELECT * FROM generate_series(1, 10);", + lambda res: "Invalid Argument" in res, + ) + run_test( + pipe, + "SELECT * FROM generate_series(1, 10, 2);", + lambda res: res == "1\n3\n5\n7\n9", + ) + run_test( + pipe, + "SELECT * FROM generate_series(1, 10, 2, 3);", + lambda res: "Invalid Argument" in res, + ) + run_test( + pipe, + "SELECT * FROM generate_series(10, 1, -2);", + lambda res: res == "10\n8\n6\n4\n2", + ) + + def main(): pipe = init_limbo() try: @@ -426,6 +460,8 @@ def main(): test_uuid(pipe) test_aggregates(pipe) test_crypto(pipe) + test_series(pipe) + except Exception as e: print(f"Test FAILED: {e}") pipe.terminate() From cd83ac6146e56d603898f672c9c32efd697efe1e Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Wed, 5 Feb 2025 16:14:21 -0500 Subject: [PATCH 15/16] Remove error from vcursor trait in extensions --- extensions/core/src/lib.rs | 4 ---- extensions/series/src/lib.rs | 10 ---------- macros/src/lib.rs | 3 --- 3 files changed, 17 deletions(-) diff --git a/extensions/core/src/lib.rs b/extensions/core/src/lib.rs index 805051079..22d90f572 100644 --- a/extensions/core/src/lib.rs +++ b/extensions/core/src/lib.rs @@ -115,10 +115,6 @@ pub trait VTabCursor: Sized { fn column(&self, idx: u32) -> Value; fn eof(&self) -> bool; fn next(&mut self) -> ResultCode; - fn set_error(&mut self, error: Self::Error); - fn error(&self) -> Option { - None - } } #[repr(C)] diff --git a/extensions/series/src/lib.rs b/extensions/series/src/lib.rs index 9732c909d..83dd334ea 100644 --- a/extensions/series/src/lib.rs +++ b/extensions/series/src/lib.rs @@ -40,7 +40,6 @@ impl VTabModule for GenerateSeriesVTab { stop: 0, step: 0, current: 0, - error: None, } } @@ -85,7 +84,6 @@ struct GenerateSeriesCursor { stop: i64, step: i64, current: i64, - error: Option, } impl VTabCursor for GenerateSeriesCursor { @@ -114,14 +112,6 @@ impl VTabCursor for GenerateSeriesCursor { } } - fn error(&self) -> Option { - self.error - } - - fn set_error(&mut self, err: ResultCode) { - self.error = Some(err); - } - fn rowid(&self) -> i64 { ((self.current - self.start) / self.step) + 1 } diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 2c88776f7..632b95615 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -463,9 +463,6 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { return ::limbo_ext::Value::error(ResultCode::Error); } let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) }; - if let Some(err) = <#struct_name as ::limbo_ext::VTabModule>::VCursor::error(cursor) { - return ::limbo_ext::Value::error_with_message(err.to_string()); - } <#struct_name as ::limbo_ext::VTabModule>::column(cursor, idx) } From ae88d51e6fa7a41d18ab2e17a31dfbf4c256372c Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Wed, 5 Feb 2025 21:30:55 -0500 Subject: [PATCH 16/16] Remove TableReferenceType enum to clean up planner --- core/ext/mod.rs | 1 + core/lib.rs | 3 +- core/translate/delete.rs | 3 +- core/translate/expr.rs | 67 ++++++++++++++++--------------------- core/translate/main_loop.rs | 43 ++++++++++++------------ core/translate/plan.rs | 32 +++--------------- core/translate/planner.rs | 27 +++++++++------ testing/extensions.py | 2 +- 8 files changed, 76 insertions(+), 102 deletions(-) diff --git a/core/ext/mod.rs b/core/ext/mod.rs index c4b6006e3..67fd78491 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -136,6 +136,7 @@ impl Database { name: name.to_string(), implementation: vtab_module, columns, + args: None, }; self.syms.borrow_mut().vtabs.insert(name.to_string(), vtab); ResultCode::OK diff --git a/core/lib.rs b/core/lib.rs index 41e8a6cd6..8fde24402 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -514,6 +514,7 @@ pub type StepResult = vdbe::StepResult; #[derive(Clone, Debug)] pub struct VirtualTable { name: String, + args: Option>, pub implementation: Rc, columns: Vec, } @@ -537,7 +538,7 @@ impl VirtualTable { OwnedValue::Null => Ok(ExtValue::null()), OwnedValue::Integer(i) => Ok(ExtValue::from_integer(*i)), OwnedValue::Float(f) => Ok(ExtValue::from_float(*f)), - OwnedValue::Text(t) => Ok(ExtValue::from_text((*t.value).clone())), + OwnedValue::Text(t) => Ok(ExtValue::from_text(t.as_str().to_string())), OwnedValue::Blob(b) => Ok(ExtValue::from_blob((**b).clone())), other => Err(LimboError::ExtensionError(format!( "Unsupported value type: {:?}", diff --git a/core/translate/delete.rs b/core/translate/delete.rs index 675b58f34..ffad33d73 100644 --- a/core/translate/delete.rs +++ b/core/translate/delete.rs @@ -7,7 +7,7 @@ use crate::vdbe::builder::{ProgramBuilder, ProgramBuilderOpts, QueryMode}; use crate::{schema::Schema, Result, SymbolTable}; use sqlite3_parser::ast::{Expr, Limit, QualifiedName}; -use super::plan::{TableReference, TableReferenceType}; +use super::plan::TableReference; pub fn translate_delete( query_mode: QueryMode, @@ -48,7 +48,6 @@ pub fn prepare_delete_plan( identifier: table.name.clone(), op: Operation::Scan { iter_dir: None }, join_info: None, - reference_type: TableReferenceType::BTreeTable, }]; let mut where_predicates = vec![]; diff --git a/core/translate/expr.rs b/core/translate/expr.rs index bef18c9f2..c23cb053a 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -3,7 +3,7 @@ use sqlite3_parser::ast::{self, UnaryOperator}; #[cfg(feature = "json")] use crate::function::JsonFunc; use crate::function::{Func, FuncCtx, MathFuncArity, ScalarFunc, VectorFunc}; -use crate::schema::Type; +use crate::schema::{Table, Type}; use crate::util::normalize_ident; use crate::vdbe::{ builder::ProgramBuilder, @@ -13,7 +13,7 @@ use crate::vdbe::{ use crate::Result; use super::emitter::Resolver; -use super::plan::{Operation, TableReference, TableReferenceType}; +use super::plan::{Operation, TableReference}; #[derive(Debug, Clone, Copy)] pub struct ConditionMetadata { @@ -1823,49 +1823,38 @@ pub fn translate_expr( match table_reference.op { // If we are reading a column from a table, we find the cursor that corresponds to // the table and read the column from the cursor. - Operation::Scan { .. } | Operation::Search(_) => { - match &table_reference.reference_type { - TableReferenceType::BTreeTable => { - let cursor_id = program.resolve_cursor_id(&table_reference.identifier); - if *is_rowid_alias { - program.emit_insn(Insn::RowId { - cursor_id, - dest: target_register, - }); - } else { - program.emit_insn(Insn::Column { - cursor_id, - column: *column, - dest: target_register, - }); - } - let Some(column) = table_reference.table.get_column_at(*column) else { - crate::bail_parse_error!("column index out of bounds"); - }; - maybe_apply_affinity(column.ty, target_register, program); - Ok(target_register) - } - TableReferenceType::VirtualTable { .. } => { - let cursor_id = program.resolve_cursor_id(&table_reference.identifier); - program.emit_insn(Insn::VColumn { + Operation::Scan { .. } | Operation::Search(_) => match &table_reference.table { + Table::BTree(_) => { + let cursor_id = program.resolve_cursor_id(&table_reference.identifier); + if *is_rowid_alias { + program.emit_insn(Insn::RowId { + cursor_id, + dest: target_register, + }); + } else { + program.emit_insn(Insn::Column { cursor_id, column: *column, dest: target_register, }); - Ok(target_register) - } - TableReferenceType::Subquery { - result_columns_start_reg, - } => { - program.emit_insn(Insn::Copy { - src_reg: result_columns_start_reg + *column, - dst_reg: target_register, - amount: 0, - }); - Ok(target_register) } + let Some(column) = table_reference.table.get_column_at(*column) else { + crate::bail_parse_error!("column index out of bounds"); + }; + maybe_apply_affinity(column.ty, target_register, program); + Ok(target_register) } - } + Table::Virtual(_) => { + let cursor_id = program.resolve_cursor_id(&table_reference.identifier); + program.emit_insn(Insn::VColumn { + cursor_id, + column: *column, + dest: target_register, + }); + Ok(target_register) + } + _ => unreachable!(), + }, // If we are reading a column from a subquery, we instead copy the column from the // subquery's result registers. Operation::Subquery { diff --git a/core/translate/main_loop.rs b/core/translate/main_loop.rs index 4558693e3..35cc505c9 100644 --- a/core/translate/main_loop.rs +++ b/core/translate/main_loop.rs @@ -1,6 +1,7 @@ use sqlite3_parser::ast; use crate::{ + schema::Table, translate::result_row::emit_select_result, vdbe::{ builder::{CursorType, ProgramBuilder}, @@ -17,7 +18,7 @@ use super::{ order_by::{order_by_sorter_insert, sorter_insert}, plan::{ IterationDirection, Operation, Search, SelectPlan, SelectQueryType, TableReference, - TableReferenceType, WhereTerm, + WhereTerm, }, }; @@ -78,21 +79,18 @@ pub fn init_loop( } match &table.op { Operation::Scan { .. } => { - let ref_type = &table.reference_type; let cursor_id = program.alloc_cursor_id( Some(table.identifier.clone()), - match ref_type { - TableReferenceType::BTreeTable => { - CursorType::BTreeTable(table.btree().unwrap().clone()) - } - TableReferenceType::VirtualTable { .. } => { + match &table.table { + Table::BTree(_) => CursorType::BTreeTable(table.btree().unwrap().clone()), + Table::Virtual(_) => { CursorType::VirtualTable(table.virtual_table().unwrap().clone()) } other => panic!("Invalid table reference type in Scan: {:?}", other), }, ); - match (mode, ref_type) { - (OperationMode::SELECT, TableReferenceType::BTreeTable) => { + match (mode, &table.table) { + (OperationMode::SELECT, Table::BTree(_)) => { let root_page = table.btree().unwrap().root_page; program.emit_insn(Insn::OpenReadAsync { cursor_id, @@ -100,7 +98,7 @@ pub fn init_loop( }); program.emit_insn(Insn::OpenReadAwait {}); } - (OperationMode::DELETE, TableReferenceType::BTreeTable) => { + (OperationMode::DELETE, Table::BTree(_)) => { let root_page = table.btree().unwrap().root_page; program.emit_insn(Insn::OpenWriteAsync { cursor_id, @@ -108,7 +106,7 @@ pub fn init_loop( }); program.emit_insn(Insn::OpenWriteAwait {}); } - (OperationMode::SELECT, TableReferenceType::VirtualTable { .. }) => { + (OperationMode::SELECT, Table::Virtual(_)) => { program.emit_insn(Insn::VOpenAsync { cursor_id }); program.emit_insn(Insn::VOpenAwait {}); } @@ -258,10 +256,9 @@ pub fn open_loop( } } Operation::Scan { iter_dir } => { - let ref_type = &table.reference_type; let cursor_id = program.resolve_cursor_id(&table.identifier); - if !matches!(ref_type, TableReferenceType::VirtualTable { .. }) { + if !matches!(&table.table, Table::Virtual(_)) { if iter_dir .as_ref() .is_some_and(|dir| *dir == IterationDirection::Backwards) @@ -271,8 +268,8 @@ pub fn open_loop( program.emit_insn(Insn::RewindAsync { cursor_id }); } } - match ref_type { - TableReferenceType::BTreeTable => program.emit_insn( + match &table.table { + Table::BTree(_) => program.emit_insn( if iter_dir .as_ref() .is_some_and(|dir| *dir == IterationDirection::Backwards) @@ -288,13 +285,18 @@ pub fn open_loop( } }, ), - TableReferenceType::VirtualTable { args, .. } => { + Table::Virtual(ref table) => { + let args = if let Some(args) = table.args.as_ref() { + args + } else { + &vec![] + }; let start_reg = program.alloc_registers(args.len()); let mut cur_reg = start_reg; for arg in args { let reg = cur_reg; cur_reg += 1; - translate_expr(program, Some(tables), arg, reg, &t_ctx.resolver)?; + translate_expr(program, Some(tables), &arg, reg, &t_ctx.resolver)?; } program.emit_insn(Insn::VFilter { cursor_id, @@ -722,11 +724,10 @@ pub fn close_loop( }); } Operation::Scan { iter_dir, .. } => { - let ref_type = &table.reference_type; program.resolve_label(loop_labels.next, program.offset()); let cursor_id = program.resolve_cursor_id(&table.identifier); - match ref_type { - TableReferenceType::BTreeTable { .. } => { + match &table.table { + Table::BTree(_) => { if iter_dir .as_ref() .is_some_and(|dir| *dir == IterationDirection::Backwards) @@ -750,7 +751,7 @@ pub fn close_loop( }); } } - TableReferenceType::VirtualTable { .. } => { + Table::Virtual(_) => { program.emit_insn(Insn::VNext { cursor_id, pc_if_next: loop_labels.loop_start, diff --git a/core/translate/plan.rs b/core/translate/plan.rs index 43cba8e1b..8195aea13 100644 --- a/core/translate/plan.rs +++ b/core/translate/plan.rs @@ -198,7 +198,6 @@ pub struct TableReference { pub identifier: String, /// The join info for this table reference, if it is the right side of a join (which all except the first table reference have) pub join_info: Option, - pub reference_type: TableReferenceType, } #[derive(Clone, Debug)] @@ -225,35 +224,17 @@ pub enum Operation { }, } -/// The type of the table reference, either BTreeTable or Subquery -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum TableReferenceType { - /// A BTreeTable is a table that is stored on disk in a B-tree index. - BTreeTable, - /// A subquery. - Subquery { - /// The index of the first register in the query plan that contains the result columns of the subquery. - result_columns_start_reg: usize, - }, - /// A virtual table. - VirtualTable { - /// Arguments to pass e.g. generate_series(1, 10, 2) - args: Vec, - }, -} - impl TableReference { /// Returns the btree table for this table reference, if it is a BTreeTable. pub fn btree(&self) -> Option> { - match &self.reference_type { - TableReferenceType::BTreeTable => self.table.btree(), - TableReferenceType::Subquery { .. } => None, - TableReferenceType::VirtualTable { .. } => None, + match &self.table { + Table::BTree(_) => self.table.btree(), + _ => None, } } pub fn virtual_table(&self) -> Option> { - match &self.reference_type { - TableReferenceType::VirtualTable { .. } => self.table.virtual_table(), + match &self.table { + Table::Virtual(_) => self.table.virtual_table(), _ => None, } } @@ -280,9 +261,6 @@ impl TableReference { result_columns_start_reg: 0, // Will be set in the bytecode emission phase }, table, - reference_type: TableReferenceType::Subquery { - result_columns_start_reg: 0, // Will be set in the bytecode emission phase - }, identifier: identifier.clone(), join_info, } diff --git a/core/translate/planner.rs b/core/translate/planner.rs index dcde7dc62..311458f9f 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -1,7 +1,7 @@ use super::{ plan::{ Aggregate, JoinInfo, Operation, Plan, ResultSetColumn, SelectQueryType, TableReference, - TableReferenceType, WhereTerm, + WhereTerm, }, select::prepare_select_plan, SymbolTable, @@ -11,7 +11,7 @@ use crate::{ schema::{Schema, Table}, util::{exprs_are_equivalent, normalize_ident}, vdbe::BranchOffset, - Result, + Result, VirtualTable, }; use sqlite3_parser::ast::{self, Expr, FromClause, JoinType, Limit, UnaryOperator}; @@ -301,7 +301,6 @@ fn parse_from_clause_table( table: Table::BTree(table.clone()), identifier: alias.unwrap_or(normalized_qualified_name), join_info: None, - reference_type: TableReferenceType::BTreeTable, }) } ast::SelectTable::Select(subselect, maybe_alias) => { @@ -320,9 +319,9 @@ fn parse_from_clause_table( .unwrap_or(format!("subquery_{}", cur_table_index)); Ok(TableReference::new_subquery(identifier, subplan, None)) } - ast::SelectTable::TableCall(qualified_name, mut maybe_args, maybe_alias) => { - let normalized_name = normalize_ident(qualified_name.name.0.as_str()); - let Some(vtab) = syms.vtabs.get(&normalized_name) else { + ast::SelectTable::TableCall(qualified_name, maybe_args, maybe_alias) => { + let normalized_name = &normalize_ident(qualified_name.name.0.as_str()); + let Some(vtab) = syms.vtabs.get(normalized_name) else { crate::bail_parse_error!("Virtual table {} not found", normalized_name); }; let alias = maybe_alias @@ -331,16 +330,22 @@ fn parse_from_clause_table( ast::As::As(id) => id.0.clone(), ast::As::Elided(id) => id.0.clone(), }) - .unwrap_or(normalized_name); + .unwrap_or(normalized_name.to_string()); Ok(TableReference { op: Operation::Scan { iter_dir: None }, join_info: None, - table: Table::Virtual(vtab.clone().into()), + table: Table::Virtual( + VirtualTable { + name: normalized_name.clone(), + args: maybe_args, + implementation: vtab.implementation.clone(), + columns: vtab.columns.clone(), + } + .into(), + ) + .into(), identifier: alias.clone(), - reference_type: TableReferenceType::VirtualTable { - args: maybe_args.take().unwrap_or_default(), - }, }) } _ => todo!(), diff --git a/testing/extensions.py b/testing/extensions.py index f2099d101..cda953f86 100755 --- a/testing/extensions.py +++ b/testing/extensions.py @@ -307,7 +307,7 @@ def test_crypto(pipe): run_test( pipe, "SELECT crypto_blake('a');", - lambda res: "Error" in res, + lambda res: "Parse error" in res, "crypto_blake3 returns null when ext not loaded", ) run_test(