diff --git a/core/lib.rs b/core/lib.rs index a4366ed7c..6c0b7b6c2 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -77,6 +77,7 @@ use translate::select::prepare_select_plan; pub use types::RefValue; pub use types::Value; use util::{columns_from_create_table_body, parse_schema_rows}; +use vdbe::builder::TableRefIdCounter; use vdbe::{builder::QueryMode, VTabOpaqueCursor}; pub type Result = std::result::Result; pub static DATABASE_VERSION: OnceLock = OnceLock::new(); @@ -407,6 +408,7 @@ impl Connection { Ok(Some(stmt)) } Cmd::ExplainQueryPlan(stmt) => { + let mut table_ref_counter = TableRefIdCounter::new(); match stmt { ast::Stmt::Select(select) => { let mut plan = prepare_select_plan( @@ -417,6 +419,7 @@ impl Connection { *select, &syms, None, + &mut table_ref_counter, )?; optimize_plan( &mut plan, diff --git a/core/translate/delete.rs b/core/translate/delete.rs index c6d7a7869..a576c6cc6 100644 --- a/core/translate/delete.rs +++ b/core/translate/delete.rs @@ -3,7 +3,7 @@ use crate::translate::emitter::emit_program; use crate::translate::optimizer::optimize_plan; use crate::translate::plan::{DeletePlan, Operation, Plan}; use crate::translate::planner::{parse_limit, parse_where}; -use crate::vdbe::builder::{ProgramBuilder, ProgramBuilderOpts, QueryMode}; +use crate::vdbe::builder::{ProgramBuilder, ProgramBuilderOpts, QueryMode, TableRefIdCounter}; use crate::{schema::Schema, Result, SymbolTable}; use limbo_sqlite3_parser::ast::{Expr, Limit, QualifiedName}; @@ -18,7 +18,13 @@ pub fn translate_delete( syms: &SymbolTable, mut program: ProgramBuilder, ) -> Result { - let mut delete_plan = prepare_delete_plan(schema, tbl_name, where_clause, limit)?; + let mut delete_plan = prepare_delete_plan( + schema, + tbl_name, + where_clause, + limit, + &mut program.table_reference_counter, + )?; optimize_plan(&mut delete_plan, schema)?; let Plan::Delete(ref delete) = delete_plan else { panic!("delete_plan is not a DeletePlan"); @@ -39,6 +45,7 @@ pub fn prepare_delete_plan( tbl_name: &QualifiedName, where_clause: Option>, limit: Option>, + table_ref_counter: &mut TableRefIdCounter, ) -> Result { let table = match schema.get_table(tbl_name.name.0.as_str()) { Some(table) => table, @@ -60,6 +67,7 @@ pub fn prepare_delete_plan( let mut table_references = vec![TableReference { table, identifier: name, + internal_id: table_ref_counter.next(), op: Operation::Scan { iter_dir: IterationDirection::Forwards, index: None, diff --git a/core/translate/expr.rs b/core/translate/expr.rs index b80ae0273..0a71a834d 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -1790,7 +1790,12 @@ pub fn translate_expr( column, is_rowid_alias, } => { - let table_reference = referenced_tables.as_ref().unwrap().get(*table).unwrap(); + let table_reference = referenced_tables + .as_ref() + .unwrap() + .iter() + .find(|t| t.internal_id == *table) + .unwrap(); let index = table_reference.op.index(); let use_covering_index = table_reference.utilizes_covering_index(); @@ -1883,7 +1888,12 @@ pub fn translate_expr( } } ast::Expr::RowId { database: _, table } => { - let table_reference = referenced_tables.as_ref().unwrap().get(*table).unwrap(); + let table_reference = referenced_tables + .as_ref() + .unwrap() + .iter() + .find(|t| t.internal_id == *table) + .unwrap(); let index = table_reference.op.index(); let use_covering_index = table_reference.utilizes_covering_index(); if use_covering_index { diff --git a/core/translate/group_by.rs b/core/translate/group_by.rs index 75d111f38..971ebe403 100644 --- a/core/translate/group_by.rs +++ b/core/translate/group_by.rs @@ -135,7 +135,11 @@ pub fn init_group_by( CollationSeq::new(collation_name).map(Some) } ast::Expr::Column { table, column, .. } => { - let table_reference = plan.table_references.get(*table).unwrap(); + let table_reference = plan + .table_references + .iter() + .find(|t| t.internal_id == *table) + .unwrap(); let Some(table_column) = table_reference.table.get_column_at(*column) else { crate::bail_parse_error!("column index out of bounds"); diff --git a/core/translate/main_loop.rs b/core/translate/main_loop.rs index 8206cdef5..d66ac22ed 100644 --- a/core/translate/main_loop.rs +++ b/core/translate/main_loop.rs @@ -298,7 +298,7 @@ pub fn open_loop( predicates: &mut [WhereTerm], ) -> Result<()> { for (join_index, join) in join_order.iter().enumerate() { - let table_index = join.table_no; + let table_index = join.original_idx; let table = &tables[table_index]; let LoopLabels { loop_start, @@ -345,110 +345,112 @@ pub fn open_loop( program.preassign_label_to_next_insn(loop_start); } Table::Virtual(vtab) => { - let (start_reg, count, maybe_idx_str, maybe_idx_int) = - if vtab.kind.eq(&VTabKind::VirtualTable) { - // Virtual‑table (non‑TVF) modules can receive constraints via xBestIndex. - // They return information with which to pass to VFilter operation. - // We forward every predicate that touches vtab columns. - // - // vtab.col = literal (always usable) - // vtab.col = outer_table.col (usable, because outer_table is already positioned) - // vtab.col = later_table.col (forwarded with usable = false) - // - // xBestIndex decides which ones it wants by setting argvIndex and whether the - // core layer may omit them (omit = true). - // We then materialise the RHS/LHS into registers before issuing VFilter. - let converted_constraints = predicates - .iter() - .filter(|p| p.should_eval_at_loop(join_index, join_order)) - .enumerate() - .filter_map(|(i, p)| { - // Build ConstraintInfo from the predicates - convert_where_to_vtab_constraint(p, table_index, i) - .unwrap_or(None) - }) - .collect::>(); - // TODO: get proper order_by information to pass to the vtab. - // maybe encode more info on t_ctx? we need: [col_idx, is_descending] - let index_info = vtab.best_index(&converted_constraints, &[]); + let (start_reg, count, maybe_idx_str, maybe_idx_int) = if vtab + .kind + .eq(&VTabKind::VirtualTable) + { + // Virtual‑table (non‑TVF) modules can receive constraints via xBestIndex. + // They return information with which to pass to VFilter operation. + // We forward every predicate that touches vtab columns. + // + // vtab.col = literal (always usable) + // vtab.col = outer_table.col (usable, because outer_table is already positioned) + // vtab.col = later_table.col (forwarded with usable = false) + // + // xBestIndex decides which ones it wants by setting argvIndex and whether the + // core layer may omit them (omit = true). + // We then materialise the RHS/LHS into registers before issuing VFilter. + let converted_constraints = predicates + .iter() + .filter(|p| p.should_eval_at_loop(join_index, join_order)) + .enumerate() + .filter_map(|(i, p)| { + // Build ConstraintInfo from the predicates + convert_where_to_vtab_constraint(p, table_index, i, join_order) + .unwrap_or(None) + }) + .collect::>(); + // TODO: get proper order_by information to pass to the vtab. + // maybe encode more info on t_ctx? we need: [col_idx, is_descending] + let index_info = vtab.best_index(&converted_constraints, &[]); - // Determine the number of VFilter arguments (constraints with an argv_index). - let args_needed = index_info - .constraint_usages - .iter() - .filter(|u| u.argv_index.is_some()) - .count(); - let start_reg = program.alloc_registers(args_needed); + // Determine the number of VFilter arguments (constraints with an argv_index). + let args_needed = index_info + .constraint_usages + .iter() + .filter(|u| u.argv_index.is_some()) + .count(); + let start_reg = program.alloc_registers(args_needed); - // For each constraint used by best_index, translate the opposite side. - for (i, usage) in index_info.constraint_usages.iter().enumerate() { - if let Some(argv_index) = usage.argv_index { - if let Some(cinfo) = converted_constraints.get(i) { - let (pred_idx, is_rhs) = cinfo.unpack_plan_info(); - if let ast::Expr::Binary(lhs, _, rhs) = - &predicates[pred_idx].expr - { - // translate the opposite side of the referenced vtab column - let expr = if is_rhs { lhs } else { rhs }; - // argv_index is 1-based; adjust to get the proper register offset. - if argv_index == 0 { - // invalid since argv_index is 1-based - continue; - } - let target_reg = - start_reg + (argv_index - 1) as usize; - translate_expr( - program, - Some(tables), - expr, - target_reg, - &t_ctx.resolver, - )?; - if cinfo.usable && usage.omit { - predicates[pred_idx].consumed = true; - } + // For each constraint used by best_index, translate the opposite side. + for (i, usage) in index_info.constraint_usages.iter().enumerate() { + if let Some(argv_index) = usage.argv_index { + if let Some(cinfo) = converted_constraints.get(i) { + let (pred_idx, is_rhs) = cinfo.unpack_plan_info(); + if let ast::Expr::Binary(lhs, _, rhs) = + &predicates[pred_idx].expr + { + // translate the opposite side of the referenced vtab column + let expr = if is_rhs { lhs } else { rhs }; + // argv_index is 1-based; adjust to get the proper register offset. + if argv_index == 0 { + // invalid since argv_index is 1-based + continue; + } + let target_reg = start_reg + (argv_index - 1) as usize; + translate_expr( + program, + Some(tables), + expr, + target_reg, + &t_ctx.resolver, + )?; + if cinfo.usable && usage.omit { + predicates[pred_idx].consumed = true; } } } } - // If best_index provided an idx_str, translate it. - let maybe_idx_str = if let Some(idx_str) = index_info.idx_str { - let reg = program.alloc_register(); - program.emit_insn(Insn::String8 { - dest: reg, - value: idx_str, - }); - Some(reg) - } else { - None - }; - ( - start_reg, - args_needed, - maybe_idx_str, - Some(index_info.idx_num), - ) + } + + // If best_index provided an idx_str, translate it. + let maybe_idx_str = if let Some(idx_str) = index_info.idx_str { + let reg = program.alloc_register(); + program.emit_insn(Insn::String8 { + dest: reg, + value: idx_str, + }); + Some(reg) } else { - // For table-valued functions: translate the table args. - let args = match vtab.args.as_ref() { - Some(args) => args, - None => &vec![], - }; - let start_reg = program.alloc_registers(args.len()); - let mut cur_reg = start_reg; - for arg in args { - let reg = cur_reg; - cur_reg += 1; - let _ = translate_expr( - program, - Some(tables), - arg, - reg, - &t_ctx.resolver, - )?; - } - (start_reg, args.len(), None, None) + None }; + ( + start_reg, + args_needed, + maybe_idx_str, + Some(index_info.idx_num), + ) + } else { + // For table-valued functions: translate the table args. + let args = match vtab.args.as_ref() { + Some(args) => args, + None => &vec![], + }; + let start_reg = program.alloc_registers(args.len()); + let mut cur_reg = start_reg; + for arg in args { + let reg = cur_reg; + cur_reg += 1; + let _ = translate_expr( + program, + Some(tables), + arg, + reg, + &t_ctx.resolver, + )?; + } + (start_reg, args.len(), None, None) + }; // Emit VFilter with the computed arguments. program.emit_insn(Insn::VFilter { @@ -919,7 +921,7 @@ pub fn close_loop( // CLOSE t2 // CLOSE t1 for join in join_order.iter().rev() { - let table_index = join.table_no; + let table_index = join.original_idx; let table = &tables[table_index]; let loop_labels = *t_ctx .labels_main_loop diff --git a/core/translate/optimizer/access_method.rs b/core/translate/optimizer/access_method.rs index 8888e7755..e1b926c8e 100644 --- a/core/translate/optimizer/access_method.rs +++ b/core/translate/optimizer/access_method.rs @@ -54,7 +54,7 @@ pub fn find_best_access_method_for_join_order<'a>( maybe_order_target: Option<&OrderTarget>, input_cardinality: f64, ) -> Result> { - let table_no = join_order.last().unwrap().table_no; + let table_no = join_order.last().unwrap().table_id; let mut best_access_method = AccessMethod::new_table_scan(input_cardinality, IterationDirection::Forwards); let rowid_column_idx = rhs_table.columns().iter().position(|c| c.is_rowid_alias); @@ -93,7 +93,7 @@ pub fn find_best_access_method_for_join_order<'a>( let mut all_same_direction = true; let mut all_opposite_direction = true; for i in 0..order_target.0.len().min(index_info.column_count) { - let correct_table = order_target.0[i].table_no == table_no; + let correct_table = order_target.0[i].table_id == table_no; let correct_column = { match &candidate.index { Some(index) => index.columns[i].pos_in_table == order_target.0[i].column_no, diff --git a/core/translate/optimizer/constraints.rs b/core/translate/optimizer/constraints.rs index a7272ce86..04398971f 100644 --- a/core/translate/optimizer/constraints.rs +++ b/core/translate/optimizer/constraints.rs @@ -9,7 +9,7 @@ use crate::{ }, Result, }; -use limbo_sqlite3_parser::ast::{self, SortOrder}; +use limbo_sqlite3_parser::ast::{self, SortOrder, TableInternalId}; use super::cost::ESTIMATED_HARDCODED_ROWS_PER_TABLE; @@ -133,7 +133,8 @@ pub struct ConstraintUseCandidate { #[derive(Debug)] /// A collection of [Constraint]s and their potential [ConstraintUseCandidate]s for a given table. pub struct TableConstraints { - pub table_no: usize, + /// The internal ID of the [TableReference] that these constraints are for. + pub table_id: TableInternalId, /// The constraints for the table, i.e. any [WhereTerm]s that reference columns from this table. pub constraints: Vec, /// Candidates for indexes that may use the constraints to perform a lookup. @@ -177,14 +178,14 @@ pub fn constraints_from_where_clause( let mut constraints = Vec::new(); // For each table, collect all the Constraints and all potential index candidates that may use them. - for (table_no, table_reference) in table_references.iter().enumerate() { + for table_reference in table_references.iter() { let rowid_alias_column = table_reference .columns() .iter() .position(|c| c.is_rowid_alias); let mut cs = TableConstraints { - table_no, + table_id: table_reference.internal_id, constraints: Vec::new(), candidates: available_indexes .get(table_reference.table.get_name()) @@ -212,7 +213,7 @@ pub fn constraints_from_where_clause( // Constraints originating from a LEFT JOIN must always be evaluated in that join's RHS table's loop, // regardless of which tables the constraint references. if let Some(outer_join_tbl) = term.from_outer_join { - if outer_join_tbl != table_no { + if outer_join_tbl != table_reference.internal_id { continue; } } @@ -220,13 +221,13 @@ pub fn constraints_from_where_clause( // If either the LHS or RHS of the constraint is a column from the table, add the constraint. match lhs { ast::Expr::Column { table, column, .. } => { - if *table == table_no { + if *table == table_reference.internal_id { let table_column = &table_reference.table.columns()[*column]; cs.constraints.push(Constraint { where_clause_pos: (i, BinaryExprSide::Rhs), operator, table_col_pos: *column, - lhs_mask: table_mask_from_expr(rhs)?, + lhs_mask: table_mask_from_expr(rhs, table_references)?, selectivity: estimate_selectivity(table_column, operator), }); } @@ -235,14 +236,14 @@ pub fn constraints_from_where_clause( // A rowid alias column must exist for the 'rowid' keyword to be considered a valid reference. // This should be a parse error at an earlier stage of the query compilation, but nevertheless, // we check it here. - if *table == table_no && rowid_alias_column.is_some() { + if *table == table_reference.internal_id && rowid_alias_column.is_some() { let table_column = &table_reference.table.columns()[rowid_alias_column.unwrap()]; cs.constraints.push(Constraint { where_clause_pos: (i, BinaryExprSide::Rhs), operator, table_col_pos: rowid_alias_column.unwrap(), - lhs_mask: table_mask_from_expr(rhs)?, + lhs_mask: table_mask_from_expr(rhs, table_references)?, selectivity: estimate_selectivity(table_column, operator), }); } @@ -251,26 +252,26 @@ pub fn constraints_from_where_clause( }; match rhs { ast::Expr::Column { table, column, .. } => { - if *table == table_no { + if *table == table_reference.internal_id { let table_column = &table_reference.table.columns()[*column]; cs.constraints.push(Constraint { where_clause_pos: (i, BinaryExprSide::Lhs), operator: opposite_cmp_op(operator), table_col_pos: *column, - lhs_mask: table_mask_from_expr(lhs)?, + lhs_mask: table_mask_from_expr(lhs, table_references)?, selectivity: estimate_selectivity(table_column, operator), }); } } ast::Expr::RowId { table, .. } => { - if *table == table_no && rowid_alias_column.is_some() { + if *table == table_reference.internal_id && rowid_alias_column.is_some() { let table_column = &table_reference.table.columns()[rowid_alias_column.unwrap()]; cs.constraints.push(Constraint { where_clause_pos: (i, BinaryExprSide::Lhs), operator: opposite_cmp_op(operator), table_col_pos: rowid_alias_column.unwrap(), - lhs_mask: table_mask_from_expr(lhs)?, + lhs_mask: table_mask_from_expr(lhs, table_references)?, selectivity: estimate_selectivity(table_column, operator), }); } @@ -383,11 +384,11 @@ pub fn usable_constraints_for_join_order<'a>( refs: &'a [ConstraintRef], join_order: &[JoinOrderMember], ) -> &'a [ConstraintRef] { - let table_no = join_order.last().unwrap().table_no; + let table_idx = join_order.last().unwrap().original_idx; let mut usable_until = 0; for cref in refs.iter() { let constraint = &constraints[cref.constraint_vec_pos]; - let other_side_refers_to_self = constraint.lhs_mask.contains_table(table_no); + let other_side_refers_to_self = constraint.lhs_mask.contains_table(table_idx); if other_side_refers_to_self { break; } @@ -395,7 +396,7 @@ pub fn usable_constraints_for_join_order<'a>( join_order .iter() .take(join_order.len() - 1) - .map(|j| j.table_no), + .map(|j| j.original_idx), ); let all_required_tables_are_on_left_side = lhs_mask.contains_all(&constraint.lhs_mask); if !all_required_tables_are_on_left_side { diff --git a/core/translate/optimizer/join.rs b/core/translate/optimizer/join.rs index 8e5174baa..23889b749 100644 --- a/core/translate/optimizer/join.rs +++ b/core/translate/optimizer/join.rs @@ -1,5 +1,7 @@ use std::{cell::RefCell, collections::HashMap}; +use limbo_sqlite3_parser::ast::TableInternalId; + use crate::{ translate::{ optimizer::{cost::Cost, order::plan_satisfies_order_target}, @@ -75,7 +77,7 @@ pub fn join_lhs_and_rhs<'a>( let mut best_access_methods = Vec::with_capacity(join_order.len()); best_access_methods.extend(lhs.map_or(vec![], |l| l.data.clone())); - let rhs_table_number = join_order.last().unwrap().table_no; + let rhs_table_number = join_order.last().unwrap().original_idx; best_access_methods.push((rhs_table_number, access_methods_arena.borrow().len() - 1)); let lhs_mask = lhs.map_or(TableMask::new(), |l| { @@ -163,7 +165,8 @@ pub fn compute_best_join_order<'a>( // Reuse a single mutable join order to avoid allocating join orders per permutation. let mut join_order = Vec::with_capacity(num_tables); join_order.push(JoinOrderMember { - table_no: 0, + table_id: TableInternalId::default(), + original_idx: 0, is_outer: false, }); @@ -187,7 +190,8 @@ pub fn compute_best_join_order<'a>( mask.add_table(i); let table_ref = &table_references[i]; join_order[0] = JoinOrderMember { - table_no: i, + table_id: table_ref.internal_id, + original_idx: i, is_outer: false, }; assert!(join_order.len() == 1); @@ -291,7 +295,8 @@ pub fn compute_best_join_order<'a>( // Build a JoinOrder out of the table bitmask we are now considering. for table_no in lhs.table_numbers() { join_order.push(JoinOrderMember { - table_no, + table_id: table_references[table_no].internal_id, + original_idx: table_no, is_outer: table_references[table_no] .join_info .as_ref() @@ -299,7 +304,8 @@ pub fn compute_best_join_order<'a>( }); } join_order.push(JoinOrderMember { - table_no: rhs_idx, + table_id: table_references[rhs_idx].internal_id, + original_idx: rhs_idx, is_outer: table_references[rhs_idx] .join_info .as_ref() @@ -402,7 +408,8 @@ pub fn compute_naive_left_deep_plan<'a>( .iter() .enumerate() .map(|(i, t)| JoinOrderMember { - table_no: i, + table_id: t.internal_id, + original_idx: i, is_outer: t.join_info.as_ref().map_or(false, |j| j.outer), }) .collect::>(); @@ -489,7 +496,7 @@ fn generate_join_bitmasks(table_number_max_exclusive: usize, how_many: usize) -> mod tests { use std::{rc::Rc, sync::Arc}; - use limbo_sqlite3_parser::ast::{self, Expr, Operator, SortOrder}; + use limbo_sqlite3_parser::ast::{self, Expr, Operator, SortOrder, TableInternalId}; use super::*; use crate::{ @@ -499,6 +506,7 @@ mod tests { plan::{ColumnUsedMask, IterationDirection, JoinInfo, Operation, WhereTerm}, planner::TableMask, }, + vdbe::builder::TableRefIdCounter, }; #[test] @@ -538,7 +546,12 @@ mod tests { /// Test that [compute_best_join_order] returns a table scan access method when the where clause is empty. fn test_compute_best_join_order_single_table_no_indexes() { let t1 = _create_btree_table("test_table", _create_column_list(&["id"], Type::Integer)); - let table_references = vec![_create_table_reference(t1.clone(), None)]; + let mut table_id_counter = TableRefIdCounter::new(); + let table_references = vec![_create_table_reference( + t1.clone(), + None, + table_id_counter.next(), + )]; let available_indexes = HashMap::new(); let where_clause = vec![]; @@ -567,10 +580,15 @@ mod tests { /// Test that [compute_best_join_order] returns a RowidEq access method when the where clause has an EQ constraint on the rowid alias. fn test_compute_best_join_order_single_table_rowid_eq() { let t1 = _create_btree_table("test_table", vec![_create_column_rowid_alias("id")]); - let table_references = vec![_create_table_reference(t1.clone(), None)]; + let mut table_id_counter = TableRefIdCounter::new(); + let table_references = vec![_create_table_reference( + t1.clone(), + None, + table_id_counter.next(), + )]; let where_clause = vec![_create_binary_expr( - _create_column_expr(0, 0, true), // table 0, column 0 (rowid) + _create_column_expr(table_references[0].internal_id, 0, true), // table 0, column 0 (rowid) ast::Operator::Equals, _create_numeric_literal("42"), )]; @@ -611,10 +629,15 @@ mod tests { "test_table", vec![_create_column_of_type("id", Type::Integer)], ); - let table_references = vec![_create_table_reference(t1.clone(), None)]; + let mut table_id_counter = TableRefIdCounter::new(); + let table_references = vec![_create_table_reference( + t1.clone(), + None, + table_id_counter.next(), + )]; let where_clause = vec![_create_binary_expr( - _create_column_expr(0, 0, false), // table 0, column 0 (id) + _create_column_expr(table_references[0].internal_id, 0, false), // table 0, column 0 (id) ast::Operator::Equals, _create_numeric_literal("42"), )]; @@ -670,14 +693,16 @@ mod tests { let t1 = _create_btree_table("table1", _create_column_list(&["id"], Type::Integer)); let t2 = _create_btree_table("table2", _create_column_list(&["id"], Type::Integer)); + let mut table_id_counter = TableRefIdCounter::new(); let mut table_references = vec![ - _create_table_reference(t1.clone(), None), + _create_table_reference(t1.clone(), None, table_id_counter.next()), _create_table_reference( t2.clone(), Some(JoinInfo { outer: false, using: None, }), + table_id_counter.next(), ), ]; @@ -705,9 +730,9 @@ mod tests { // SELECT * FROM table1 JOIN table2 WHERE table1.id = table2.id // expecting table2 to be chosen first due to the index on table1.id let where_clause = vec![_create_binary_expr( - _create_column_expr(TABLE1, 0, false), // table1.id + _create_column_expr(table_references[TABLE1].internal_id, 0, false), // table1.id ast::Operator::Equals, - _create_column_expr(TABLE2, 0, false), // table2.id + _create_column_expr(table_references[TABLE2].internal_id, 0, false), // table2.id )]; let access_methods_arena = RefCell::new(Vec::new()); @@ -769,14 +794,16 @@ mod tests { ], ); + let mut table_id_counter = TableRefIdCounter::new(); let table_references = vec![ - _create_table_reference(table_orders.clone(), None), + _create_table_reference(table_orders.clone(), None, table_id_counter.next()), _create_table_reference( table_customers.clone(), Some(JoinInfo { outer: false, using: None, }), + table_id_counter.next(), ), _create_table_reference( table_order_items.clone(), @@ -784,6 +811,7 @@ mod tests { outer: false, using: None, }), + table_id_counter.next(), ), ]; @@ -857,19 +885,19 @@ mod tests { let where_clause = vec![ // orders.customer_id = customers.id _create_binary_expr( - _create_column_expr(TABLE_NO_ORDERS, 1, false), // orders.customer_id + _create_column_expr(table_references[TABLE_NO_ORDERS].internal_id, 1, false), // orders.customer_id ast::Operator::Equals, - _create_column_expr(TABLE_NO_CUSTOMERS, 0, false), // customers.id + _create_column_expr(table_references[TABLE_NO_CUSTOMERS].internal_id, 0, false), // customers.id ), // orders.id = order_items.order_id _create_binary_expr( - _create_column_expr(TABLE_NO_ORDERS, 0, false), // orders.id + _create_column_expr(table_references[TABLE_NO_ORDERS].internal_id, 0, false), // orders.id ast::Operator::Equals, - _create_column_expr(TABLE_NO_ORDER_ITEMS, 1, false), // order_items.order_id + _create_column_expr(table_references[TABLE_NO_ORDER_ITEMS].internal_id, 1, false), // order_items.order_id ), // customers.id = 42 _create_binary_expr( - _create_column_expr(TABLE_NO_CUSTOMERS, 0, false), // customers.id + _create_column_expr(table_references[TABLE_NO_CUSTOMERS].internal_id, 0, false), // customers.id ast::Operator::Equals, _create_numeric_literal("42"), ), @@ -946,14 +974,16 @@ mod tests { let t2 = _create_btree_table("t2", _create_column_list(&["id", "foo"], Type::Integer)); let t3 = _create_btree_table("t3", _create_column_list(&["id", "foo"], Type::Integer)); + let mut table_id_counter = TableRefIdCounter::new(); let mut table_references = vec![ - _create_table_reference(t1.clone(), None), + _create_table_reference(t1.clone(), None, table_id_counter.next()), _create_table_reference( t2.clone(), Some(JoinInfo { outer: false, using: None, }), + table_id_counter.next(), ), _create_table_reference( t3.clone(), @@ -961,19 +991,20 @@ mod tests { outer: false, using: None, }), + table_id_counter.next(), ), ]; let where_clause = vec![ // t2.foo = 42 (equality filter, more selective) _create_binary_expr( - _create_column_expr(1, 1, false), // table 1, column 1 (foo) + _create_column_expr(table_references[1].internal_id, 1, false), // table 1, column 1 (foo) ast::Operator::Equals, _create_numeric_literal("42"), ), // t1.foo > 10 (inequality filter, less selective) _create_binary_expr( - _create_column_expr(0, 1, false), // table 0, column 1 (foo) + _create_column_expr(table_references[0].internal_id, 1, false), // table 0, column 1 (foo) ast::Operator::Greater, _create_numeric_literal("10"), ), @@ -1043,19 +1074,13 @@ mod tests { }) .collect(); - let mut where_clause = vec![]; - - // Add join conditions between fact and each dimension table - for i in 0..NUM_DIM_TABLES { - where_clause.push(_create_binary_expr( - _create_column_expr(FACT_TABLE_IDX, i + 1, false), // fact.dimX_id - ast::Operator::Equals, - _create_column_expr(i, 0, true), // dimX.id - )); - } - + let mut table_id_counter = TableRefIdCounter::new(); let table_references = { - let mut refs = vec![_create_table_reference(dim_tables[0].clone(), None)]; + let mut refs = vec![_create_table_reference( + dim_tables[0].clone(), + None, + table_id_counter.next(), + )]; refs.extend(dim_tables.iter().skip(1).map(|t| { _create_table_reference( t.clone(), @@ -1063,6 +1088,7 @@ mod tests { outer: false, using: None, }), + table_id_counter.next(), ) })); refs.push(_create_table_reference( @@ -1071,10 +1097,24 @@ mod tests { outer: false, using: None, }), + table_id_counter.next(), )); refs }; + let mut where_clause = vec![]; + + // Add join conditions between fact and each dimension table + for i in 0..NUM_DIM_TABLES { + let internal_id_fact = table_references[FACT_TABLE_IDX].internal_id; + let internal_id_other = table_references[i].internal_id; + where_clause.push(_create_binary_expr( + _create_column_expr(internal_id_fact, i + 1, false), // fact.dimX_id + ast::Operator::Equals, + _create_column_expr(internal_id_other, 0, true), // dimX.id + )); + } + let access_methods_arena = RefCell::new(Vec::new()); let available_indexes = HashMap::new(); let table_constraints = @@ -1140,19 +1180,22 @@ mod tests { let available_indexes = HashMap::new(); + let mut table_id_counter = TableRefIdCounter::new(); // Create table references let table_references: Vec<_> = tables .iter() - .map(|t| _create_table_reference(t.clone(), None)) + .map(|t| _create_table_reference(t.clone(), None, table_id_counter.next())) .collect(); // Create where clause linking each table to the next let mut where_clause = Vec::new(); for i in 0..NUM_TABLES - 1 { + let internal_id_left = table_references[i].internal_id; + let internal_id_right = table_references[i + 1].internal_id; where_clause.push(_create_binary_expr( - _create_column_expr(i, 1, false), // ti.next_id + _create_column_expr(internal_id_left, 1, false), // ti.next_id ast::Operator::Equals, - _create_column_expr(i + 1, 0, true), // t(i+1).id + _create_column_expr(internal_id_right, 0, true), // t(i+1).id )); } @@ -1258,6 +1301,7 @@ mod tests { fn _create_table_reference( table: Rc, join_info: Option, + internal_id: TableInternalId, ) -> TableReference { let name = table.name.clone(); TableReference { @@ -1267,13 +1311,14 @@ mod tests { index: None, }, identifier: name, + internal_id, join_info, col_used_mask: ColumnUsedMask::new(), } } /// Creates a column expression - fn _create_column_expr(table: usize, column: usize, is_rowid_alias: bool) -> Expr { + fn _create_column_expr(table: TableInternalId, column: usize, is_rowid_alias: bool) -> Expr { Expr::Column { database: None, table, diff --git a/core/translate/optimizer/lift_common_subexpressions.rs b/core/translate/optimizer/lift_common_subexpressions.rs index 80e87cc8e..b7302cb7e 100644 --- a/core/translate/optimizer/lift_common_subexpressions.rs +++ b/core/translate/optimizer/lift_common_subexpressions.rs @@ -188,7 +188,7 @@ fn rebuild_or_expr_from_list(mut operands: Vec) -> Expr { mod tests { use super::*; use crate::translate::plan::WhereTerm; - use limbo_sqlite3_parser::ast::{self, Expr, Literal, Operator}; + use limbo_sqlite3_parser::ast::{self, Expr, Literal, Operator, TableInternalId}; #[test] fn test_lift_common_subexpressions() -> Result<()> { @@ -200,7 +200,7 @@ mod tests { let a_expr = Expr::Binary( Box::new(Expr::Column { database: None, - table: 0, + table: TableInternalId::default(), column: 0, is_rowid_alias: false, }), @@ -211,7 +211,7 @@ mod tests { let b_expr = Expr::Binary( Box::new(Expr::Column { database: None, - table: 0, + table: TableInternalId::default(), column: 1, is_rowid_alias: false, }), @@ -222,7 +222,7 @@ mod tests { let x_expr = Expr::Binary( Box::new(Expr::Column { database: None, - table: 0, + table: TableInternalId::default(), column: 2, is_rowid_alias: false, }), @@ -233,7 +233,7 @@ mod tests { let y_expr = Expr::Binary( Box::new(Expr::Column { database: None, - table: 0, + table: TableInternalId::default(), column: 3, is_rowid_alias: false, }), @@ -293,7 +293,7 @@ mod tests { let a_expr = Expr::Binary( Box::new(Expr::Column { database: None, - table: 0, + table: TableInternalId::default(), column: 0, is_rowid_alias: false, }), @@ -304,7 +304,7 @@ mod tests { let x_expr = Expr::Binary( Box::new(Expr::Column { database: None, - table: 0, + table: TableInternalId::default(), column: 1, is_rowid_alias: false, }), @@ -315,7 +315,7 @@ mod tests { let y_expr = Expr::Binary( Box::new(Expr::Column { database: None, - table: 0, + table: TableInternalId::default(), column: 2, is_rowid_alias: false, }), @@ -326,7 +326,7 @@ mod tests { let z_expr = Expr::Binary( Box::new(Expr::Column { database: None, - table: 0, + table: TableInternalId::default(), column: 3, is_rowid_alias: false, }), @@ -393,7 +393,7 @@ mod tests { let x_expr = Expr::Binary( Box::new(Expr::Column { database: None, - table: 0, + table: TableInternalId::default(), column: 0, is_rowid_alias: false, }), @@ -404,7 +404,7 @@ mod tests { let y_expr = Expr::Binary( Box::new(Expr::Column { database: None, - table: 0, + table: TableInternalId::default(), column: 1, is_rowid_alias: false, }), @@ -445,7 +445,7 @@ mod tests { let a_expr = Expr::Binary( Box::new(Expr::Column { database: None, - table: 0, + table: TableInternalId::default(), column: 0, is_rowid_alias: false, }), @@ -456,7 +456,7 @@ mod tests { let x_expr = Expr::Binary( Box::new(Expr::Column { database: None, - table: 0, + table: TableInternalId::default(), column: 1, is_rowid_alias: false, }), @@ -467,7 +467,7 @@ mod tests { let y_expr = Expr::Binary( Box::new(Expr::Column { database: None, - table: 0, + table: TableInternalId::default(), column: 2, is_rowid_alias: false, }), @@ -487,7 +487,7 @@ mod tests { let mut where_clause = vec![WhereTerm { expr: or_expr, - from_outer_join: Some(0), // Set from_outer_join + from_outer_join: Some(TableInternalId::default()), // Set from_outer_join consumed: false, }]; @@ -507,9 +507,15 @@ mod tests { Box::new(ast::Expr::Parenthesized(vec![y_expr])) ) ); - assert_eq!(nonconsumed_terms[0].from_outer_join, Some(0)); + assert_eq!( + nonconsumed_terms[0].from_outer_join, + Some(TableInternalId::default()) + ); assert_eq!(nonconsumed_terms[1].expr, a_expr); - assert_eq!(nonconsumed_terms[1].from_outer_join, Some(0)); + assert_eq!( + nonconsumed_terms[1].from_outer_join, + Some(TableInternalId::default()) + ); Ok(()) } @@ -523,7 +529,7 @@ mod tests { let single_expr = Expr::Binary( Box::new(Expr::Column { database: None, - table: 0, + table: TableInternalId::default(), column: 0, is_rowid_alias: false, }), @@ -559,7 +565,7 @@ mod tests { Expr::Binary( Box::new(Expr::Column { database: None, - table: 0, + table: TableInternalId::default(), column: i, is_rowid_alias: false, }), diff --git a/core/translate/optimizer/mod.rs b/core/translate/optimizer/mod.rs index c4fe27956..496d5ad38 100644 --- a/core/translate/optimizer/mod.rs +++ b/core/translate/optimizer/mod.rs @@ -211,7 +211,8 @@ fn optimize_table_access( let best_join_order: Vec = best_table_numbers .into_iter() .map(|table_number| JoinOrderMember { - table_no: table_number, + table_id: table_references[table_number].internal_id, + original_idx: table_number, is_outer: table_references[table_number] .join_info .as_ref() @@ -220,20 +221,20 @@ fn optimize_table_access( .collect(); // Mutate the Operations in `table_references` to use the selected access methods. for (i, join_order_member) in best_join_order.iter().enumerate() { - let table_number = join_order_member.table_no; + let table_idx = join_order_member.original_idx; let access_method = &access_methods_arena.borrow()[best_access_methods[i]]; if access_method.is_scan() { let is_leftmost_table = i == 0; let uses_index = access_method.index.is_some(); let source_table_is_from_clause_subquery = matches!( - &table_references[table_number].table, + &table_references[table_idx].table, Table::FromClauseSubquery(_) ); let try_to_build_ephemeral_index = !is_leftmost_table && !uses_index && !source_table_is_from_clause_subquery; if !try_to_build_ephemeral_index { - table_references[table_number].op = Operation::Scan { + table_references[table_idx].op = Operation::Scan { iter_dir: access_method.iter_dir, index: access_method.index.clone(), }; @@ -243,9 +244,9 @@ fn optimize_table_access( // Try to construct an ephemeral index since it's going to be better than a scan. let table_constraints = constraints_per_table .iter() - .find(|c| c.table_no == table_number); + .find(|c| c.table_id == join_order_member.table_id); let Some(table_constraints) = table_constraints else { - table_references[table_number].op = Operation::Scan { + table_references[table_idx].op = Operation::Scan { iter_dir: access_method.iter_dir, index: access_method.index.clone(), }; @@ -264,20 +265,19 @@ fn optimize_table_access( &best_join_order[..=i], ); if usable_constraint_refs.is_empty() { - table_references[table_number].op = Operation::Scan { + table_references[table_idx].op = Operation::Scan { iter_dir: access_method.iter_dir, index: access_method.index.clone(), }; continue; } let ephemeral_index = ephemeral_index_build( - &table_references[table_number], - table_number, + &table_references[table_idx], &table_constraints.constraints, &usable_constraint_refs, ); let ephemeral_index = Arc::new(ephemeral_index); - table_references[table_number].op = Operation::Search(Search::Seek { + table_references[table_idx].op = Operation::Search(Search::Seek { index: Some(ephemeral_index), seek_def: build_seek_def_from_constraints( &table_constraints.constraints, @@ -291,7 +291,7 @@ fn optimize_table_access( assert!(!constraint_refs.is_empty()); for cref in constraint_refs.iter() { let constraint = - &constraints_per_table[table_number].constraints[cref.constraint_vec_pos]; + &constraints_per_table[table_idx].constraints[cref.constraint_vec_pos]; assert!( !where_clause[constraint.where_clause_pos.0].consumed, "trying to consume a where clause term twice: {:?}", @@ -300,10 +300,10 @@ fn optimize_table_access( where_clause[constraint.where_clause_pos.0].consumed = true; } if let Some(index) = &access_method.index { - table_references[table_number].op = Operation::Search(Search::Seek { + table_references[table_idx].op = Operation::Search(Search::Seek { index: Some(index.clone()), seek_def: build_seek_def_from_constraints( - &constraints_per_table[table_number].constraints, + &constraints_per_table[table_idx].constraints, &constraint_refs, access_method.iter_dir, where_clause, @@ -316,16 +316,16 @@ fn optimize_table_access( "expected exactly one constraint for rowid seek, got {:?}", constraint_refs ); - let constraint = &constraints_per_table[table_number].constraints + let constraint = &constraints_per_table[table_idx].constraints [constraint_refs[0].constraint_vec_pos]; - table_references[table_number].op = match constraint.operator { + table_references[table_idx].op = match constraint.operator { ast::Operator::Equals => Operation::Search(Search::RowidEq { cmp_expr: constraint.get_constraining_expr(where_clause), }), _ => Operation::Search(Search::Seek { index: None, seek_def: build_seek_def_from_constraints( - &constraints_per_table[table_number].constraints, + &constraints_per_table[table_idx].constraints, &constraint_refs, access_method.iter_dir, where_clause, @@ -505,7 +505,7 @@ impl Optimizable for ast::Expr { return true; } - let table_ref = &tables[*table]; + let table_ref = tables.iter().find(|t| t.internal_id == *table).unwrap(); let columns = table_ref.columns(); let column = &columns[*column]; return column.primary_key || column.notnull; @@ -747,7 +747,6 @@ impl Optimizable for ast::Expr { fn ephemeral_index_build( table_reference: &TableReference, - table_index: usize, constraints: &[Constraint], constraint_refs: &[ConstraintRef], ) -> Index { @@ -785,7 +784,7 @@ fn ephemeral_index_build( name: format!( "ephemeral_{}_{}", table_reference.table.get_name(), - table_index + table_reference.internal_id ), columns: ephemeral_columns, unique: false, diff --git a/core/translate/optimizer/order.rs b/core/translate/optimizer/order.rs index 7f153b10c..51b05a5a5 100644 --- a/core/translate/optimizer/order.rs +++ b/core/translate/optimizer/order.rs @@ -1,6 +1,6 @@ use std::cell::RefCell; -use limbo_sqlite3_parser::ast::{self, SortOrder}; +use limbo_sqlite3_parser::ast::{self, SortOrder, TableInternalId}; use crate::{ translate::plan::{GroupBy, IterationDirection, TableReference}, @@ -12,7 +12,7 @@ use super::{access_method::AccessMethod, join::JoinN}; #[derive(Debug, PartialEq, Clone)] /// A convenience struct for representing a (table_no, column_no, [SortOrder]) tuple. pub struct ColumnOrder { - pub table_no: usize, + pub table_id: TableInternalId, pub column_no: usize, pub order: SortOrder, } @@ -51,7 +51,7 @@ impl OrderTarget { unreachable!(); }; ColumnOrder { - table_no: *table, + table_id: *table, column_no: *column, order, } @@ -162,10 +162,10 @@ pub fn plan_satisfies_order_target( ) -> bool { let mut target_col_idx = 0; let num_cols_in_order_target = order_target.0.len(); - for (table_no, access_method_index) in plan.data.iter() { + for (table_index, access_method_index) in plan.data.iter() { let target_col = &order_target.0[target_col_idx]; - let table_ref = &table_references[*table_no]; - let correct_table = target_col.table_no == *table_no; + let table_ref = &table_references[*table_index]; + let correct_table = target_col.table_id == table_ref.internal_id; if !correct_table { return false; } diff --git a/core/translate/order_by.rs b/core/translate/order_by.rs index 7e1e71232..7047960d4 100644 --- a/core/translate/order_by.rs +++ b/core/translate/order_by.rs @@ -54,7 +54,10 @@ pub fn init_order_by( .map(|(expr, _)| match expr { ast::Expr::Collate(_, collation_name) => CollationSeq::new(collation_name).map(Some), ast::Expr::Column { table, column, .. } => { - let table_reference = referenced_tables.get(*table).unwrap(); + let table_reference = referenced_tables + .iter() + .find(|t| t.internal_id == *table) + .unwrap(); let Some(table_column) = table_reference.table.get_column_at(*column) else { crate::bail_parse_error!("column index out of bounds"); diff --git a/core/translate/plan.rs b/core/translate/plan.rs index 15630b9b8..4ee6b8047 100644 --- a/core/translate/plan.rs +++ b/core/translate/plan.rs @@ -21,6 +21,8 @@ use crate::{ }; use crate::{schema::Type, types::SeekOp, util::can_pushdown_predicate}; +use limbo_sqlite3_parser::ast::TableInternalId; + use super::{emitter::OperationMode, planner::determine_where_to_eval_term, schema::ParseSchema}; #[derive(Debug, Clone)] @@ -38,11 +40,18 @@ impl ResultSetColumn { } match &self.expr { ast::Expr::Column { table, column, .. } => { - tables[*table].columns()[*column].name.as_deref() + let table_ref = tables.iter().find(|t| t.internal_id == *table).unwrap(); + table_ref + .table + .get_column_at(*column) + .unwrap() + .name + .as_deref() } ast::Expr::RowId { table, .. } => { // If there is a rowid alias column, use its name - if let Table::BTree(table) = &tables[*table].table { + let table_ref = tables.iter().find(|t| t.internal_id == *table).unwrap(); + if let Table::BTree(table) = &table_ref.table { if let Some(rowid_alias_column) = table.get_rowid_alias_column() { if let Some(name) = &rowid_alias_column.1.name { return Some(name); @@ -77,11 +86,12 @@ pub struct GroupBy { pub struct WhereTerm { /// The original condition expression. pub expr: ast::Expr, - /// Is this condition originally from an OUTER JOIN, and which table number in the plan's [TableReference] vector? - /// If so, we need to evaluate it at the loop of the right table in that JOIN, + /// Is this condition originally from an OUTER JOIN, and if so, what is the internal ID of the [TableReference] that it came from? + /// The ID is always the right-hand-side table of the OUTER JOIN. + /// If `from_outer_join` is Some, we need to evaluate this term at the loop of the the corresponding table, /// regardless of which tables it references. /// We also cannot e.g. short circuit the entire query in the optimizer if the condition is statically false. - pub from_outer_join: Option, + pub from_outer_join: Option, /// Whether the condition has been consumed by the optimizer in some way, and it should not be evaluated /// in the normal place where WHERE terms are evaluated. /// A term may have been consumed e.g. if: @@ -159,8 +169,9 @@ fn to_ext_constraint_op(op: &Operator) -> Option { /// the filtration in the vdbe layer. pub fn convert_where_to_vtab_constraint( term: &WhereTerm, - table_index: usize, + table_idx: usize, pred_idx: usize, + join_order: &[JoinOrderMember], ) -> Result> { if term.from_outer_join.is_some() { return Ok(None); @@ -168,7 +179,8 @@ pub fn convert_where_to_vtab_constraint( let Expr::Binary(lhs, op, rhs) = &term.expr else { return Ok(None); }; - let expr_is_ready = |e: &Expr| -> Result { can_pushdown_predicate(e, table_index) }; + let expr_is_ready = + |e: &Expr| -> Result { can_pushdown_predicate(e, table_idx, join_order) }; let (vcol_idx, op_for_vtab, usable, is_rhs) = match (&**lhs, &**rhs) { ( Expr::Column { @@ -183,23 +195,37 @@ pub fn convert_where_to_vtab_constraint( }, ) => { // one side must be the virtual table - let vtab_on_l = *tbl_l == table_index; - let vtab_on_r = *tbl_r == table_index; + let tbl_l_idx = join_order + .iter() + .position(|j| j.table_id == *tbl_l) + .unwrap(); + let tbl_r_idx = join_order + .iter() + .position(|j| j.table_id == *tbl_r) + .unwrap(); + let vtab_on_l = tbl_l_idx == table_idx; + let vtab_on_r = tbl_r_idx == table_idx; if vtab_on_l == vtab_on_r { return Ok(None); // either both or none -> not convertible } if vtab_on_l { // vtab on left side: operator unchanged - let usable = *tbl_r < table_index; // usable if the other table is already positioned + let usable = tbl_r_idx < table_idx; // usable if the other table is already positioned (col_l, op, usable, false) } else { // vtab on right side of the expr: reverse operator - let usable = *tbl_l < table_index; + let usable = tbl_l_idx < table_idx; (col_r, &reverse_operator(op).unwrap_or(*op), usable, true) } } - (Expr::Column { table, column, .. }, other) if *table == table_index => { + (Expr::Column { table, column, .. }, other) + if join_order + .iter() + .position(|j| j.table_id == *table) + .unwrap() + == table_idx => + { ( column, op, @@ -207,12 +233,20 @@ pub fn convert_where_to_vtab_constraint( false, ) } - (other, Expr::Column { table, column, .. }) if *table == table_index => ( - column, - &reverse_operator(op).unwrap_or(*op), - expr_is_ready(other)?, - true, - ), + (other, Expr::Column { table, column, .. }) + if join_order + .iter() + .position(|j| j.table_id == *table) + .unwrap() + == table_idx => + { + ( + column, + &reverse_operator(op).unwrap_or(*op), + expr_is_ready(other)?, + true, + ) + } _ => return Ok(None), // does not involve the virtual table at all }; @@ -289,8 +323,11 @@ pub enum SelectQueryType { #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct JoinOrderMember { - /// The index of the table in the plan's vector of [TableReference] - pub table_no: usize, + /// The internal ID of the[TableReference] + pub table_id: TableInternalId, + /// The index of the table in the original join order. + /// This is used to index into e.g. [SelectPlan::table_references] + pub original_idx: usize, /// Whether this member is the right side of an OUTER JOIN pub is_outer: bool, } @@ -298,7 +335,8 @@ pub struct JoinOrderMember { impl Default for JoinOrderMember { fn default() -> Self { Self { - table_no: 0, + table_id: TableInternalId::default(), + original_idx: 0, is_outer: false, } } @@ -524,7 +562,7 @@ pub enum IterationDirection { } pub fn select_star(tables: &[TableReference], out_columns: &mut Vec) { - for (current_table_index, table) in tables.iter().enumerate() { + for table in tables.iter() { let maybe_using_cols = table .join_info .as_ref() @@ -551,7 +589,7 @@ pub fn select_star(tables: &[TableReference], out_columns: &mut Vec, /// Bitmask of columns that are referenced in the query. @@ -671,7 +711,12 @@ impl TableReference { } /// Creates a new TableReference for a subquery. - pub fn new_subquery(identifier: String, plan: SelectPlan, join_info: Option) -> Self { + pub fn new_subquery( + identifier: String, + plan: SelectPlan, + join_info: Option, + internal_id: TableInternalId, + ) -> Self { let columns = plan .result_columns .iter() @@ -701,6 +746,7 @@ impl TableReference { }, table, identifier, + internal_id, join_info, col_used_mask: ColumnUsedMask::new(), } diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 525a612fa..508fbf345 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -13,11 +13,11 @@ use crate::{ schema::{Schema, Table}, translate::expr::walk_expr_mut, util::{exprs_are_equivalent, normalize_ident, vtable_args}, - vdbe::BranchOffset, + vdbe::{builder::TableRefIdCounter, BranchOffset}, Result, }; use limbo_sqlite3_parser::ast::{ - self, Expr, FromClause, JoinType, Limit, Materialized, UnaryOperator, With, + self, Expr, FromClause, JoinType, Limit, Materialized, TableInternalId, UnaryOperator, With, }; pub const ROWID: &str = "rowid"; @@ -110,7 +110,9 @@ pub fn bind_column_references( if !referenced_tables.is_empty() { if let Some(row_id_expr) = - parse_row_id(&normalized_id, 0, || referenced_tables.len() != 1)? + parse_row_id(&normalized_id, referenced_tables[0].internal_id, || { + referenced_tables.len() != 1 + })? { *expr = row_id_expr; @@ -135,7 +137,7 @@ pub fn bind_column_references( if let Some((tbl_idx, col_idx, is_rowid_alias)) = match_result { *expr = Expr::Column { database: None, // TODO: support different databases - table: tbl_idx, + table: referenced_tables[tbl_idx].internal_id, column: col_idx, is_rowid_alias, }; @@ -167,7 +169,11 @@ pub fn bind_column_references( let tbl_idx = matching_tbl_idx.unwrap(); let normalized_id = normalize_ident(id.0.as_str()); - if let Some(row_id_expr) = parse_row_id(&normalized_id, tbl_idx, || false)? { + if let Some(row_id_expr) = parse_row_id( + &normalized_id, + referenced_tables[tbl_idx].internal_id, + || false, + )? { *expr = row_id_expr; return Ok(()); @@ -186,7 +192,7 @@ pub fn bind_column_references( .unwrap(); *expr = Expr::Column { database: None, // TODO: support different databases - table: tbl_idx, + table: referenced_tables[tbl_idx].internal_id, column: col_idx.unwrap(), is_rowid_alias: col.is_rowid_alias, }; @@ -203,6 +209,7 @@ fn parse_from_clause_table<'a>( table: ast::SelectTable, scope: &mut Scope<'a>, syms: &SymbolTable, + table_ref_counter: &mut TableRefIdCounter, ) -> Result<()> { match table { ast::SelectTable::Table(qualified_name, maybe_alias, _) => { @@ -215,8 +222,12 @@ fn parse_from_clause_table<'a>( { // CTE can be rewritten as a subquery. // TODO: find a way not to clone the CTE plan here. - let cte_table = - TableReference::new_subquery(cte.name.clone(), cte.plan.clone(), None); + let cte_table = TableReference::new_subquery( + cte.name.clone(), + cte.plan.clone(), + None, + table_ref_counter.next(), + ); scope.tables.push(cte_table); return Ok(()); }; @@ -244,6 +255,7 @@ fn parse_from_clause_table<'a>( }, table: tbl_ref, identifier: alias.unwrap_or(normalized_qualified_name), + internal_id: table_ref_counter.next(), join_info: None, col_used_mask: ColumnUsedMask::new(), }); @@ -267,8 +279,12 @@ fn parse_from_clause_table<'a>( .find(|cte| cte.name == normalized_qualified_name) { // TODO: avoid cloning the CTE plan here. - let cte_table = - TableReference::new_subquery(cte.name.clone(), cte.plan.clone(), None); + let cte_table = TableReference::new_subquery( + cte.name.clone(), + cte.plan.clone(), + None, + table_ref_counter.next(), + ); scope.tables.push(cte_table); return Ok(()); } @@ -278,7 +294,7 @@ fn parse_from_clause_table<'a>( } ast::SelectTable::Select(subselect, maybe_alias) => { let Plan::Select(mut subplan) = - prepare_select_plan(schema, *subselect, syms, Some(scope))? + prepare_select_plan(schema, *subselect, syms, Some(scope), table_ref_counter)? else { crate::bail_parse_error!("Only non-compound SELECT queries are currently supported in FROM clause subqueries"); }; @@ -293,9 +309,12 @@ fn parse_from_clause_table<'a>( ast::As::Elided(id) => id.0.clone(), }) .unwrap_or(format!("subquery_{}", cur_table_index)); - scope - .tables - .push(TableReference::new_subquery(identifier, subplan, None)); + scope.tables.push(TableReference::new_subquery( + identifier, + subplan, + None, + table_ref_counter.next(), + )); Ok(()) } ast::SelectTable::TableCall(qualified_name, maybe_args, maybe_alias) => { @@ -328,6 +347,7 @@ fn parse_from_clause_table<'a>( join_info: None, table: Table::Virtual(vtab), identifier: alias, + internal_id: table_ref_counter.next(), col_used_mask: ColumnUsedMask::new(), }); @@ -384,6 +404,7 @@ pub fn parse_from<'a>( with: Option, out_where_clause: &mut Vec, outer_scope: Option<&'a Scope<'a>>, + table_ref_counter: &mut TableRefIdCounter, ) -> Result> { if from.as_ref().and_then(|f| f.select.as_ref()).is_none() { return Ok(vec![]); @@ -429,7 +450,8 @@ pub fn parse_from<'a>( } // CTE can refer to other CTEs that came before it, plus any schema tables or tables in the outer scope. - let cte_plan = prepare_select_plan(schema, *cte.select, syms, Some(&scope))?; + let cte_plan = + prepare_select_plan(schema, *cte.select, syms, Some(&scope), table_ref_counter)?; let Plan::Select(mut cte_plan) = cte_plan else { crate::bail_parse_error!("Only SELECT queries are currently supported in CTEs"); }; @@ -448,10 +470,17 @@ pub fn parse_from<'a>( let mut from_owned = std::mem::take(&mut from).unwrap(); let select_owned = *std::mem::take(&mut from_owned.select).unwrap(); let joins_owned = std::mem::take(&mut from_owned.joins).unwrap_or_default(); - parse_from_clause_table(schema, select_owned, &mut scope, syms)?; + parse_from_clause_table(schema, select_owned, &mut scope, syms, table_ref_counter)?; for join in joins_owned.into_iter() { - parse_join(schema, join, syms, &mut scope, out_where_clause)?; + parse_join( + schema, + join, + syms, + &mut scope, + out_where_clause, + table_ref_counter, + )?; } Ok(scope.tables) @@ -493,11 +522,11 @@ pub fn determine_where_to_eval_term( term: &WhereTerm, join_order: &[JoinOrderMember], ) -> Result { - if let Some(table_no) = term.from_outer_join { + if let Some(table_id) = term.from_outer_join { return Ok(EvalAt::Loop( join_order .iter() - .position(|t| t.table_no == table_no) + .position(|t| t.table_id == table_id) .unwrap_or(usize::MAX), )); } @@ -523,6 +552,11 @@ pub fn determine_where_to_eval_term( /// [TableMask] helps determine: /// - Which tables are referenced in a constraint /// - When a constraint can be applied as a join condition (all referenced tables must be on the left side of the table being joined) +/// +/// Note that although [TableReference]s contain an internal ID as well, in join order optimization +/// the [TableMask] refers to the index of the table in the original join order, not the internal ID. +/// This is simply because we want to represent the tables as a contiguous set of bits, and the internal ID +/// might not be contiguous after e.g. subquery unnesting or other transformations. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct TableMask(pub u128); @@ -597,12 +631,19 @@ impl TableMask { /// Returns a [TableMask] representing the tables referenced in the given expression. /// Used in the optimizer for constraint analysis. -pub fn table_mask_from_expr(top_level_expr: &Expr) -> Result { +pub fn table_mask_from_expr( + top_level_expr: &Expr, + table_references: &[TableReference], +) -> Result { let mut mask = TableMask::new(); walk_expr(top_level_expr, &mut |expr: &Expr| -> Result<()> { match expr { Expr::Column { table, .. } | Expr::RowId { table, .. } => { - mask.add_table(*table); + let table_idx = table_references + .iter() + .position(|t| t.internal_id == *table) + .expect("table not found in table_references"); + mask.add_table(table_idx); } _ => {} } @@ -622,7 +663,7 @@ pub fn determine_where_to_eval_expr<'a>( Expr::Column { table, .. } | Expr::RowId { table, .. } => { let join_idx = join_order .iter() - .position(|t| t.table_no == *table) + .position(|t| t.table_id == *table) .unwrap_or(usize::MAX); eval_at = eval_at.max(EvalAt::Loop(join_idx)); } @@ -640,6 +681,7 @@ fn parse_join<'a>( syms: &SymbolTable, scope: &mut Scope<'a>, out_where_clause: &mut Vec, + table_ref_counter: &mut TableRefIdCounter, ) -> Result<()> { let ast::JoinedSelectTable { operator: join_operator, @@ -647,7 +689,7 @@ fn parse_join<'a>( constraint, } = join; - parse_from_clause_table(schema, table, scope, syms)?; + parse_from_clause_table(schema, table, scope, syms, table_ref_counter)?; let (outer, natural) = match join_operator { ast::JoinOperator::TypedJoin(Some(join_type)) => { @@ -717,7 +759,7 @@ fn parse_join<'a>( out_where_clause.push(WhereTerm { expr: pred, from_outer_join: if outer { - Some(scope.tables.len() - 1) + Some(scope.tables.last().unwrap().internal_id) } else { None }, @@ -744,7 +786,7 @@ fn parse_join<'a>( .as_ref() .map_or(false, |name| *name == name_normalized) }) - .map(|(idx, col)| (left_table_idx, idx, col)); + .map(|(idx, col)| (left_table_idx, left_table.internal_id, idx, col)); if left_col.is_some() { break; } @@ -766,19 +808,19 @@ fn parse_join<'a>( distinct_name.0 ); } - let (left_table_idx, left_col_idx, left_col) = left_col.unwrap(); + let (left_table_idx, left_table_id, left_col_idx, left_col) = left_col.unwrap(); let (right_col_idx, right_col) = right_col.unwrap(); let expr = Expr::Binary( Box::new(Expr::Column { database: None, - table: left_table_idx, + table: left_table_id, column: left_col_idx, is_rowid_alias: left_col.is_rowid_alias, }), ast::Operator::Equals, Box::new(Expr::Column { database: None, - table: cur_table_idx, + table: right_table.internal_id, column: right_col_idx, is_rowid_alias: right_col.is_rowid_alias, }), @@ -790,7 +832,11 @@ fn parse_join<'a>( right_table.mark_column_used(right_col_idx); out_where_clause.push(WhereTerm { expr, - from_outer_join: if outer { Some(cur_table_idx) } else { None }, + from_outer_join: if outer { + Some(right_table.internal_id) + } else { + None + }, consumed: false, }); } @@ -858,7 +904,11 @@ pub fn break_predicate_at_and_boundaries(predicate: Expr, out_predicates: &mut V } } -fn parse_row_id(column_name: &str, table_id: usize, fn_check: F) -> Result> +fn parse_row_id( + column_name: &str, + table_id: TableInternalId, + fn_check: F, +) -> Result> where F: FnOnce() -> bool, { diff --git a/core/translate/select.rs b/core/translate/select.rs index 300613cbd..bf2112829 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -10,7 +10,7 @@ use crate::translate::planner::{ parse_where, resolve_aggregates, }; use crate::util::normalize_ident; -use crate::vdbe::builder::{ProgramBuilderOpts, QueryMode}; +use crate::vdbe::builder::{ProgramBuilderOpts, QueryMode, TableRefIdCounter}; use crate::vdbe::insn::Insn; use crate::SymbolTable; use crate::{schema::Schema, vdbe::builder::ProgramBuilder, Result}; @@ -24,7 +24,13 @@ pub fn translate_select( syms: &SymbolTable, mut program: ProgramBuilder, ) -> Result { - let mut select_plan = prepare_select_plan(schema, select, syms, None)?; + let mut select_plan = prepare_select_plan( + schema, + select, + syms, + None, + &mut program.table_reference_counter, + )?; optimize_plan(&mut select_plan, schema)?; let opts = match &select_plan { Plan::Select(select) => ProgramBuilderOpts { @@ -64,6 +70,7 @@ pub fn prepare_select_plan<'a>( mut select: ast::Select, syms: &SymbolTable, outer_scope: Option<&'a Scope<'a>>, + table_ref_counter: &mut TableRefIdCounter, ) -> Result { let compounds = select.body.compounds.take(); match compounds { @@ -77,6 +84,7 @@ pub fn prepare_select_plan<'a>( select.with.take(), syms, outer_scope, + table_ref_counter, )?)) } Some(compounds) => { @@ -88,6 +96,7 @@ pub fn prepare_select_plan<'a>( None, syms, outer_scope, + table_ref_counter, )?; let mut rest = Vec::with_capacity(compounds.len()); for CompoundSelect { select, operator } in compounds { @@ -95,8 +104,16 @@ pub fn prepare_select_plan<'a>( if operator != ast::CompoundOperator::UnionAll { crate::bail_parse_error!("only UNION ALL is supported for compound SELECTs"); } - let plan = - prepare_one_select_plan(schema, *select, None, None, None, syms, outer_scope)?; + let plan = prepare_one_select_plan( + schema, + *select, + None, + None, + None, + syms, + outer_scope, + table_ref_counter, + )?; rest.push((plan, operator)); } // Ensure all subplans have same number of result columns @@ -144,6 +161,7 @@ fn prepare_one_select_plan<'a>( with: Option, syms: &SymbolTable, outer_scope: Option<&'a Scope<'a>>, + table_ref_counter: &mut TableRefIdCounter, ) -> Result { match select { ast::OneSelect::Select(select_inner) => { @@ -163,8 +181,15 @@ fn prepare_one_select_plan<'a>( let mut where_predicates = vec![]; // Parse the FROM clause into a vec of TableReferences. Fold all the join conditions expressions into the WHERE clause. - let table_references = - parse_from(schema, from, syms, with, &mut where_predicates, outer_scope)?; + let table_references = parse_from( + schema, + from, + syms, + with, + &mut where_predicates, + outer_scope, + table_ref_counter, + )?; // Preallocate space for the result columns let result_columns = Vec::with_capacity( @@ -192,7 +217,8 @@ fn prepare_one_select_plan<'a>( .iter() .enumerate() .map(|(i, t)| JoinOrderMember { - table_no: i, + table_id: t.internal_id, + original_idx: i, is_outer: t.join_info.as_ref().map_or(false, |j| j.outer), }) .collect(), @@ -226,13 +252,12 @@ fn prepare_one_select_plan<'a>( let referenced_table = plan .table_references .iter_mut() - .enumerate() - .find(|(_, t)| t.identifier == name_normalized); + .find(|t| t.identifier == name_normalized); if referenced_table.is_none() { crate::bail_parse_error!("Table {} not found", name.0); } - let (table_index, table) = referenced_table.unwrap(); + let table = referenced_table.unwrap(); let num_columns = table.columns().len(); for idx in 0..num_columns { let is_rowid_alias = { @@ -242,7 +267,7 @@ fn prepare_one_select_plan<'a>( plan.result_columns.push(ResultSetColumn { expr: ast::Expr::Column { database: None, // TODO: support different databases - table: table_index, + table: table.internal_id, column: idx, is_rowid_alias, }, diff --git a/core/translate/update.rs b/core/translate/update.rs index 6c7b4c2f5..44be69d3d 100644 --- a/core/translate/update.rs +++ b/core/translate/update.rs @@ -1,4 +1,5 @@ use crate::translate::plan::Operation; +use crate::vdbe::builder::TableRefIdCounter; use crate::{ bail_parse_error, schema::{Schema, Table}, @@ -54,7 +55,12 @@ pub fn translate_update( parse_schema: ParseSchema, mut program: ProgramBuilder, ) -> crate::Result { - let mut plan = prepare_update_plan(schema, body, parse_schema)?; + let mut plan = prepare_update_plan( + schema, + body, + parse_schema, + &mut program.table_reference_counter, + )?; optimize_plan(&mut plan, schema)?; // TODO: freestyling these numbers let opts = ProgramBuilderOpts { @@ -72,6 +78,7 @@ pub fn prepare_update_plan( schema: &Schema, body: &mut Update, parse_schema: ParseSchema, + table_ref_counter: &mut TableRefIdCounter, ) -> crate::Result { if body.with.is_some() { bail_parse_error!("WITH clause is not supported"); @@ -103,6 +110,7 @@ pub fn prepare_update_plan( _ => unreachable!(), }, identifier: table_name.0.clone(), + internal_id: table_ref_counter.next(), op: Operation::Scan { iter_dir, index: None, diff --git a/core/util.rs b/core/util.rs index 0f71089ec..909145bd0 100644 --- a/core/util.rs +++ b/core/util.rs @@ -1,6 +1,6 @@ use crate::{ schema::{self, Column, Schema, Type}, - translate::{collate::CollationSeq, expr::walk_expr}, + translate::{collate::CollationSeq, expr::walk_expr, plan::JoinOrderMember}, types::{Value, ValueType}, LimboError, OpenFlags, Result, Statement, StepResult, SymbolTable, IO, }; @@ -583,12 +583,20 @@ pub fn columns_from_create_table_body(body: &ast::CreateTableBody) -> crate::Res /// This function checks if a given expression is a constant value that can be pushed down to the database engine. /// It is expected to be called with the other half of a binary expression with an Expr::Column -pub fn can_pushdown_predicate(top_level_expr: &Expr, table_idx: usize) -> Result { +pub fn can_pushdown_predicate( + top_level_expr: &Expr, + table_idx: usize, + join_order: &[JoinOrderMember], +) -> Result { let mut can_pushdown = true; walk_expr(top_level_expr, &mut |expr: &Expr| -> Result<()> { match expr { Expr::Column { table, .. } | Expr::RowId { table, .. } => { - can_pushdown &= *table <= table_idx; + let join_idx = join_order + .iter() + .position(|t| t.table_id == *table) + .expect("table not found in join_order"); + can_pushdown &= join_idx <= table_idx; } Expr::FunctionCall { args, name, .. } => { let function = crate::function::Func::resolve_function( diff --git a/core/vdbe/builder.rs b/core/vdbe/builder.rs index 1f84a5399..a9149f96a 100644 --- a/core/vdbe/builder.rs +++ b/core/vdbe/builder.rs @@ -5,7 +5,7 @@ use std::{ sync::Arc, }; -use limbo_sqlite3_parser::ast; +use limbo_sqlite3_parser::ast::{self, TableInternalId}; use crate::{ fast_lock::SpinLock, @@ -19,10 +19,28 @@ use crate::{ }, Connection, VirtualTable, }; +pub struct TableRefIdCounter { + next_free: TableInternalId, +} + +impl TableRefIdCounter { + pub fn new() -> Self { + Self { + next_free: TableInternalId::default(), + } + } + + pub fn next(&mut self) -> ast::TableInternalId { + let id = self.next_free; + self.next_free += 1; + id + } +} use super::{BranchOffset, CursorID, Insn, InsnFunction, InsnReference, JumpTarget, Program}; #[allow(dead_code)] pub struct ProgramBuilder { + pub table_reference_counter: TableRefIdCounter, next_free_register: usize, next_free_cursor_id: usize, /// Instruction, the function to execute it with, and its original index in the vector. @@ -90,6 +108,7 @@ pub struct ProgramBuilderOpts { impl ProgramBuilder { pub fn new(opts: ProgramBuilderOpts) -> Self { Self { + table_reference_counter: TableRefIdCounter::new(), next_free_register: 1, next_free_cursor_id: 0, insns: Vec::with_capacity(opts.approx_num_insns), diff --git a/vendored/sqlite3-parser/src/parser/ast/mod.rs b/vendored/sqlite3-parser/src/parser/ast/mod.rs index 087df0534..4bf2fbf9c 100644 --- a/vendored/sqlite3-parser/src/parser/ast/mod.rs +++ b/vendored/sqlite3-parser/src/parser/ast/mod.rs @@ -288,6 +288,43 @@ pub struct Delete { pub limit: Option>, } +#[repr(transparent)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +/// Internal ID of a table. +/// +/// Used by [Expr::Column] and [Expr::RowId] to refer to a table. +pub struct TableInternalId(usize); + +impl Default for TableInternalId { + fn default() -> Self { + Self(1) + } +} + +impl From for TableInternalId { + fn from(value: usize) -> Self { + Self(value) + } +} + +impl std::ops::AddAssign for TableInternalId { + fn add_assign(&mut self, rhs: usize) { + self.0 += rhs; + } +} + +impl From for usize { + fn from(value: TableInternalId) -> Self { + value.0 + } +} + +impl std::fmt::Display for TableInternalId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "t{}", self.0) + } +} + /// SQL expression // https://sqlite.org/syntax/expr.html #[derive(Clone, Debug, PartialEq, Eq)] @@ -354,7 +391,7 @@ pub enum Expr { /// the x in `x.y.z`. index of the db in catalog. database: Option, /// the y in `x.y.z`. index of the table in catalog. - table: usize, + table: TableInternalId, /// the z in `x.y.z`. index of the column in the table. column: usize, /// is the column a rowid alias @@ -365,7 +402,7 @@ pub enum Expr { /// the x in `x.y.z`. index of the db in catalog. database: Option, /// the y in `x.y.z`. index of the table in catalog. - table: usize, + table: TableInternalId, }, /// `IN` InList {