Merge 'Complete initial pass on virtual tables' from Preston Thorpe

EDIT:
Ok this should finally be all set.  There were some massive changes that
had taken place since the original PR 🫠 🫠
merge conflicts were serious
As soon as this is merged, I will finish the rest of what is needed for
`CREATE VIRTUAL TABLE some_table USING some_extension('some argument or
sql');`
I added tests for the series extension.
This PR is already big and covers a LOT of surface area, so we should
ideally try to get this in first, and I'll go back and complete the rest
of the functionality 👍

Closes #858
This commit is contained in:
Jussi Saurio
2025-02-06 16:29:14 +02:00
26 changed files with 1289 additions and 179 deletions

10
Cargo.lock generated
View File

@@ -1616,6 +1616,7 @@ dependencies = [
"limbo_macros",
"limbo_percentile",
"limbo_regexp",
"limbo_series",
"limbo_time",
"limbo_uuid",
"log",
@@ -1701,6 +1702,15 @@ dependencies = [
"regex",
]
[[package]]
name = "limbo_series"
version = "0.0.14"
dependencies = [
"limbo_ext",
"log",
"mimalloc",
]
[[package]]
name = "limbo_sim"
version = "0.0.14"

View File

@@ -20,6 +20,7 @@ members = [
"extensions/percentile",
"extensions/time",
"extensions/crypto",
"extensions/series",
]
exclude = ["perf/latency/limbo"]

View File

@@ -27,6 +27,7 @@ percentile = ["limbo_percentile/static"]
regexp = ["limbo_regexp/static"]
time = ["limbo_time/static"]
crypto = ["limbo_crypto/static"]
series = ["limbo_series/static"]
[target.'cfg(target_os = "linux")'.dependencies]
io-uring = { version = "0.6.1", optional = true }
@@ -67,6 +68,7 @@ limbo_regexp = { path = "../extensions/regexp", optional = true, features = ["st
limbo_percentile = { path = "../extensions/percentile", optional = true, features = ["static"] }
limbo_time = { path = "../extensions/time", optional = true, features = ["static"] }
limbo_crypto = { path = "../extensions/crypto", optional = true, features = ["static"] }
limbo_series = { path = "../extensions/series", optional = true, features = ["static"] }
miette = "7.4.0"
strum = "0.26"
parking_lot = "0.12.3"

View File

@@ -1,6 +1,11 @@
use crate::{function::ExternalFunc, Database};
use limbo_ext::{ExtensionApi, InitAggFunction, ResultCode, ScalarFunction};
use crate::{function::ExternalFunc, util::columns_from_create_table_body, Database, VirtualTable};
use fallible_iterator::FallibleIterator;
use limbo_ext::{ExtensionApi, InitAggFunction, ResultCode, ScalarFunction, VTabModuleImpl};
pub use limbo_ext::{FinalizeFunction, StepFunction, Value as ExtValue, ValueType as ExtValueType};
use sqlite3_parser::{
ast::{Cmd, Stmt},
lexer::sql::Parser,
};
use std::{
ffi::{c_char, c_void, CStr},
rc::Rc,
@@ -44,6 +49,48 @@ unsafe extern "C" fn register_aggregate_function(
db.register_aggregate_function_impl(&name_str, args, (init_func, step_func, finalize_func))
}
unsafe extern "C" fn register_module(
ctx: *mut c_void,
name: *const c_char,
module: VTabModuleImpl,
) -> ResultCode {
let c_str = unsafe { CStr::from_ptr(name) };
let name_str = match c_str.to_str() {
Ok(s) => s.to_string(),
Err(_) => return ResultCode::Error,
};
if ctx.is_null() {
return ResultCode::Error;
}
let db = unsafe { &mut *(ctx as *mut Database) };
db.register_module_impl(&name_str, module)
}
unsafe extern "C" fn declare_vtab(
ctx: *mut c_void,
name: *const c_char,
sql: *const c_char,
) -> ResultCode {
let c_str = unsafe { CStr::from_ptr(name) };
let name_str = match c_str.to_str() {
Ok(s) => s.to_string(),
Err(_) => return ResultCode::Error,
};
let c_str = unsafe { CStr::from_ptr(sql) };
let sql_str = match c_str.to_str() {
Ok(s) => s.to_string(),
Err(_) => return ResultCode::Error,
};
if ctx.is_null() {
return ResultCode::Error;
}
let db = unsafe { &mut *(ctx as *mut Database) };
db.declare_vtab_impl(&name_str, &sql_str)
}
impl Database {
fn register_scalar_function_impl(&self, name: &str, func: ScalarFunction) -> ResultCode {
self.syms.borrow_mut().functions.insert(
@@ -66,11 +113,42 @@ impl Database {
ResultCode::OK
}
fn register_module_impl(&mut self, name: &str, module: VTabModuleImpl) -> ResultCode {
self.vtab_modules.insert(name.to_string(), Rc::new(module));
ResultCode::OK
}
fn declare_vtab_impl(&mut self, name: &str, sql: &str) -> ResultCode {
let mut parser = Parser::new(sql.as_bytes());
let cmd = parser.next().unwrap().unwrap();
let Cmd::Stmt(stmt) = cmd else {
return ResultCode::Error;
};
let Stmt::CreateTable { body, .. } = stmt else {
return ResultCode::Error;
};
let Ok(columns) = columns_from_create_table_body(body) else {
return ResultCode::Error;
};
let vtab_module = self.vtab_modules.get(name).unwrap().clone();
let vtab = VirtualTable {
name: name.to_string(),
implementation: vtab_module,
columns,
args: None,
};
self.syms.borrow_mut().vtabs.insert(name.to_string(), vtab);
ResultCode::OK
}
pub fn build_limbo_ext(&self) -> ExtensionApi {
ExtensionApi {
ctx: self as *const _ as *mut c_void,
register_scalar_function,
register_aggregate_function,
register_module,
declare_vtab,
}
}
@@ -96,6 +174,10 @@ impl Database {
if unsafe { !limbo_crypto::register_extension_static(&ext_api).is_ok() } {
return Err("Failed to register crypto extension".to_string());
}
#[cfg(feature = "series")]
if unsafe { !limbo_series::register_extension_static(&ext_api).is_ok() } {
return Err("Failed to register series extension".to_string());
}
Ok(())
}
}

View File

@@ -25,12 +25,13 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
use fallible_iterator::FallibleIterator;
#[cfg(not(target_family = "wasm"))]
use libloading::{Library, Symbol};
#[cfg(not(target_family = "wasm"))]
use limbo_ext::{ExtensionApi, ExtensionEntryPoint};
use limbo_ext::{ResultCode, VTabModuleImpl, Value as ExtValue};
use log::trace;
use parking_lot::RwLock;
use schema::Schema;
use sqlite3_parser::ast;
use sqlite3_parser::{ast::Cmd, lexer::sql::Parser};
use schema::{Column, Schema};
use sqlite3_parser::{ast, ast::Cmd, lexer::sql::Parser};
use std::cell::Cell;
use std::collections::HashMap;
use std::num::NonZero;
@@ -44,9 +45,11 @@ use storage::pager::allocate_page;
use storage::sqlite3_ondisk::{DatabaseHeader, DATABASE_HEADER_SIZE};
pub use storage::wal::WalFile;
pub use storage::wal::WalFileShared;
use types::OwnedValue;
pub use types::Value;
use util::parse_schema_rows;
use vdbe::builder::QueryMode;
use vdbe::VTabOpaqueCursor;
pub use error::LimboError;
use translate::select::prepare_select_plan;
@@ -82,6 +85,7 @@ pub struct Database {
schema: Rc<RefCell<Schema>>,
header: Rc<RefCell<DatabaseHeader>>,
syms: Rc<RefCell<SymbolTable>>,
vtab_modules: HashMap<String, Rc<VTabModuleImpl>>,
// Shared structures of a Database are the parts that are common to multiple threads that might
// create DB connections.
_shared_page_cache: Arc<RwLock<DumbLruPageCache>>,
@@ -144,6 +148,7 @@ impl Database {
_shared_page_cache: _shared_page_cache.clone(),
_shared_wal: shared_wal.clone(),
syms,
vtab_modules: HashMap::new(),
};
if let Err(e) = db.register_builtins() {
return Err(LimboError::ExtensionError(e));
@@ -267,8 +272,8 @@ impl Connection {
let sql = sql.as_ref();
trace!("Preparing: {}", sql);
let db = &self.db;
let syms: &SymbolTable = &db.syms.borrow();
let mut parser = Parser::new(sql.as_bytes());
let syms = &db.syms.borrow();
let cmd = parser.next()?;
if let Some(cmd) = cmd {
match cmd {
@@ -412,7 +417,7 @@ impl Connection {
#[cfg(not(target_family = "wasm"))]
pub fn load_extension<P: AsRef<std::ffi::OsStr>>(&self, path: P) -> Result<()> {
Database::load_extension(self.db.as_ref(), path)
Database::load_extension(&self.db, path)
}
/// Close a connection and checkpoint.
@@ -506,10 +511,71 @@ pub type Row = types::Record;
pub type StepResult = vdbe::StepResult;
#[derive(Clone, Debug)]
pub struct VirtualTable {
name: String,
args: Option<Vec<ast::Expr>>,
pub implementation: Rc<VTabModuleImpl>,
columns: Vec<Column>,
}
impl VirtualTable {
pub fn open(&self) -> VTabOpaqueCursor {
let cursor = unsafe { (self.implementation.open)() };
VTabOpaqueCursor::new(cursor)
}
pub fn filter(
&self,
cursor: &VTabOpaqueCursor,
arg_count: usize,
args: Vec<OwnedValue>,
) -> Result<()> {
let mut filter_args = Vec::with_capacity(arg_count);
for i in 0..arg_count {
let ownedvalue_arg = args.get(i).unwrap();
let extvalue_arg: ExtValue = match ownedvalue_arg {
OwnedValue::Null => Ok(ExtValue::null()),
OwnedValue::Integer(i) => Ok(ExtValue::from_integer(*i)),
OwnedValue::Float(f) => Ok(ExtValue::from_float(*f)),
OwnedValue::Text(t) => Ok(ExtValue::from_text(t.as_str().to_string())),
OwnedValue::Blob(b) => Ok(ExtValue::from_blob((**b).clone())),
other => Err(LimboError::ExtensionError(format!(
"Unsupported value type: {:?}",
other
))),
}?;
filter_args.push(extvalue_arg);
}
let rc = unsafe {
(self.implementation.filter)(cursor.as_ptr(), arg_count as i32, filter_args.as_ptr())
};
match rc {
ResultCode::OK => Ok(()),
_ => Err(LimboError::ExtensionError(rc.to_string())),
}
}
pub fn column(&self, cursor: &VTabOpaqueCursor, column: usize) -> Result<OwnedValue> {
let val = unsafe { (self.implementation.column)(cursor.as_ptr(), column as u32) };
OwnedValue::from_ffi(&val)
}
pub fn next(&self, cursor: &VTabOpaqueCursor) -> Result<bool> {
let rc = unsafe { (self.implementation.next)(cursor.as_ptr()) };
match rc {
ResultCode::OK => Ok(true),
ResultCode::EOF => Ok(false),
_ => Err(LimboError::ExtensionError("Next failed".to_string())),
}
}
}
pub(crate) struct SymbolTable {
pub functions: HashMap<String, Rc<function::ExternalFunc>>,
#[cfg(not(target_family = "wasm"))]
extensions: Vec<(Library, *const ExtensionApi)>,
pub vtabs: HashMap<String, VirtualTable>,
}
impl std::fmt::Debug for SymbolTable {
@@ -551,6 +617,7 @@ impl SymbolTable {
pub fn new() -> Self {
Self {
functions: HashMap::new(),
vtabs: HashMap::new(),
#[cfg(not(target_family = "wasm"))]
extensions: Vec::new(),
}

View File

@@ -1,3 +1,4 @@
use crate::VirtualTable;
use crate::{util::normalize_ident, Result};
use core::fmt;
use fallible_iterator::FallibleIterator;
@@ -47,6 +48,7 @@ impl Schema {
pub enum Table {
BTree(Rc<BTreeTable>),
Pseudo(Rc<PseudoTable>),
Virtual(Rc<VirtualTable>),
}
impl Table {
@@ -54,6 +56,7 @@ impl Table {
match self {
Table::BTree(table) => table.root_page,
Table::Pseudo(_) => unimplemented!(),
Table::Virtual(_) => unimplemented!(),
}
}
@@ -61,19 +64,15 @@ impl Table {
match self {
Self::BTree(table) => &table.name,
Self::Pseudo(_) => "",
Self::Virtual(table) => &table.name,
}
}
pub fn get_column_at(&self, index: usize) -> &Column {
pub fn get_column_at(&self, index: usize) -> Option<&Column> {
match self {
Self::BTree(table) => table
.columns
.get(index)
.expect("column index out of bounds"),
Self::Pseudo(table) => table
.columns
.get(index)
.expect("column index out of bounds"),
Self::BTree(table) => table.columns.get(index),
Self::Pseudo(table) => table.columns.get(index),
Self::Virtual(table) => table.columns.get(index),
}
}
@@ -81,6 +80,7 @@ impl Table {
match self {
Self::BTree(table) => &table.columns,
Self::Pseudo(table) => &table.columns,
Self::Virtual(table) => &table.columns,
}
}
@@ -88,6 +88,14 @@ impl Table {
match self {
Self::BTree(table) => Some(table.clone()),
Self::Pseudo(_) => None,
Self::Virtual(_) => None,
}
}
pub fn virtual_table(&self) -> Option<Rc<VirtualTable>> {
match self {
Self::Virtual(table) => Some(table.clone()),
_ => None,
}
}
}
@@ -97,6 +105,7 @@ impl PartialEq for Table {
match (self, other) {
(Self::BTree(a), Self::BTree(b)) => Rc::ptr_eq(a, b),
(Self::Pseudo(a), Self::Pseudo(b)) => Rc::ptr_eq(a, b),
(Self::Virtual(a), Self::Virtual(b)) => Rc::ptr_eq(a, b),
_ => false,
}
}
@@ -155,7 +164,7 @@ impl BTreeTable {
sql.push_str(",\n");
}
sql.push_str(" ");
sql.push_str(&column.name.as_ref().expect("column name is None"));
sql.push_str(column.name.as_ref().expect("column name is None"));
sql.push(' ');
sql.push_str(&column.ty.to_string());
}

View File

@@ -3,7 +3,7 @@ use sqlite3_parser::ast::{self, UnaryOperator};
#[cfg(feature = "json")]
use crate::function::JsonFunc;
use crate::function::{Func, FuncCtx, MathFuncArity, ScalarFunc, VectorFunc};
use crate::schema::Type;
use crate::schema::{Table, Type};
use crate::util::normalize_ident;
use crate::vdbe::{
builder::ProgramBuilder,
@@ -1823,24 +1823,38 @@ pub fn translate_expr(
match table_reference.op {
// If we are reading a column from a table, we find the cursor that corresponds to
// the table and read the column from the cursor.
Operation::Scan { .. } | Operation::Search(_) => {
let cursor_id = program.resolve_cursor_id(&table_reference.identifier);
if *is_rowid_alias {
program.emit_insn(Insn::RowId {
cursor_id,
dest: target_register,
});
} else {
program.emit_insn(Insn::Column {
Operation::Scan { .. } | Operation::Search(_) => match &table_reference.table {
Table::BTree(_) => {
let cursor_id = program.resolve_cursor_id(&table_reference.identifier);
if *is_rowid_alias {
program.emit_insn(Insn::RowId {
cursor_id,
dest: target_register,
});
} else {
program.emit_insn(Insn::Column {
cursor_id,
column: *column,
dest: target_register,
});
}
let Some(column) = table_reference.table.get_column_at(*column) else {
crate::bail_parse_error!("column index out of bounds");
};
maybe_apply_affinity(column.ty, target_register, program);
Ok(target_register)
}
Table::Virtual(_) => {
let cursor_id = program.resolve_cursor_id(&table_reference.identifier);
program.emit_insn(Insn::VColumn {
cursor_id,
column: *column,
dest: target_register,
});
Ok(target_register)
}
let column = table_reference.table.get_column_at(*column);
maybe_apply_affinity(column.ty, target_register, program);
Ok(target_register)
}
_ => unreachable!(),
},
// If we are reading a column from a subquery, we instead copy the column from the
// subquery's result registers.
Operation::Subquery {

View File

@@ -1,6 +1,7 @@
use sqlite3_parser::ast;
use crate::{
schema::Table,
translate::result_row::emit_select_result,
vdbe::{
builder::{CursorType, ProgramBuilder},
@@ -80,25 +81,35 @@ pub fn init_loop(
Operation::Scan { .. } => {
let cursor_id = program.alloc_cursor_id(
Some(table.identifier.clone()),
CursorType::BTreeTable(table.btree().unwrap().clone()),
match &table.table {
Table::BTree(_) => CursorType::BTreeTable(table.btree().unwrap().clone()),
Table::Virtual(_) => {
CursorType::VirtualTable(table.virtual_table().unwrap().clone())
}
other => panic!("Invalid table reference type in Scan: {:?}", other),
},
);
let root_page = table.table.get_root_page();
match mode {
OperationMode::SELECT => {
match (mode, &table.table) {
(OperationMode::SELECT, Table::BTree(_)) => {
let root_page = table.btree().unwrap().root_page;
program.emit_insn(Insn::OpenReadAsync {
cursor_id,
root_page,
});
program.emit_insn(Insn::OpenReadAwait {});
}
OperationMode::DELETE => {
(OperationMode::DELETE, Table::BTree(_)) => {
let root_page = table.btree().unwrap().root_page;
program.emit_insn(Insn::OpenWriteAsync {
cursor_id,
root_page,
});
program.emit_insn(Insn::OpenWriteAwait {});
}
(OperationMode::SELECT, Table::Virtual(_)) => {
program.emit_insn(Insn::VOpenAsync { cursor_id });
program.emit_insn(Insn::VOpenAwait {});
}
_ => {
unimplemented!()
}
@@ -246,30 +257,55 @@ pub fn open_loop(
}
Operation::Scan { iter_dir } => {
let cursor_id = program.resolve_cursor_id(&table.identifier);
if iter_dir
.as_ref()
.is_some_and(|dir| *dir == IterationDirection::Backwards)
{
program.emit_insn(Insn::LastAsync { cursor_id });
} else {
program.emit_insn(Insn::RewindAsync { cursor_id });
}
program.emit_insn(
if !matches!(&table.table, Table::Virtual(_)) {
if iter_dir
.as_ref()
.is_some_and(|dir| *dir == IterationDirection::Backwards)
{
Insn::LastAwait {
cursor_id,
pc_if_empty: loop_end,
}
program.emit_insn(Insn::LastAsync { cursor_id });
} else {
Insn::RewindAwait {
cursor_id,
pc_if_empty: loop_end,
program.emit_insn(Insn::RewindAsync { cursor_id });
}
}
match &table.table {
Table::BTree(_) => program.emit_insn(
if iter_dir
.as_ref()
.is_some_and(|dir| *dir == IterationDirection::Backwards)
{
Insn::LastAwait {
cursor_id,
pc_if_empty: loop_end,
}
} else {
Insn::RewindAwait {
cursor_id,
pc_if_empty: loop_end,
}
},
),
Table::Virtual(ref table) => {
let args = if let Some(args) = table.args.as_ref() {
args
} else {
&vec![]
};
let start_reg = program.alloc_registers(args.len());
let mut cur_reg = start_reg;
for arg in args {
let reg = cur_reg;
cur_reg += 1;
translate_expr(program, Some(tables), &arg, reg, &t_ctx.resolver)?;
}
},
);
program.emit_insn(Insn::VFilter {
cursor_id,
arg_count: args.len(),
args_reg: start_reg,
});
}
other => panic!("Unsupported table reference type: {:?}", other),
}
program.resolve_label(loop_start, program.offset());
for cond in predicates
@@ -690,27 +726,38 @@ pub fn close_loop(
Operation::Scan { iter_dir, .. } => {
program.resolve_label(loop_labels.next, program.offset());
let cursor_id = program.resolve_cursor_id(&table.identifier);
if iter_dir
.as_ref()
.is_some_and(|dir| *dir == IterationDirection::Backwards)
{
program.emit_insn(Insn::PrevAsync { cursor_id });
} else {
program.emit_insn(Insn::NextAsync { cursor_id });
}
if iter_dir
.as_ref()
.is_some_and(|dir| *dir == IterationDirection::Backwards)
{
program.emit_insn(Insn::PrevAwait {
cursor_id,
pc_if_next: loop_labels.loop_start,
});
} else {
program.emit_insn(Insn::NextAwait {
cursor_id,
pc_if_next: loop_labels.loop_start,
});
match &table.table {
Table::BTree(_) => {
if iter_dir
.as_ref()
.is_some_and(|dir| *dir == IterationDirection::Backwards)
{
program.emit_insn(Insn::PrevAsync { cursor_id });
} else {
program.emit_insn(Insn::NextAsync { cursor_id });
}
if iter_dir
.as_ref()
.is_some_and(|dir| *dir == IterationDirection::Backwards)
{
program.emit_insn(Insn::PrevAwait {
cursor_id,
pc_if_next: loop_labels.loop_start,
});
} else {
program.emit_insn(Insn::NextAwait {
cursor_id,
pc_if_next: loop_labels.loop_start,
});
}
}
Table::Virtual(_) => {
program.emit_insn(Insn::VNext {
cursor_id,
pc_if_next: loop_labels.loop_start,
});
}
other => unreachable!("Unsupported table reference type: {:?}", other),
}
}
Operation::Search(search) => {

View File

@@ -204,16 +204,16 @@ fn eliminate_constant_conditions(
}
fn push_scan_direction(table: &mut TableReference, direction: &Direction) {
match &mut table.op {
Operation::Scan { iter_dir, .. } => {
if iter_dir.is_none() {
match direction {
Direction::Ascending => *iter_dir = Some(IterationDirection::Forwards),
Direction::Descending => *iter_dir = Some(IterationDirection::Backwards),
}
if let Operation::Scan {
ref mut iter_dir, ..
} = table.op
{
if iter_dir.is_none() {
match direction {
Direction::Ascending => *iter_dir = Some(IterationDirection::Forwards),
Direction::Descending => *iter_dir = Some(IterationDirection::Backwards),
}
}
_ => {}
}
}
@@ -307,14 +307,14 @@ impl Optimizable for ast::Expr {
else {
return Ok(None);
};
let column = table_reference.table.get_column_at(*column);
let Some(column) = table_reference.table.get_column_at(*column) else {
return Ok(None);
};
for index in available_indexes_for_table.iter() {
if column
.name
.as_ref()
.map_or(false, |name| *name == index.columns.first().unwrap().name)
{
return Ok(Some(index.clone()));
if let Some(name) = column.name.as_ref() {
if &index.columns.first().unwrap().name == name {
return Ok(Some(index.clone()));
}
}
}
Ok(None)

View File

@@ -9,6 +9,7 @@ use crate::{
function::AggFunc,
schema::{BTreeTable, Column, Index, Table},
vdbe::BranchOffset,
VirtualTable,
};
use crate::{
schema::{PseudoTable, Type},
@@ -199,9 +200,6 @@ pub struct TableReference {
pub join_info: Option<JoinInfo>,
}
/**
A SourceOperator is a reference in the query plan that reads data from a table.
*/
#[derive(Clone, Debug)]
pub enum Operation {
// Scan operation
@@ -229,7 +227,16 @@ pub enum Operation {
impl TableReference {
/// Returns the btree table for this table reference, if it is a BTreeTable.
pub fn btree(&self) -> Option<Rc<BTreeTable>> {
self.table.btree()
match &self.table {
Table::BTree(_) => self.table.btree(),
_ => None,
}
}
pub fn virtual_table(&self) -> Option<Rc<VirtualTable>> {
match &self.table {
Table::Virtual(_) => self.table.virtual_table(),
_ => None,
}
}
/// Creates a new TableReference for a subquery.

View File

@@ -11,7 +11,7 @@ use crate::{
schema::{Schema, Table},
util::{exprs_are_equivalent, normalize_ident},
vdbe::BranchOffset,
Result,
Result, VirtualTable,
};
use sqlite3_parser::ast::{self, Expr, FromClause, JoinType, Limit, UnaryOperator};
@@ -317,8 +317,36 @@ fn parse_from_clause_table(
ast::As::Elided(id) => id.0.clone(),
})
.unwrap_or(format!("subquery_{}", cur_table_index));
let table_reference = TableReference::new_subquery(identifier, subplan, None);
Ok(table_reference)
Ok(TableReference::new_subquery(identifier, subplan, None))
}
ast::SelectTable::TableCall(qualified_name, maybe_args, maybe_alias) => {
let normalized_name = &normalize_ident(qualified_name.name.0.as_str());
let Some(vtab) = syms.vtabs.get(normalized_name) else {
crate::bail_parse_error!("Virtual table {} not found", normalized_name);
};
let alias = maybe_alias
.as_ref()
.map(|a| match a {
ast::As::As(id) => id.0.clone(),
ast::As::Elided(id) => id.0.clone(),
})
.unwrap_or(normalized_name.to_string());
Ok(TableReference {
op: Operation::Scan { iter_dir: None },
join_info: None,
table: Table::Virtual(
VirtualTable {
name: normalized_name.clone(),
args: maybe_args,
implementation: vtab.implementation.clone(),
columns: vtab.columns.clone(),
}
.into(),
)
.into(),
identifier: alias.clone(),
})
}
_ => todo!(),
}

View File

@@ -28,9 +28,9 @@ pub fn translate_select(
let mut program = ProgramBuilder::new(ProgramBuilderOpts {
query_mode,
num_cursors: count_plan_required_cursors(&select),
approx_num_insns: estimate_num_instructions(&select),
approx_num_labels: estimate_num_labels(&select),
num_cursors: count_plan_required_cursors(select),
approx_num_insns: estimate_num_instructions(select),
approx_num_labels: estimate_num_labels(select),
});
emit_program(&mut program, select_plan, syms)?;
Ok(program)

View File

@@ -6,6 +6,7 @@ use crate::pseudo::PseudoCursor;
use crate::storage::btree::BTreeCursor;
use crate::storage::sqlite3_ondisk::write_varint;
use crate::vdbe::sorter::Sorter;
use crate::vdbe::VTabOpaqueCursor;
use crate::Result;
use std::fmt::Display;
use std::rc::Rc;
@@ -670,6 +671,7 @@ pub enum Cursor {
Index(BTreeCursor),
Pseudo(PseudoCursor),
Sorter(Sorter),
Virtual(VTabOpaqueCursor),
}
impl Cursor {
@@ -716,6 +718,13 @@ impl Cursor {
_ => panic!("Cursor is not a sorter cursor"),
}
}
pub fn as_virtual_mut(&mut self) -> &mut VTabOpaqueCursor {
match self {
Self::Virtual(cursor) => cursor,
_ => panic!("Cursor is not a virtual cursor"),
}
}
}
pub enum CursorResult<T> {

View File

@@ -1,9 +1,8 @@
use sqlite3_parser::ast::{self, CreateTableBody, Expr, FunctionTail, Literal};
use std::{rc::Rc, sync::Arc};
use sqlite3_parser::ast::{Expr, FunctionTail, Literal};
use crate::{
schema::{self, Schema},
schema::{self, Column, Schema, Type},
Result, Statement, StepResult, IO,
};
@@ -308,6 +307,77 @@ pub fn exprs_are_equivalent(expr1: &Expr, expr2: &Expr) -> bool {
}
}
pub fn columns_from_create_table_body(body: ast::CreateTableBody) -> Result<Vec<Column>, ()> {
let CreateTableBody::ColumnsAndConstraints { columns, .. } = body else {
return Err(());
};
Ok(columns
.into_iter()
.filter_map(|(name, column_def)| {
// if column_def.col_type includes HIDDEN, omit it for now
if let Some(data_type) = column_def.col_type.as_ref() {
if data_type.name.as_str().contains("HIDDEN") {
return None;
}
}
let column = Column {
name: Some(name.0),
ty: match column_def.col_type {
Some(ref data_type) => {
// https://www.sqlite.org/datatype3.html
let type_name = data_type.name.as_str().to_uppercase();
if type_name.contains("INT") {
Type::Integer
} else if type_name.contains("CHAR")
|| type_name.contains("CLOB")
|| type_name.contains("TEXT")
{
Type::Text
} else if type_name.contains("BLOB") || type_name.is_empty() {
Type::Blob
} else if type_name.contains("REAL")
|| type_name.contains("FLOA")
|| type_name.contains("DOUB")
{
Type::Real
} else {
Type::Numeric
}
}
None => Type::Null,
},
default: column_def
.constraints
.iter()
.find_map(|c| match &c.constraint {
sqlite3_parser::ast::ColumnConstraint::Default(val) => Some(val.clone()),
_ => None,
}),
notnull: column_def.constraints.iter().any(|c| {
matches!(
c.constraint,
sqlite3_parser::ast::ColumnConstraint::NotNull { .. }
)
}),
ty_str: column_def
.col_type
.clone()
.map(|t| t.name.to_string())
.unwrap_or_default(),
primary_key: column_def.constraints.iter().any(|c| {
matches!(
c.constraint,
sqlite3_parser::ast::ColumnConstraint::PrimaryKey { .. }
)
}),
is_rowid_alias: false,
};
Some(column)
})
.collect::<Vec<_>>())
}
#[cfg(test)]
pub mod tests {
use super::*;

View File

@@ -9,7 +9,7 @@ use crate::{
schema::{BTreeTable, Index, PseudoTable},
storage::sqlite3_ondisk::DatabaseHeader,
translate::plan::{ResultSetColumn, TableReference},
Connection,
Connection, VirtualTable,
};
use super::{BranchOffset, CursorID, Insn, InsnReference, Program};
@@ -40,6 +40,7 @@ pub enum CursorType {
BTreeIndex(Rc<Index>),
Pseudo(Rc<PseudoTable>),
Sorter,
VirtualTable(Rc<VirtualTable>),
}
impl CursorType {
@@ -406,6 +407,9 @@ impl ProgramBuilder {
Insn::IsNull { reg: _, target_pc } => {
resolve(target_pc, "IsNull");
}
Insn::VNext { pc_if_next, .. } => {
resolve(pc_if_next, "VNext");
}
_ => continue,
}
}

View File

@@ -363,6 +363,62 @@ pub fn insn_to_str(
0,
"".to_string(),
),
Insn::VOpenAsync { cursor_id } => (
"VOpenAsync",
*cursor_id as i32,
0,
0,
OwnedValue::build_text(Rc::new("".to_string())),
0,
"".to_string(),
),
Insn::VOpenAwait => (
"VOpenAwait",
0,
0,
0,
OwnedValue::build_text(Rc::new("".to_string())),
0,
"".to_string(),
),
Insn::VFilter {
cursor_id,
arg_count,
args_reg,
} => (
"VFilter",
*cursor_id as i32,
*arg_count as i32,
*args_reg as i32,
OwnedValue::build_text(Rc::new("".to_string())),
0,
"".to_string(),
),
Insn::VColumn {
cursor_id,
column,
dest,
} => (
"VColumn",
*cursor_id as i32,
*column as i32,
*dest as i32,
OwnedValue::build_text(Rc::new("".to_string())),
0,
"".to_string(),
),
Insn::VNext {
cursor_id,
pc_if_next,
} => (
"VNext",
*cursor_id as i32,
pc_if_next.to_debug_int(),
0,
OwnedValue::build_text(Rc::new("".to_string())),
0,
"".to_string(),
),
Insn::OpenPseudo {
cursor_id,
content_reg,
@@ -423,6 +479,7 @@ pub fn insn_to_str(
name
}
CursorType::Sorter => None,
CursorType::VirtualTable(v) => v.columns.get(*column).unwrap().name.as_ref(),
};
(
"Column",

View File

@@ -213,6 +213,35 @@ pub enum Insn {
// Await for the completion of open cursor.
OpenReadAwait,
/// Open a cursor for a virtual table.
VOpenAsync {
cursor_id: CursorID,
},
/// Await for the completion of open cursor for a virtual table.
VOpenAwait,
/// Initialize the position of the virtual table cursor.
VFilter {
cursor_id: CursorID,
arg_count: usize,
args_reg: usize,
},
/// Read a column from the current row of the virtual table cursor.
VColumn {
cursor_id: CursorID,
column: usize,
dest: usize,
},
/// Advance the virtual table cursor to the next row.
/// TODO: async
VNext {
cursor_id: CursorID,
pc_if_next: BranchOffset,
},
// Open a cursor for a pseudo-table that contains a single row.
OpenPseudo {
cursor_id: CursorID,

View File

@@ -65,6 +65,7 @@ use sorter::Sorter;
use std::borrow::BorrowMut;
use std::cell::{Cell, RefCell, RefMut};
use std::collections::HashMap;
use std::ffi::c_void;
use std::num::NonZero;
use std::rc::{Rc, Weak};
@@ -267,6 +268,19 @@ fn get_cursor_as_sorter_mut<'long, 'short>(
cursor
}
fn get_cursor_as_virtual_mut<'long, 'short>(
cursors: &'short mut RefMut<'long, Vec<Option<Cursor>>>,
cursor_id: CursorID,
) -> &'short mut VTabOpaqueCursor {
let cursor = cursors
.get_mut(cursor_id)
.expect("cursor id out of bounds")
.as_mut()
.expect("cursor not allocated")
.as_virtual_mut();
cursor
}
struct Bitfield<const N: usize>([u64; N]);
impl<const N: usize> Bitfield<N> {
@@ -290,6 +304,18 @@ impl<const N: usize> Bitfield<N> {
}
}
pub struct VTabOpaqueCursor(*mut c_void);
impl VTabOpaqueCursor {
pub fn new(cursor: *mut c_void) -> Self {
Self(cursor)
}
pub fn as_ptr(&self) -> *mut c_void {
self.0
}
}
/// The program state describes the environment in which the program executes.
pub struct ProgramState {
pub pc: InsnReference,
@@ -370,6 +396,7 @@ macro_rules! must_be_btree_cursor {
CursorType::BTreeIndex(_) => get_cursor_as_index_mut(&mut $cursors, $cursor_id),
CursorType::Pseudo(_) => panic!("{} on pseudo cursor", $insn_name),
CursorType::Sorter => panic!("{} on sorter cursor", $insn_name),
CursorType::VirtualTable(_) => panic!("{} on virtual table cursor", $insn_name),
};
cursor
}};
@@ -826,12 +853,79 @@ impl Program {
CursorType::Sorter => {
panic!("OpenReadAsync on sorter cursor");
}
CursorType::VirtualTable(_) => {
panic!("OpenReadAsync on virtual table cursor, use Insn::VOpenAsync instead");
}
}
state.pc += 1;
}
Insn::OpenReadAwait => {
state.pc += 1;
}
Insn::VOpenAsync { cursor_id } => {
let (_, cursor_type) = self.cursor_ref.get(*cursor_id).unwrap();
let CursorType::VirtualTable(virtual_table) = cursor_type else {
panic!("VOpenAsync on non-virtual table cursor");
};
let cursor = virtual_table.open();
state
.cursors
.borrow_mut()
.insert(*cursor_id, Some(Cursor::Virtual(cursor)));
state.pc += 1;
}
Insn::VOpenAwait => {
state.pc += 1;
}
Insn::VFilter {
cursor_id,
arg_count,
args_reg,
} => {
let (_, cursor_type) = self.cursor_ref.get(*cursor_id).unwrap();
let CursorType::VirtualTable(virtual_table) = cursor_type else {
panic!("VFilter on non-virtual table cursor");
};
let mut cursors = state.cursors.borrow_mut();
let cursor = get_cursor_as_virtual_mut(&mut cursors, *cursor_id);
let mut args = Vec::new();
for i in 0..*arg_count {
args.push(state.registers[args_reg + i].clone());
}
virtual_table.filter(cursor, *arg_count, args)?;
state.pc += 1;
}
Insn::VColumn {
cursor_id,
column,
dest,
} => {
let (_, cursor_type) = self.cursor_ref.get(*cursor_id).unwrap();
let CursorType::VirtualTable(virtual_table) = cursor_type else {
panic!("VColumn on non-virtual table cursor");
};
let mut cursors = state.cursors.borrow_mut();
let cursor = get_cursor_as_virtual_mut(&mut cursors, *cursor_id);
state.registers[*dest] = virtual_table.column(cursor, *column)?;
state.pc += 1;
}
Insn::VNext {
cursor_id,
pc_if_next,
} => {
let (_, cursor_type) = self.cursor_ref.get(*cursor_id).unwrap();
let CursorType::VirtualTable(virtual_table) = cursor_type else {
panic!("VNextAsync on non-virtual table cursor");
};
let mut cursors = state.cursors.borrow_mut();
let cursor = get_cursor_as_virtual_mut(&mut cursors, *cursor_id);
let has_more = virtual_table.next(cursor)?;
if has_more {
state.pc = pc_if_next.to_offset_int();
} else {
state.pc += 1;
}
}
Insn::OpenPseudo {
cursor_id,
content_reg: _,
@@ -943,6 +1037,11 @@ impl Program {
state.registers[*dest] = OwnedValue::Null;
}
}
CursorType::VirtualTable(_) => {
panic!(
"Insn::Column on virtual table cursor, use Insn::VColumn instead"
);
}
}
state.pc += 1;

View File

@@ -9,7 +9,8 @@ like traditional `sqlite3` extensions, but are able to be written in much more e
- [ x ] **Scalar Functions**: Create scalar functions using the `scalar` macro.
- [ x ] **Aggregate Functions**: Define aggregate functions with `AggregateDerive` macro and `AggFunc` trait.
- [] **Virtual tables**: TODO
- [ x ] **Virtual tables**: Create a module for a virtual table with the `VTabModuleDerive` macro and `VTabCursor` trait.
- [] **VFS Modules**
---
## Installation
@@ -17,24 +18,32 @@ like traditional `sqlite3` extensions, but are able to be written in much more e
Add the crate to your `Cargo.toml`:
```toml
[features]
static = ["limbo_ext/static"]
[dependencies]
limbo_ext = { path = "path/to/limbo/extensions/core" } # temporary until crate is published
limbo_ext = { path = "path/to/limbo/extensions/core", features = ["static"] } # temporary until crate is published
# mimalloc is required if you intend on linking dynamically. It is imported for you by the register_extension
# macro, so no configuration is needed. But it must be added to your Cargo.toml
[target.'cfg(not(target_family = "wasm"))'.dependencies]
mimalloc = { version = "*", default-features = false }
```
**NOTE** Crate must be of type `cdylib` if you wish to link dynamically
```
# NOTE: Crate must be of type `cdylib` if you wish to link dynamically
[lib]
crate-type = ["cdylib", "lib"]
```
`cargo build` will output a shared library that can be loaded with `.load target/debug/libyour_crate_name`
`cargo build` will output a shared library that can be loaded by the following options:
#### **CLI:**
`.load target/debug/libyour_crate_name`
#### **SQL:**
`SELECT load_extension('target/debug/libyour_crate_name')`
Extensions can be registered with the `register_extension!` macro:
@@ -44,6 +53,7 @@ Extensions can be registered with the `register_extension!` macro:
register_extension!{
scalars: { double }, // name of your function, if different from attribute name
aggregates: { Percentile },
vtabs: { CsvVTable },
}
```
@@ -140,4 +150,101 @@ impl AggFunc for Percentile {
}
```
### Virtual Table Example:
```rust
/// Example: A virtual table that operates on a CSV file as a database table.
/// This example assumes that the CSV file is located at "data.csv" in the current directory.
#[derive(Debug, VTabModuleDerive)]
struct CsvVTable;
impl VTabModule for CsvVTable {
type VCursor = CsvCursor;
/// Declare the name for your virtual table
const NAME: &'static str = "csv_data";
/// Declare the table schema and call `api.declare_virtual_table` with the schema sql.
fn connect(api: &ExtensionApi) -> ResultCode {
let sql = "CREATE TABLE csv_data(
name TEXT,
age TEXT,
city TEXT
)";
api.declare_virtual_table(Self::NAME, sql)
}
/// Open to return a new cursor: In this simple example, the CSV file is read completely into memory on connect.
fn open() -> Self::VCursor {
// Read CSV file contents from "data.csv"
let csv_content = fs::read_to_string("data.csv").unwrap_or_default();
// For simplicity, we'll ignore the header row.
let rows: Vec<Vec<String>> = csv_content
.lines()
.skip(1)
.map(|line| {
line.split(',')
.map(|s| s.trim().to_string())
.collect()
})
.collect();
CsvCursor { rows, index: 0 }
}
/// Filter through result columns. (not used in this simple example)
fn filter(_cursor: &mut Self::VCursor, _arg_count: i32, _args: &[Value]) -> ResultCode {
ResultCode::OK
}
/// Return the value for the column at the given index in the current row.
fn column(cursor: &Self::VCursor, idx: u32) -> Value {
cursor.column(idx)
}
/// Next advances the cursor to the next row.
fn next(cursor: &mut Self::VCursor) -> ResultCode {
if cursor.index < cursor.rows.len() - 1 {
cursor.index += 1;
ResultCode::OK
} else {
ResultCode::EOF
}
}
/// Return true if the cursor is at the end.
fn eof(cursor: &Self::VCursor) -> bool {
cursor.index >= cursor.rows.len()
}
}
/// The cursor for iterating over CSV rows.
#[derive(Debug)]
struct CsvCursor {
rows: Vec<Vec<String>>,
index: usize,
}
/// Implement the VTabCursor trait for your cursor type
impl VTabCursor for CsvCursor {
fn next(&mut self) -> ResultCode {
CsvCursor::next(self)
}
fn eof(&self) -> bool {
self.index >= self.rows.len()
}
fn column(&self, idx: u32) -> Value {
let row = &self.rows[self.index];
if (idx as usize) < row.len() {
Value::from_text(&row[idx as usize])
} else {
Value::null()
}
}
fn rowid(&self) -> i64 {
self.index as i64
}
}
```

View File

@@ -1,5 +1,5 @@
mod types;
pub use limbo_macros::{register_extension, scalar, AggregateDerive};
pub use limbo_macros::{register_extension, scalar, AggregateDerive, VTabModuleDerive};
use std::os::raw::{c_char, c_void};
pub use types::{ResultCode, Value, ValueType};
@@ -21,6 +21,30 @@ pub struct ExtensionApi {
step_func: StepFunction,
finalize_func: FinalizeFunction,
) -> ResultCode,
pub register_module: unsafe extern "C" fn(
ctx: *mut c_void,
name: *const c_char,
module: VTabModuleImpl,
) -> ResultCode,
pub declare_vtab: unsafe extern "C" fn(
ctx: *mut c_void,
name: *const c_char,
sql: *const c_char,
) -> ResultCode,
}
impl ExtensionApi {
pub fn declare_virtual_table(&self, name: &str, sql: &str) -> ResultCode {
let Ok(name) = std::ffi::CString::new(name) else {
return ResultCode::Error;
};
let Ok(sql) = std::ffi::CString::new(sql) else {
return ResultCode::Error;
};
unsafe { (self.declare_vtab)(self.ctx, name.as_ptr(), sql.as_ptr()) }
}
}
pub type ExtensionEntryPoint = unsafe extern "C" fn(api: *const ExtensionApi) -> ResultCode;
@@ -47,3 +71,53 @@ pub trait AggFunc {
fn step(state: &mut Self::State, args: &[Value]);
fn finalize(state: Self::State) -> Value;
}
#[repr(C)]
#[derive(Clone, Debug)]
pub struct VTabModuleImpl {
pub name: *const c_char,
pub connect: VtabFnConnect,
pub open: VtabFnOpen,
pub filter: VtabFnFilter,
pub column: VtabFnColumn,
pub next: VtabFnNext,
pub eof: VtabFnEof,
}
pub type VtabFnConnect = unsafe extern "C" fn(api: *const c_void) -> ResultCode;
pub type VtabFnOpen = unsafe extern "C" fn() -> *mut c_void;
pub type VtabFnFilter =
unsafe extern "C" fn(cursor: *mut c_void, argc: i32, argv: *const Value) -> ResultCode;
pub type VtabFnColumn = unsafe extern "C" fn(cursor: *mut c_void, idx: u32) -> Value;
pub type VtabFnNext = unsafe extern "C" fn(cursor: *mut c_void) -> ResultCode;
pub type VtabFnEof = unsafe extern "C" fn(cursor: *mut c_void) -> bool;
pub trait VTabModule: 'static {
type VCursor: VTabCursor;
const NAME: &'static str;
fn connect(api: &ExtensionApi) -> ResultCode;
fn open() -> Self::VCursor;
fn filter(cursor: &mut Self::VCursor, arg_count: i32, args: &[Value]) -> ResultCode;
fn column(cursor: &Self::VCursor, idx: u32) -> Value;
fn next(cursor: &mut Self::VCursor) -> ResultCode;
fn eof(cursor: &Self::VCursor) -> bool;
}
pub trait VTabCursor: Sized {
type Error: std::fmt::Display;
fn rowid(&self) -> i64;
fn column(&self, idx: u32) -> Value;
fn eof(&self) -> bool;
fn next(&mut self) -> ResultCode;
}
#[repr(C)]
pub struct VTabImpl {
pub module: VTabModuleImpl,
}

View File

@@ -2,8 +2,8 @@ use std::fmt::Display;
/// Error type is of type ExtError which can be
/// either a user defined error or an error code
#[derive(Clone, Copy)]
#[repr(C)]
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum ResultCode {
OK = 0,
Error = 1,
@@ -20,6 +20,7 @@ pub enum ResultCode {
Internal = 12,
Unavailable = 13,
CustomError = 14,
EOF = 15,
}
impl ResultCode {
@@ -50,6 +51,7 @@ impl Display for ResultCode {
ResultCode::Internal => write!(f, "Internal Error"),
ResultCode::Unavailable => write!(f, "Unavailable"),
ResultCode::CustomError => write!(f, "Error "),
ResultCode::EOF => write!(f, "EOF"),
}
}
}

View File

@@ -0,0 +1,21 @@
[package]
name = "limbo_series"
version.workspace = true
authors.workspace = true
edition.workspace = true
license.workspace = true
repository.workspace = true
[features]
static = ["limbo_ext/static"]
[lib]
crate-type = ["cdylib", "lib"]
[dependencies]
limbo_ext = { path = "../core", features = ["static"] }
log = "0.4.20"
[target.'cfg(not(target_family = "wasm"))'.dependencies]
mimalloc = { version = "*", default-features = false }

View File

@@ -0,0 +1,118 @@
use limbo_ext::{
register_extension, ExtensionApi, ResultCode, VTabCursor, VTabModule, VTabModuleDerive, Value,
};
register_extension! {
vtabs: { GenerateSeriesVTab }
}
macro_rules! try_option {
($expr:expr, $err:expr) => {
match $expr {
Some(val) => val,
None => return $err,
}
};
}
/// A virtual table that generates a sequence of integers
#[derive(Debug, VTabModuleDerive)]
struct GenerateSeriesVTab;
impl VTabModule for GenerateSeriesVTab {
type VCursor = GenerateSeriesCursor;
const NAME: &'static str = "generate_series";
fn connect(api: &ExtensionApi) -> ResultCode {
// Create table schema
let sql = "CREATE TABLE generate_series(
value INTEGER,
start INTEGER HIDDEN,
stop INTEGER HIDDEN,
step INTEGER HIDDEN
)";
api.declare_virtual_table(Self::NAME, sql)
}
fn open() -> Self::VCursor {
GenerateSeriesCursor {
start: 0,
stop: 0,
step: 0,
current: 0,
}
}
fn filter(cursor: &mut Self::VCursor, arg_count: i32, args: &[Value]) -> ResultCode {
// args are the start, stop, and step
if arg_count == 0 || arg_count > 3 {
return ResultCode::InvalidArgs;
}
let start = try_option!(args[0].to_integer(), ResultCode::InvalidArgs);
let stop = try_option!(
args.get(1).map(|v| v.to_integer().unwrap_or(i64::MAX)),
ResultCode::InvalidArgs
);
let step = try_option!(
args.get(2).map(|v| v.to_integer().unwrap_or(1)),
ResultCode::InvalidArgs
);
cursor.start = start;
cursor.current = start;
cursor.step = step;
cursor.stop = stop;
ResultCode::OK
}
fn column(cursor: &Self::VCursor, idx: u32) -> Value {
cursor.column(idx)
}
fn next(cursor: &mut Self::VCursor) -> ResultCode {
GenerateSeriesCursor::next(cursor)
}
fn eof(cursor: &Self::VCursor) -> bool {
cursor.eof()
}
}
/// The cursor for iterating over the generated sequence
#[derive(Debug)]
struct GenerateSeriesCursor {
start: i64,
stop: i64,
step: i64,
current: i64,
}
impl VTabCursor for GenerateSeriesCursor {
type Error = ResultCode;
fn next(&mut self) -> ResultCode {
let next_val = self.current.saturating_add(self.step);
if (self.step > 0 && next_val > self.stop) || (self.step < 0 && next_val < self.stop) {
return ResultCode::EOF;
}
self.current = next_val;
ResultCode::OK
}
fn eof(&self) -> bool {
(self.step > 0 && self.current > self.stop) || (self.step < 0 && self.current < self.stop)
}
fn column(&self, idx: u32) -> Value {
match idx {
0 => Value::from_integer(self.current),
1 => Value::from_integer(self.start),
2 => Value::from_integer(self.stop),
3 => Value::from_integer(self.step),
_ => Value::null(),
}
}
fn rowid(&self) -> i64 {
((self.current - self.start) / self.step) + 1
}
}

View File

@@ -6,31 +6,32 @@ use syn::{Ident, LitStr, Token};
pub(crate) struct RegisterExtensionInput {
pub aggregates: Vec<Ident>,
pub scalars: Vec<Ident>,
pub vtabs: Vec<Ident>,
}
impl syn::parse::Parse for RegisterExtensionInput {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut aggregates = Vec::new();
let mut scalars = Vec::new();
let mut vtabs = Vec::new();
while !input.is_empty() {
if input.peek(syn::Ident) && input.peek2(Token![:]) {
let section_name: Ident = input.parse()?;
input.parse::<Token![:]>()?;
if section_name == "aggregates" || section_name == "scalars" {
let names = ["aggregates", "scalars", "vtabs"];
if names.contains(&section_name.to_string().as_str()) {
let content;
syn::braced!(content in input);
let parsed_items = Punctuated::<Ident, Token![,]>::parse_terminated(&content)?
.into_iter()
.collect();
if section_name == "aggregates" {
aggregates = parsed_items;
} else {
scalars = parsed_items;
}
match section_name.to_string().as_str() {
"aggregates" => aggregates = parsed_items,
"scalars" => scalars = parsed_items,
"vtabs" => vtabs = parsed_items,
_ => unreachable!(),
};
if input.peek(Token![,]) {
input.parse::<Token![,]>()?;
@@ -39,13 +40,14 @@ impl syn::parse::Parse for RegisterExtensionInput {
return Err(syn::Error::new(section_name.span(), "Unknown section"));
}
} else {
return Err(input.error("Expected aggregates: or scalars: section"));
return Err(input.error("Expected aggregates:, scalars:, or vtabs: section"));
}
}
Ok(Self {
aggregates,
scalars,
vtabs,
})
}
}

View File

@@ -324,6 +324,201 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream {
TokenStream::from(expanded)
}
/// Macro to derive a VTabModule for your extension. This macro will generate
/// the necessary functions to register your module with core. You must implement
/// the VTabModule trait for your struct, and the VTabCursor trait for your cursor.
/// ```ignore
///#[derive(Debug, VTabModuleDerive)]
///struct CsvVTab;
///impl VTabModule for CsvVTab {
/// type VCursor = CsvCursor;
/// const NAME: &'static str = "csv_data";
///
/// /// Declare the schema for your virtual table
/// fn connect(api: &ExtensionApi) -> ResultCode {
/// let sql = "CREATE TABLE csv_data(
/// name TEXT,
/// age TEXT,
/// city TEXT
/// )";
/// api.declare_virtual_table(Self::NAME, sql)
/// }
/// /// Open the virtual table and return a cursor
/// fn open() -> Self::VCursor {
/// let csv_content = fs::read_to_string("data.csv").unwrap_or_default();
/// let rows: Vec<Vec<String>> = csv_content
/// .lines()
/// .skip(1)
/// .map(|line| {
/// line.split(',')
/// .map(|s| s.trim().to_string())
/// .collect()
/// })
/// .collect();
/// CsvCursor { rows, index: 0 }
/// }
/// /// Filter the virtual table based on arguments (omitted here for simplicity)
/// fn filter(_cursor: &mut Self::VCursor, _arg_count: i32, _args: &[Value]) -> ResultCode {
/// ResultCode::OK
/// }
/// /// Return the value for a given column index
/// fn column(cursor: &Self::VCursor, idx: u32) -> Value {
/// cursor.column(idx)
/// }
/// /// Move the cursor to the next row
/// fn next(cursor: &mut Self::VCursor) -> ResultCode {
/// if cursor.index < cursor.rows.len() - 1 {
/// cursor.index += 1;
/// ResultCode::OK
/// } else {
/// ResultCode::EOF
/// }
/// }
/// fn eof(cursor: &Self::VCursor) -> bool {
/// cursor.index >= cursor.rows.len()
/// }
/// #[derive(Debug)]
/// struct CsvCursor {
/// rows: Vec<Vec<String>>,
/// index: usize,
///
/// impl CsvCursor {
/// /// Returns the value for a given column index.
/// fn column(&self, idx: u32) -> Value {
/// let row = &self.rows[self.index];
/// if (idx as usize) < row.len() {
/// Value::from_text(&row[idx as usize])
/// } else {
/// Value::null()
/// }
/// }
/// // Implement the VTabCursor trait for your virtual cursor
/// impl VTabCursor for CsvCursor {
/// fn next(&mut self) -> ResultCode {
/// Self::next(self)
/// }
/// fn eof(&self) -> bool {
/// self.index >= self.rows.len()
/// }
/// fn column(&self, idx: u32) -> Value {
/// self.column(idx)
/// }
/// fn rowid(&self) -> i64 {
/// self.index as i64
/// }
///
#[proc_macro_derive(VTabModuleDerive)]
pub fn derive_vtab_module(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as DeriveInput);
let struct_name = &ast.ident;
let register_fn_name = format_ident!("register_{}", struct_name);
let connect_fn_name = format_ident!("connect_{}", struct_name);
let open_fn_name = format_ident!("open_{}", struct_name);
let filter_fn_name = format_ident!("filter_{}", struct_name);
let column_fn_name = format_ident!("column_{}", struct_name);
let next_fn_name = format_ident!("next_{}", struct_name);
let eof_fn_name = format_ident!("eof_{}", struct_name);
let expanded = quote! {
impl #struct_name {
#[no_mangle]
unsafe extern "C" fn #connect_fn_name(
db: *const ::std::ffi::c_void,
) -> ::limbo_ext::ResultCode {
if db.is_null() {
return ::limbo_ext::ResultCode::Error;
}
let api = unsafe { &*(db as *const ExtensionApi) };
<#struct_name as ::limbo_ext::VTabModule>::connect(api)
}
#[no_mangle]
unsafe extern "C" fn #open_fn_name(
) -> *mut ::std::ffi::c_void {
let cursor = <#struct_name as ::limbo_ext::VTabModule>::open();
Box::into_raw(Box::new(cursor)) as *mut ::std::ffi::c_void
}
#[no_mangle]
unsafe extern "C" fn #filter_fn_name(
cursor: *mut ::std::ffi::c_void,
argc: i32,
argv: *const ::limbo_ext::Value,
) -> ::limbo_ext::ResultCode {
if cursor.is_null() {
return ::limbo_ext::ResultCode::Error;
}
let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) };
let args = std::slice::from_raw_parts(argv, argc as usize);
<#struct_name as ::limbo_ext::VTabModule>::filter(cursor, argc, args)
}
#[no_mangle]
unsafe extern "C" fn #column_fn_name(
cursor: *mut ::std::ffi::c_void,
idx: u32,
) -> ::limbo_ext::Value {
if cursor.is_null() {
return ::limbo_ext::Value::error(ResultCode::Error);
}
let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) };
<#struct_name as ::limbo_ext::VTabModule>::column(cursor, idx)
}
#[no_mangle]
unsafe extern "C" fn #next_fn_name(
cursor: *mut ::std::ffi::c_void,
) -> ::limbo_ext::ResultCode {
if cursor.is_null() {
return ::limbo_ext::ResultCode::Error;
}
let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) };
<#struct_name as ::limbo_ext::VTabModule>::next(cursor)
}
#[no_mangle]
unsafe extern "C" fn #eof_fn_name(
cursor: *mut ::std::ffi::c_void,
) -> bool {
if cursor.is_null() {
return true;
}
let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) };
<#struct_name as ::limbo_ext::VTabModule>::eof(cursor)
}
#[no_mangle]
pub unsafe extern "C" fn #register_fn_name(
api: *const ::limbo_ext::ExtensionApi
) -> ::limbo_ext::ResultCode {
if api.is_null() {
return ::limbo_ext::ResultCode::Error;
}
let api = &*api;
let name = <#struct_name as ::limbo_ext::VTabModule>::NAME;
// name needs to be a c str FFI compatible, NOT CString
let name_c = std::ffi::CString::new(name).unwrap();
let module = ::limbo_ext::VTabModuleImpl {
name: name_c.as_ptr(),
connect: Self::#connect_fn_name,
open: Self::#open_fn_name,
filter: Self::#filter_fn_name,
column: Self::#column_fn_name,
next: Self::#next_fn_name,
eof: Self::#eof_fn_name,
};
(api.register_module)(api.ctx, name_c.as_ptr(), module)
}
}
};
TokenStream::from(expanded)
}
/// Register your extension with 'core' by providing the relevant functions
///```ignore
///use limbo_ext::{register_extension, scalar, Value, AggregateDerive, AggFunc};
@@ -362,6 +557,7 @@ pub fn register_extension(input: TokenStream) -> TokenStream {
let RegisterExtensionInput {
aggregates,
scalars,
vtabs,
} = input_ast;
let scalar_calls = scalars.iter().map(|scalar_ident| {
@@ -388,8 +584,23 @@ pub fn register_extension(input: TokenStream) -> TokenStream {
}
}
});
let vtab_calls = vtabs.iter().map(|vtab_ident| {
let register_fn = syn::Ident::new(&format!("register_{}", vtab_ident), vtab_ident.span());
quote! {
{
let result = unsafe{ #vtab_ident::#register_fn(api)};
if result == ::limbo_ext::ResultCode::OK {
let result = <#vtab_ident as ::limbo_ext::VTabModule>::connect(api);
if !result.is_ok() {
return result;
}
}
}
}
});
let static_aggregates = aggregate_calls.clone();
let static_scalars = scalar_calls.clone();
let static_vtabs = vtab_calls.clone();
let expanded = quote! {
#[cfg(not(target_family = "wasm"))]
@@ -404,19 +615,23 @@ pub fn register_extension(input: TokenStream) -> TokenStream {
#(#static_aggregates)*
#(#static_vtabs)*
::limbo_ext::ResultCode::OK
}
#[cfg(not(feature = "static"))]
#[no_mangle]
pub unsafe extern "C" fn register_extension(api: &::limbo_ext::ExtensionApi) -> ::limbo_ext::ResultCode {
let api = unsafe { &*api };
#(#scalar_calls)*
#[no_mangle]
pub unsafe extern "C" fn register_extension(api: &::limbo_ext::ExtensionApi) -> ::limbo_ext::ResultCode {
let api = unsafe { &*api };
#(#scalar_calls)*
#(#aggregate_calls)*
#(#aggregate_calls)*
::limbo_ext::ResultCode::OK
}
#(#vtab_calls)*
::limbo_ext::ResultCode::OK
}
};
TokenStream::from(expanded)

View File

@@ -110,14 +110,19 @@ def validate_blob(result):
# and assert they are valid hex digits
return int(result, 16) is not None
def validate_string_uuid(result):
return len(result) == 36 and result.count("-") == 4
def returns_error(result):
def returns_error_no_func(result):
return "error: no such function: " in result
def returns_vtable_parse_err(result):
return "Parse error: Virtual table" in result
def returns_null(result):
return result == "" or result == "\n"
@@ -129,6 +134,7 @@ def assert_now_unixtime(result):
def assert_specific_time(result):
return result == "1736720789"
def test_uuid(pipe):
specific_time = "01945ca0-3189-76c0-9a8f-caf310fc8b8e"
# these are built into the binary, so we just test they work
@@ -165,7 +171,7 @@ def test_regexp(pipe):
extension_path = "./target/debug/liblimbo_regexp.so"
# before extension loads, assert no function
run_test(pipe, "SELECT regexp('a.c', 'abc');", returns_error)
run_test(pipe, "SELECT regexp('a.c', 'abc');", returns_error_no_func)
run_test(pipe, f".load {extension_path}", returns_null)
print(f"Extension {extension_path} loaded successfully.")
run_test(pipe, "SELECT regexp('a.c', 'abc');", validate_true)
@@ -205,13 +211,14 @@ def validate_percentile2(res):
def validate_percentile_disc(res):
return res == "40.0"
def test_aggregates(pipe):
extension_path = "./target/debug/liblimbo_percentile.so"
# assert no function before extension loads
run_test(
pipe,
"SELECT median(1);",
returns_error,
returns_error_no_func,
"median agg function returns null when ext not loaded",
)
run_test(
@@ -252,63 +259,55 @@ def test_aggregates(pipe):
pipe, "SELECT percentile_disc(value, 0.55) from test;", validate_percentile_disc
)
# Hashes
def validate_blake3(a):
return a == "6437b3ac38465133ffb63b75273a8db548c558465d79db03fd359c6cd5bd9d85"
def validate_md5(a):
return a == "900150983cd24fb0d6963f7d28e17f72"
def validate_sha1(a):
return a == "a9993e364706816aba3e25717850c26c9cd0d89d"
def validate_sha256(a):
return a == "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"
def validate_sha384(a):
return a == "cb00753f45a35e8bb5a03d699ac65007272c32ab0eded1631a8b605a43ff5bed8086072ba1e7cc2358baeca134c825a7"
def validate_sha512(a):
return a == "ddaf35a193617abacc417349ae20413112e6fa4e89a97ea20a9eeee64b55d39a2192992a274fc1a836ba3c23a3feebbd454d4423643ce80e2a9ac94fa54ca49f"
# Encoders and decoders
def validate_url_encode(a):
return a == f"%2Fhello%3Ftext%3D%28%E0%B2%A0_%E0%B2%A0%29"
return a == "%2Fhello%3Ftext%3D%28%E0%B2%A0_%E0%B2%A0%29"
def validate_url_decode(a):
return a == "/hello?text=(ಠ_ಠ)"
def validate_hex_encode(a):
return a == "68656c6c6f"
def validate_hex_decode(a):
return a == "hello"
def validate_base85_encode(a):
return a == "BOu!rDZ"
def validate_base85_decode(a):
return a == "hello"
def validate_base32_encode(a):
return a == "NBSWY3DP"
def validate_base32_decode(a):
return a == "hello"
def validate_base64_encode(a):
return a == "aGVsbG8="
def validate_base64_decode(a):
return a == "hello"
def test_crypto(pipe):
extension_path = "./target/debug/liblimbo_crypto.so"
# assert no function before extension loads
run_test(
pipe,
"SELECT crypto_blake('a');",
returns_error,
lambda res: "Parse error" in res,
"crypto_blake3 returns null when ext not loaded",
)
run_test(
@@ -321,104 +320,139 @@ def test_crypto(pipe):
run_test(
pipe,
"SELECT crypto_encode(crypto_blake3('abc'), 'hex');",
validate_blake3,
"blake3 should encrypt correctly"
lambda res: res
== "6437b3ac38465133ffb63b75273a8db548c558465d79db03fd359c6cd5bd9d85",
"blake3 should encrypt correctly",
)
run_test(
pipe,
"SELECT crypto_encode(crypto_md5('abc'), 'hex');",
validate_md5,
"md5 should encrypt correctly"
lambda res: res == "900150983cd24fb0d6963f7d28e17f72",
"md5 should encrypt correctly",
)
run_test(
pipe,
"SELECT crypto_encode(crypto_sha1('abc'), 'hex');",
validate_sha1,
"sha1 should encrypt correctly"
lambda res: res == "a9993e364706816aba3e25717850c26c9cd0d89d",
"sha1 should encrypt correctly",
)
run_test(
pipe,
"SELECT crypto_encode(crypto_sha256('abc'), 'hex');",
validate_sha256,
"sha256 should encrypt correctly"
lambda a: a
== "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad",
"sha256 should encrypt correctly",
)
run_test(
pipe,
"SELECT crypto_encode(crypto_sha384('abc'), 'hex');",
validate_sha384,
"sha384 should encrypt correctly"
lambda a: a
== "cb00753f45a35e8bb5a03d699ac65007272c32ab0eded1631a8b605a43ff5bed8086072ba1e7cc2358baeca134c825a7",
"sha384 should encrypt correctly",
)
run_test(
pipe,
"SELECT crypto_encode(crypto_sha512('abc'), 'hex');",
validate_sha512,
"sha512 should encrypt correctly"
)
lambda a: a
== "ddaf35a193617abacc417349ae20413112e6fa4e89a97ea20a9eeee64b55d39a2192992a274fc1a836ba3c23a3feebbd454d4423643ce80e2a9ac94fa54ca49f",
"sha512 should encrypt correctly",
)
# Encoding and Decoding
run_test(
pipe,
"SELECT crypto_encode('hello', 'base32');",
validate_base32_encode,
"base32 should encode correctly"
)
"base32 should encode correctly",
)
run_test(
pipe,
"SELECT crypto_decode('NBSWY3DP', 'base32');",
validate_base32_decode,
"base32 should decode correctly"
"base32 should decode correctly",
)
run_test(
pipe,
"SELECT crypto_encode('hello', 'base64');",
validate_base64_encode,
"base64 should encode correctly"
"base64 should encode correctly",
)
run_test(
pipe,
"SELECT crypto_decode('aGVsbG8=', 'base64');",
validate_base64_decode,
"base64 should decode correctly"
"base64 should decode correctly",
)
run_test(
pipe,
"SELECT crypto_encode('hello', 'base85');",
validate_base85_encode,
"base85 should encode correctly"
"base85 should encode correctly",
)
run_test(
pipe,
"SELECT crypto_decode('BOu!rDZ', 'base85');",
validate_base85_decode,
"base85 should decode correctly"
"base85 should decode correctly",
)
run_test(
pipe,
"SELECT crypto_encode('hello', 'hex');",
validate_hex_encode,
"hex should encode correctly"
"hex should encode correctly",
)
run_test(
pipe,
"SELECT crypto_decode('68656c6c6f', 'hex');",
validate_hex_decode,
"hex should decode correctly"
"hex should decode correctly",
)
run_test(
pipe,
"SELECT crypto_encode('/hello?text=(ಠ_ಠ)', 'url');",
validate_url_encode,
"url should encode correctly"
"url should encode correctly",
)
run_test(
pipe,
f"SELECT crypto_decode('%2Fhello%3Ftext%3D%28%E0%B2%A0_%E0%B2%A0%29', 'url');",
"SELECT crypto_decode('%2Fhello%3Ftext%3D%28%E0%B2%A0_%E0%B2%A0%29', 'url');",
validate_url_decode,
"url should decode correctly"
"url should decode correctly",
)
def test_series(pipe):
ext_path = "./target/debug/liblimbo_series"
run_test(
pipe,
"SELECT * FROM generate_series(1, 10);",
lambda res: "Virtual table generate_series not found" in res,
)
run_test(pipe, f".load {ext_path}", returns_null)
run_test(
pipe,
"SELECT * FROM generate_series(1, 10);",
lambda res: "Invalid Argument" in res,
)
run_test(
pipe,
"SELECT * FROM generate_series(1, 10, 2);",
lambda res: res == "1\n3\n5\n7\n9",
)
run_test(
pipe,
"SELECT * FROM generate_series(1, 10, 2, 3);",
lambda res: "Invalid Argument" in res,
)
run_test(
pipe,
"SELECT * FROM generate_series(10, 1, -2);",
lambda res: res == "10\n8\n6\n4\n2",
)
def main():
pipe = init_limbo()
try:
@@ -426,6 +460,8 @@ def main():
test_uuid(pipe)
test_aggregates(pipe)
test_crypto(pipe)
test_series(pipe)
except Exception as e:
print(f"Test FAILED: {e}")
pipe.terminate()