From 124b38a26211e7325f50f4e2aad56fe1581a07c9 Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Tue, 27 May 2025 19:55:29 +0300 Subject: [PATCH 1/3] plan.rs: add new datastructures - TableReferences struct, which holds both: - joined_tables, and - outer_query_refs - JoinedTable: - this is just a rename of the previous TableReference struct - OuterQueryReference - this is to distinguish from JoinedTable those cases where e.g. a subquery refers to an outer query's table, or a CTE refers to a previous CTE. Both JoinedTable and OuterQueryReference can be referred to by expressions, but only JoinedTables are considered for join ordering optimization and so forth. This commit does not compile. --- core/translate/plan.rs | 207 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 204 insertions(+), 3 deletions(-) diff --git a/core/translate/plan.rs b/core/translate/plan.rs index d47ad740e..d79437de9 100644 --- a/core/translate/plan.rs +++ b/core/translate/plan.rs @@ -630,7 +630,7 @@ pub struct JoinInfo { pub using: Option, } -/// A table reference in the query plan. +/// A joined table in the query plan. /// For example, /// ```sql /// SELECT * FROM users u JOIN products p JOIN (SELECT * FROM users) sub; @@ -641,7 +641,7 @@ pub struct JoinInfo { /// - `t` and `p` are [Table::BTree] while `sub` is [Table::FromClauseSubquery] /// - join_info is None for the first table reference, and Some(JoinInfo { outer: false, using: None }) for the second and third table references #[derive(Debug, Clone)] -pub struct TableReference { +pub struct JoinedTable { /// The operation that this table reference performs. pub op: Operation, /// Table object, which contains metadata about the table, e.g. columns. @@ -657,6 +657,207 @@ pub struct TableReference { pub col_used_mask: ColumnUsedMask, } +#[derive(Debug, Clone)] +pub struct OuterQueryReference { + /// The name of the table as referred to in the query, either the literal name or an alias e.g. "users" or "u" + pub identifier: String, + /// Internal ID of the table reference, used in e.g. [Expr::Column] to refer to this table. + pub internal_id: TableInternalId, + /// Table object, which contains metadata about the table, e.g. columns. + pub table: Table, + /// Bitmask of columns that are referenced in the query. + /// Used to track dependencies, so that it can be resolved + /// when a WHERE clause subquery should be evaluated; + /// i.e., if the subquery depends on tables T and U, + /// then both T and U need to be in scope for the subquery to be evaluated. + pub col_used_mask: ColumnUsedMask, +} + +impl OuterQueryReference { + /// Returns the columns of the table that this outer query reference refers to. + pub fn columns(&self) -> &[Column] { + self.table.columns() + } + + /// Marks a column as used; used means that the column is referenced in the query. + pub fn mark_column_used(&mut self, column_index: usize) { + self.col_used_mask.set(column_index); + } + + /// Whether the OuterQueryReference is used by the current query scope. + /// This is used primarily to determine at what loop depth a subquery should be evaluated. + pub fn is_used(&self) -> bool { + !self.col_used_mask.is_empty() + } +} + +#[derive(Debug, Clone)] +/// A collection of table references in a given SQL statement. +/// +/// `TableReferences::joined_tables` is the list of tables that are joined together. +/// Example: SELECT * FROM t JOIN u JOIN v -- the joined tables are t, u and v. +/// +/// `TableReferences::outer_query_refs` are references to tables outside the current scope. +/// Example: SELECT * FROM t WHERE EXISTS (SELECT * FROM u WHERE u.foo = t.foo) +/// -- here, 'u' is an outer query reference for the subquery (SELECT * FROM u WHERE u.foo = t.foo), +/// since that query does not declare 't' in its FROM clause. +/// +/// +/// Typically a query will only have joined tables, but the following may have outer query references: +/// - CTEs that refer to other preceding CTEs +/// - Correlated subqueries, i.e. subqueries that depend on the outer scope +pub struct TableReferences { + /// Tables that are joined together in this query scope. + joined_tables: Vec, + /// Tables from outer scopes that are referenced in this query scope. + outer_query_refs: Vec, +} + +impl TableReferences { + pub fn new( + joined_tables: Vec, + outer_query_refs: Vec, + ) -> Self { + Self { + joined_tables, + outer_query_refs, + } + } + + /// Add a new [JoinedTable] to the query plan. + pub fn add_joined_table(&mut self, joined_table: JoinedTable) { + self.joined_tables.push(joined_table); + } + + /// Returns an immutable reference to the [JoinedTable]s in the query plan. + pub fn joined_tables(&self) -> &[JoinedTable] { + &self.joined_tables + } + + /// Returns a mutable reference to the [JoinedTable]s in the query plan. + pub fn joined_tables_mut(&mut self) -> &mut Vec { + &mut self.joined_tables + } + + /// Returns an immutable reference to the [OuterQueryReference]s in the query plan. + pub fn outer_query_refs(&self) -> &[OuterQueryReference] { + &self.outer_query_refs + } + + /// Returns an immutable reference to the [OuterQueryReference] with the given internal ID. + pub fn find_outer_query_ref_by_internal_id( + &self, + internal_id: TableInternalId, + ) -> Option<&OuterQueryReference> { + self.outer_query_refs + .iter() + .find(|t| t.internal_id == internal_id) + } + + /// Returns a mutable reference to the [OuterQueryReference] with the given internal ID. + pub fn find_outer_query_ref_by_internal_id_mut( + &mut self, + internal_id: TableInternalId, + ) -> Option<&mut OuterQueryReference> { + self.outer_query_refs + .iter_mut() + .find(|t| t.internal_id == internal_id) + } + + /// Returns an immutable reference to the [Table] with the given internal ID. + pub fn find_table_by_internal_id(&self, internal_id: TableInternalId) -> Option<&Table> { + self.joined_tables + .iter() + .find(|t| t.internal_id == internal_id) + .map(|t| &t.table) + .or_else(|| { + self.outer_query_refs + .iter() + .find(|t| t.internal_id == internal_id) + .map(|t| &t.table) + }) + } + + /// Returns an immutable reference to the [Table] with the given identifier, + /// where identifier is either the literal name of the table or an alias. + pub fn find_table_by_identifier(&self, identifier: &str) -> Option<&Table> { + self.joined_tables + .iter() + .find(|t| t.identifier == identifier) + .map(|t| &t.table) + .or_else(|| { + self.outer_query_refs + .iter() + .find(|t| t.identifier == identifier) + .map(|t| &t.table) + }) + } + + /// Returns an immutable reference to the [OuterQueryReference] with the given identifier, + /// where identifier is either the literal name of the table or an alias. + pub fn find_outer_query_ref_by_identifier( + &self, + identifier: &str, + ) -> Option<&OuterQueryReference> { + self.outer_query_refs + .iter() + .find(|t| t.identifier == identifier) + } + + /// Returns the internal ID and immutable reference to the [Table] with the given identifier, + pub fn find_table_and_internal_id_by_identifier( + &self, + identifier: &str, + ) -> Option<(TableInternalId, &Table)> { + self.joined_tables + .iter() + .find(|t| t.identifier == identifier) + .map(|t| (t.internal_id, &t.table)) + .or_else(|| { + self.outer_query_refs + .iter() + .find(|t| t.identifier == identifier) + .map(|t| (t.internal_id, &t.table)) + }) + } + + /// Returns an immutable reference to the [JoinedTable] with the given internal ID. + pub fn find_joined_table_by_internal_id( + &self, + internal_id: TableInternalId, + ) -> Option<&JoinedTable> { + self.joined_tables + .iter() + .find(|t| t.internal_id == internal_id) + } + + /// Returns a mutable reference to the [JoinedTable] with the given internal ID. + pub fn find_joined_table_by_internal_id_mut( + &mut self, + internal_id: TableInternalId, + ) -> Option<&mut JoinedTable> { + self.joined_tables + .iter_mut() + .find(|t| t.internal_id == internal_id) + } + + /// Marks a column as used; used means that the column is referenced in the query. + pub fn mark_column_used(&mut self, internal_id: TableInternalId, column_index: usize) { + if let Some(joined_table) = self.find_joined_table_by_internal_id_mut(internal_id) { + joined_table.mark_column_used(column_index); + } else if let Some(outer_query_ref) = + self.find_outer_query_ref_by_internal_id_mut(internal_id) + { + outer_query_ref.mark_column_used(column_index); + } else { + panic!( + "table with internal id {} not found in table references", + internal_id + ); + } + } +} + #[derive(Clone, Debug, PartialEq, Eq)] #[repr(transparent)] pub struct ColumnUsedMask(u128); @@ -717,7 +918,7 @@ impl Operation { } } -impl TableReference { +impl JoinedTable { /// Returns the btree table for this table reference, if it is a BTreeTable. pub fn btree(&self) -> Option> { match &self.table { From cc405dea7e159e64d724f9601c51e774187ac5e9 Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Tue, 27 May 2025 19:57:49 +0300 Subject: [PATCH 2/3] Use new TableReferences struct everywhere --- core/lib.rs | 2 +- core/translate/aggregation.rs | 4 +- core/translate/delete.rs | 7 +- core/translate/emitter.rs | 20 +- core/translate/expr.rs | 97 +++--- core/translate/group_by.rs | 7 +- core/translate/main_loop.rs | 240 +++++++------- core/translate/optimizer/access_method.rs | 4 +- core/translate/optimizer/constraints.rs | 6 +- core/translate/optimizer/join.rs | 136 ++++---- core/translate/optimizer/mod.rs | 50 +-- core/translate/optimizer/order.rs | 6 +- core/translate/order_by.rs | 11 +- core/translate/plan.rs | 59 ++-- core/translate/planner.rs | 367 ++++++++++++---------- core/translate/select.rs | 50 +-- core/translate/subquery.rs | 10 +- core/translate/update.rs | 6 +- core/vdbe/builder.rs | 8 +- core/vdbe/mod.rs | 5 +- 20 files changed, 578 insertions(+), 517 deletions(-) diff --git a/core/lib.rs b/core/lib.rs index f1411d37a..c0660644a 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -422,7 +422,7 @@ impl Connection { .deref(), *select, &syms, - None, + &[], &mut table_ref_counter, translate::plan::QueryDestination::ResultRows, )?; diff --git a/core/translate/aggregation.rs b/core/translate/aggregation.rs index de1b54069..ec8ac0274 100644 --- a/core/translate/aggregation.rs +++ b/core/translate/aggregation.rs @@ -12,7 +12,7 @@ use crate::{ use super::{ emitter::{Resolver, TranslateCtx}, expr::translate_expr, - plan::{Aggregate, Distinctness, SelectPlan, TableReference}, + plan::{Aggregate, Distinctness, SelectPlan, TableReferences}, result_row::emit_select_result, }; @@ -99,7 +99,7 @@ pub fn handle_distinct(program: &mut ProgramBuilder, agg: &Aggregate, agg_arg_re /// and the actual result value of the aggregation is materialized. pub fn translate_aggregation_step( program: &mut ProgramBuilder, - referenced_tables: &[TableReference], + referenced_tables: &TableReferences, agg: &Aggregate, target_register: usize, resolver: &Resolver, diff --git a/core/translate/delete.rs b/core/translate/delete.rs index 8f3485ce6..3786a0e00 100644 --- a/core/translate/delete.rs +++ b/core/translate/delete.rs @@ -7,7 +7,7 @@ use crate::vdbe::builder::{ProgramBuilder, ProgramBuilderOpts, QueryMode, TableR use crate::{schema::Schema, Result, SymbolTable}; use limbo_sqlite3_parser::ast::{Expr, Limit, QualifiedName}; -use super::plan::{ColumnUsedMask, IterationDirection, TableReference}; +use super::plan::{ColumnUsedMask, IterationDirection, JoinedTable, TableReferences}; pub fn translate_delete( query_mode: QueryMode, @@ -64,7 +64,7 @@ pub fn prepare_delete_plan( .iter() .cloned() .collect(); - let mut table_references = vec![TableReference { + let joined_tables = vec![JoinedTable { table, identifier: name, internal_id: table_ref_counter.next(), @@ -75,6 +75,7 @@ pub fn prepare_delete_plan( join_info: None, col_used_mask: ColumnUsedMask::new(), }]; + let mut table_references = TableReferences::new(joined_tables, vec![]); let mut where_predicates = vec![]; @@ -106,5 +107,5 @@ pub fn prepare_delete_plan( fn estimate_num_instructions(plan: &DeletePlan) -> usize { let base = 20; - base + plan.table_references.len() * 10 + base + plan.table_references.joined_tables().len() * 10 } diff --git a/core/translate/emitter.rs b/core/translate/emitter.rs index bb6892334..c838991bc 100644 --- a/core/translate/emitter.rs +++ b/core/translate/emitter.rs @@ -16,7 +16,7 @@ use super::main_loop::{ }; use super::order_by::{emit_order_by, init_order_by, SortMetadata}; use super::plan::{ - JoinOrderMember, Operation, QueryDestination, SelectPlan, TableReference, UpdatePlan, + JoinOrderMember, Operation, QueryDestination, SelectPlan, TableReferences, UpdatePlan, }; use super::schema::ParseSchema; use super::select::emit_simple_count; @@ -238,7 +238,7 @@ fn emit_program_for_compound_select( program, schema, syms, - first.table_references.len(), + first.table_references.joined_tables().len(), first.result_columns.len(), )); rest.iter().for_each(|(select, _)| { @@ -246,7 +246,7 @@ fn emit_program_for_compound_select( program, schema, syms, - select.table_references.len(), + select.table_references.joined_tables().len(), select.result_columns.len(), ); t_ctx_list.push(t_ctx); @@ -475,7 +475,7 @@ fn emit_program_for_select( program, schema, syms, - plan.table_references.len(), + plan.table_references.joined_tables().len(), plan.result_columns.len(), ); @@ -492,7 +492,7 @@ fn emit_program_for_select( emit_query(program, &mut plan, &mut t_ctx)?; // Finalize program - if plan.table_references.is_empty() { + if plan.table_references.joined_tables().is_empty() { program.epilogue(TransactionMode::None); } else { program.epilogue(TransactionMode::Read); @@ -647,7 +647,7 @@ fn emit_program_for_delete( program, schema, syms, - plan.table_references.len(), + plan.table_references.joined_tables().len(), plan.result_columns.len(), ); @@ -709,10 +709,10 @@ fn emit_program_for_delete( fn emit_delete_insns( program: &mut ProgramBuilder, t_ctx: &mut TranslateCtx, - table_references: &[TableReference], + table_references: &TableReferences, index_references: &[Arc], ) -> Result<()> { - let table_reference = table_references.first().unwrap(); + let table_reference = table_references.joined_tables().first().unwrap(); let cursor_id = match &table_reference.op { Operation::Scan { .. } => { program.resolve_cursor_id(&CursorKey::table(table_reference.internal_id)) @@ -811,7 +811,7 @@ fn emit_program_for_update( program, schema, syms, - plan.table_references.len(), + plan.table_references.joined_tables().len(), plan.returning.as_ref().map_or(0, |r| r.len()), ); @@ -895,7 +895,7 @@ fn emit_update_insns( program: &mut ProgramBuilder, index_cursors: Vec<(usize, usize)>, ) -> crate::Result<()> { - let table_ref = &plan.table_references.first().unwrap(); + let table_ref = plan.table_references.joined_tables().first().unwrap(); let loop_labels = t_ctx.labels_main_loop.first().unwrap(); let (cursor_id, index, is_virtual) = match &table_ref.op { Operation::Scan { index, .. } => ( diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 8e47bbcdd..b89e392e7 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -2,7 +2,7 @@ use limbo_sqlite3_parser::ast::{self, UnaryOperator}; use super::emitter::Resolver; use super::optimizer::Optimizable; -use super::plan::TableReference; +use super::plan::TableReferences; #[cfg(feature = "json")] use crate::function::JsonFunc; use crate::function::{Func, FuncCtx, MathFuncArity, ScalarFunc, VectorFunc}; @@ -131,7 +131,7 @@ macro_rules! expect_arguments_even { pub fn translate_condition_expr( program: &mut ProgramBuilder, - referenced_tables: &[TableReference], + referenced_tables: &TableReferences, expr: &ast::Expr, condition_metadata: ConditionMetadata, resolver: &Resolver, @@ -401,7 +401,7 @@ pub enum NoConstantOptReason { /// a register will end up being reused e.g. in a coroutine. pub fn translate_expr_no_constant_opt( program: &mut ProgramBuilder, - referenced_tables: Option<&[TableReference]>, + referenced_tables: Option<&TableReferences>, expr: &ast::Expr, target_register: usize, resolver: &Resolver, @@ -421,7 +421,7 @@ pub fn translate_expr_no_constant_opt( /// Translate an expression into bytecode. pub fn translate_expr( program: &mut ProgramBuilder, - referenced_tables: Option<&[TableReference]>, + referenced_tables: Option<&TableReferences>, expr: &ast::Expr, target_register: usize, resolver: &Resolver, @@ -1787,20 +1787,30 @@ pub fn translate_expr( ), ast::Expr::Column { database: _, - table, + table: table_ref_id, column, is_rowid_alias, } => { - let table_reference = referenced_tables - .as_ref() - .unwrap() - .iter() - .find(|t| t.internal_id == *table) - .unwrap(); - let index = table_reference.op.index(); - let use_covering_index = table_reference.utilizes_covering_index(); + let (index, use_covering_index) = { + if let Some(table_reference) = referenced_tables + .unwrap() + .find_joined_table_by_internal_id(*table_ref_id) + { + ( + table_reference.op.index(), + table_reference.utilizes_covering_index(), + ) + } else { + (None, false) + } + }; - let Some(table_column) = table_reference.table.get_column_at(*column) else { + let table = referenced_tables + .unwrap() + .find_table_by_internal_id(*table_ref_id) + .expect("table reference should be found"); + + let Some(table_column) = table.get_column_at(*column) else { crate::bail_parse_error!("column index out of bounds"); }; // Counter intuitive but a column always needs to have a collation @@ -1809,21 +1819,15 @@ pub fn translate_expr( // If we are reading a column from a table, we find the cursor that corresponds to // the table and read the column from the cursor. // If we have a covering index, we don't have an open table cursor so we read from the index cursor. - match &table_reference.table { + match &table { Table::BTree(_) => { let table_cursor_id = if use_covering_index { None } else { - Some( - program - .resolve_cursor_id(&CursorKey::table(table_reference.internal_id)), - ) + Some(program.resolve_cursor_id(&CursorKey::table(*table_ref_id))) }; let index_cursor_id = index.map(|index| { - program.resolve_cursor_id(&CursorKey::index( - table_reference.internal_id, - index.clone(), - )) + program.resolve_cursor_id(&CursorKey::index(*table_ref_id, index.clone())) }); if *is_rowid_alias { if let Some(index_cursor_id) = index_cursor_id { @@ -1854,7 +1858,7 @@ pub fn translate_expr( "index cursor should be opened when use_covering_index=true", ); index.column_table_pos_to_index_pos(*column).unwrap_or_else(|| { - panic!("covering index {} does not contain column number {} of table {}", index.name, column, table_reference.identifier) + panic!("covering index {} does not contain column number {} of table {}", index.name, column, table_ref_id) }) } else { *column @@ -1865,7 +1869,7 @@ pub fn translate_expr( dest: target_register, }); } - let Some(column) = table_reference.table.get_column_at(*column) else { + let Some(column) = table.get_column_at(*column) else { crate::bail_parse_error!("column index out of bounds"); }; maybe_apply_affinity(column.ty, target_register, program); @@ -1885,8 +1889,7 @@ pub fn translate_expr( Ok(target_register) } Table::Virtual(_) => { - let cursor_id = - program.resolve_cursor_id(&CursorKey::table(table_reference.internal_id)); + let cursor_id = program.resolve_cursor_id(&CursorKey::table(*table_ref_id)); program.emit_insn(Insn::VColumn { cursor_id, column: *column, @@ -1897,29 +1900,35 @@ pub fn translate_expr( Table::Pseudo(_) => panic!("Column access on pseudo table"), } } - ast::Expr::RowId { database: _, table } => { - let table_reference = referenced_tables - .as_ref() - .unwrap() - .iter() - .find(|t| t.internal_id == *table) - .unwrap(); - let index = table_reference.op.index(); - let use_covering_index = table_reference.utilizes_covering_index(); + ast::Expr::RowId { + database: _, + table: table_ref_id, + } => { + let (index, use_covering_index) = { + if let Some(table_reference) = referenced_tables + .unwrap() + .find_joined_table_by_internal_id(*table_ref_id) + { + ( + table_reference.op.index(), + table_reference.utilizes_covering_index(), + ) + } else { + (None, false) + } + }; + if use_covering_index { let index = index.expect("index cursor should be opened when use_covering_index=true"); - let cursor_id = program.resolve_cursor_id(&CursorKey::index( - table_reference.internal_id, - index.clone(), - )); + let cursor_id = + program.resolve_cursor_id(&CursorKey::index(*table_ref_id, index.clone())); program.emit_insn(Insn::IdxRowId { cursor_id, dest: target_register, }); } else { - let cursor_id = - program.resolve_cursor_id(&CursorKey::table(table_reference.internal_id)); + let cursor_id = program.resolve_cursor_id(&CursorKey::table(*table_ref_id)); program.emit_insn(Insn::RowId { cursor_id, dest: target_register, @@ -2437,7 +2446,7 @@ fn emit_binary_insn( /// see [translate_condition_expr] and [translate_expr] for implementations. fn translate_like_base( program: &mut ProgramBuilder, - referenced_tables: Option<&[TableReference]>, + referenced_tables: Option<&TableReferences>, expr: &ast::Expr, target_register: usize, resolver: &Resolver, @@ -2496,7 +2505,7 @@ fn translate_like_base( fn translate_function( program: &mut ProgramBuilder, args: &[ast::Expr], - referenced_tables: Option<&[TableReference]>, + referenced_tables: Option<&TableReferences>, resolver: &Resolver, target_register: usize, func_ctx: FuncCtx, diff --git a/core/translate/group_by.rs b/core/translate/group_by.rs index afcffcfdc..bde5eaab5 100644 --- a/core/translate/group_by.rs +++ b/core/translate/group_by.rs @@ -20,7 +20,7 @@ use super::{ emitter::{Resolver, TranslateCtx}, expr::{translate_condition_expr, translate_expr, ConditionMetadata}, order_by::order_by_sorter_insert, - plan::{Aggregate, Distinctness, GroupBy, SelectPlan, TableReference}, + plan::{Aggregate, Distinctness, GroupBy, SelectPlan, TableReferences}, result_row::emit_select_result, }; @@ -137,8 +137,7 @@ pub fn init_group_by( ast::Expr::Column { table, column, .. } => { let table_reference = plan .table_references - .iter() - .find(|t| t.internal_id == *table) + .find_joined_table_by_internal_id(*table) .unwrap(); let Some(table_column) = table_reference.table.get_column_at(*column) else { @@ -971,7 +970,7 @@ pub fn group_by_emit_row_phase<'a>( /// and the actual result value of the aggregation is materialized. pub fn translate_aggregation_step_groupby( program: &mut ProgramBuilder, - referenced_tables: &[TableReference], + referenced_tables: &TableReferences, agg_arg_source: GroupByAggArgumentSource, target_register: usize, resolver: &Resolver, diff --git a/core/translate/main_loop.rs b/core/translate/main_loop.rs index 1e695730e..1614a41db 100644 --- a/core/translate/main_loop.rs +++ b/core/translate/main_loop.rs @@ -30,7 +30,7 @@ use super::{ order_by::{order_by_sorter_insert, sorter_insert}, plan::{ convert_where_to_vtab_constraint, Aggregate, GroupBy, IterationDirection, JoinOrderMember, - Operation, QueryDestination, Search, SeekDef, SelectPlan, TableReference, WhereTerm, + Operation, QueryDestination, Search, SeekDef, SelectPlan, TableReferences, WhereTerm, }, }; @@ -110,13 +110,13 @@ pub fn init_distinct(program: &mut ProgramBuilder, plan: &mut SelectPlan) { pub fn init_loop( program: &mut ProgramBuilder, t_ctx: &mut TranslateCtx, - tables: &[TableReference], + tables: &TableReferences, aggregates: &mut [Aggregate], group_by: Option<&GroupBy>, mode: OperationMode, ) -> Result<()> { assert!( - t_ctx.meta_left_joins.len() == tables.len(), + t_ctx.meta_left_joins.len() == tables.joined_tables().len(), "meta_left_joins length does not match tables length" ); // Initialize ephemeral indexes for distinct aggregates @@ -161,7 +161,7 @@ pub fn init_loop( }), }; } - for (table_index, table) in tables.iter().enumerate() { + for (table_index, table) in tables.joined_tables().iter().enumerate() { // Initialize bookkeeping for OUTER JOIN if let Some(join_info) = table.join_info.as_ref() { if join_info.outer { @@ -287,20 +287,20 @@ pub fn init_loop( pub fn open_loop( program: &mut ProgramBuilder, t_ctx: &mut TranslateCtx, - tables: &[TableReference], + table_references: &TableReferences, join_order: &[JoinOrderMember], predicates: &[WhereTerm], ) -> Result<()> { for (join_index, join) in join_order.iter().enumerate() { - let table_index = join.original_idx; - let table = &tables[table_index]; + let joined_table_index = join.original_idx; + let table = &table_references.joined_tables()[joined_table_index]; let LoopLabels { loop_start, loop_end, next, } = *t_ctx .labels_main_loop - .get(table_index) + .get(joined_table_index) .expect("table has no loop labels"); // Each OUTER JOIN has a "match flag" that is initially set to false, @@ -308,7 +308,7 @@ pub fn open_loop( // This is used to determine whether to emit actual columns or NULLs for the columns of the right table. if let Some(join_info) = table.join_info.as_ref() { if join_info.outer { - let lj_meta = t_ctx.meta_left_joins[table_index].as_ref().unwrap(); + let lj_meta = t_ctx.meta_left_joins[joined_table_index].as_ref().unwrap(); program.emit_insn(Insn::Integer { value: 0, dest: lj_meta.reg_match_flag, @@ -339,112 +339,116 @@ pub fn open_loop( program.preassign_label_to_next_insn(loop_start); } Table::Virtual(vtab) => { - let (start_reg, count, maybe_idx_str, maybe_idx_int) = if vtab - .kind - .eq(&VTabKind::VirtualTable) - { - // Virtual‑table (non‑TVF) modules can receive constraints via xBestIndex. - // They return information with which to pass to VFilter operation. - // We forward every predicate that touches vtab columns. - // - // vtab.col = literal (always usable) - // vtab.col = outer_table.col (usable, because outer_table is already positioned) - // vtab.col = later_table.col (forwarded with usable = false) - // - // xBestIndex decides which ones it wants by setting argvIndex and whether the - // core layer may omit them (omit = true). - // We then materialise the RHS/LHS into registers before issuing VFilter. - let converted_constraints = predicates - .iter() - .filter(|p| p.should_eval_at_loop(join_index, join_order)) - .enumerate() - .filter_map(|(i, p)| { - // Build ConstraintInfo from the predicates - convert_where_to_vtab_constraint(p, table_index, i, join_order) + let (start_reg, count, maybe_idx_str, maybe_idx_int) = + if vtab.kind.eq(&VTabKind::VirtualTable) { + // Virtual‑table (non‑TVF) modules can receive constraints via xBestIndex. + // They return information with which to pass to VFilter operation. + // We forward every predicate that touches vtab columns. + // + // vtab.col = literal (always usable) + // vtab.col = outer_table.col (usable, because outer_table is already positioned) + // vtab.col = later_table.col (forwarded with usable = false) + // + // xBestIndex decides which ones it wants by setting argvIndex and whether the + // core layer may omit them (omit = true). + // We then materialise the RHS/LHS into registers before issuing VFilter. + let converted_constraints = predicates + .iter() + .filter(|p| p.should_eval_at_loop(join_index, join_order)) + .enumerate() + .filter_map(|(i, p)| { + // Build ConstraintInfo from the predicates + convert_where_to_vtab_constraint( + p, + joined_table_index, + i, + join_order, + ) .unwrap_or(None) - }) - .collect::>(); - // TODO: get proper order_by information to pass to the vtab. - // maybe encode more info on t_ctx? we need: [col_idx, is_descending] - let index_info = vtab.best_index(&converted_constraints, &[]); + }) + .collect::>(); + // TODO: get proper order_by information to pass to the vtab. + // maybe encode more info on t_ctx? we need: [col_idx, is_descending] + let index_info = vtab.best_index(&converted_constraints, &[]); - // Determine the number of VFilter arguments (constraints with an argv_index). - let args_needed = index_info - .constraint_usages - .iter() - .filter(|u| u.argv_index.is_some()) - .count(); - let start_reg = program.alloc_registers(args_needed); + // Determine the number of VFilter arguments (constraints with an argv_index). + let args_needed = index_info + .constraint_usages + .iter() + .filter(|u| u.argv_index.is_some()) + .count(); + let start_reg = program.alloc_registers(args_needed); - // For each constraint used by best_index, translate the opposite side. - for (i, usage) in index_info.constraint_usages.iter().enumerate() { - if let Some(argv_index) = usage.argv_index { - if let Some(cinfo) = converted_constraints.get(i) { - let (pred_idx, is_rhs) = cinfo.unpack_plan_info(); - if let ast::Expr::Binary(lhs, _, rhs) = - &predicates[pred_idx].expr - { - // translate the opposite side of the referenced vtab column - let expr = if is_rhs { lhs } else { rhs }; - // argv_index is 1-based; adjust to get the proper register offset. - if argv_index == 0 { - // invalid since argv_index is 1-based - continue; - } - let target_reg = start_reg + (argv_index - 1) as usize; - translate_expr( - program, - Some(tables), - expr, - target_reg, - &t_ctx.resolver, - )?; - if cinfo.usable && usage.omit { - predicates[pred_idx].consumed.set(true); + // For each constraint used by best_index, translate the opposite side. + for (i, usage) in index_info.constraint_usages.iter().enumerate() { + if let Some(argv_index) = usage.argv_index { + if let Some(cinfo) = converted_constraints.get(i) { + let (pred_idx, is_rhs) = cinfo.unpack_plan_info(); + if let ast::Expr::Binary(lhs, _, rhs) = + &predicates[pred_idx].expr + { + // translate the opposite side of the referenced vtab column + let expr = if is_rhs { lhs } else { rhs }; + // argv_index is 1-based; adjust to get the proper register offset. + if argv_index == 0 { + // invalid since argv_index is 1-based + continue; + } + let target_reg = + start_reg + (argv_index - 1) as usize; + translate_expr( + program, + Some(table_references), + expr, + target_reg, + &t_ctx.resolver, + )?; + if cinfo.usable && usage.omit { + predicates[pred_idx].consumed.set(true); + } } } } } - } - // If best_index provided an idx_str, translate it. - let maybe_idx_str = if let Some(idx_str) = index_info.idx_str { - let reg = program.alloc_register(); - program.emit_insn(Insn::String8 { - dest: reg, - value: idx_str, - }); - Some(reg) + // If best_index provided an idx_str, translate it. + let maybe_idx_str = if let Some(idx_str) = index_info.idx_str { + let reg = program.alloc_register(); + program.emit_insn(Insn::String8 { + dest: reg, + value: idx_str, + }); + Some(reg) + } else { + None + }; + ( + start_reg, + args_needed, + maybe_idx_str, + Some(index_info.idx_num), + ) } else { - None + // For table-valued functions: translate the table args. + let args = match vtab.args.as_ref() { + Some(args) => args, + None => &vec![], + }; + let start_reg = program.alloc_registers(args.len()); + let mut cur_reg = start_reg; + for arg in args { + let reg = cur_reg; + cur_reg += 1; + let _ = translate_expr( + program, + Some(table_references), + arg, + reg, + &t_ctx.resolver, + )?; + } + (start_reg, args.len(), None, None) }; - ( - start_reg, - args_needed, - maybe_idx_str, - Some(index_info.idx_num), - ) - } else { - // For table-valued functions: translate the table args. - let args = match vtab.args.as_ref() { - Some(args) => args, - None => &vec![], - }; - let start_reg = program.alloc_registers(args.len()); - let mut cur_reg = start_reg; - for arg in args { - let reg = cur_reg; - cur_reg += 1; - let _ = translate_expr( - program, - Some(tables), - arg, - reg, - &t_ctx.resolver, - )?; - } - (start_reg, args.len(), None, None) - }; // Emit VFilter with the computed arguments. program.emit_insn(Insn::VFilter { @@ -507,7 +511,7 @@ pub fn open_loop( }; translate_condition_expr( program, - tables, + table_references, &cond.expr, condition_metadata, &t_ctx.resolver, @@ -524,7 +528,13 @@ pub fn open_loop( // Rowid equality point lookups are handled with a SeekRowid instruction which does not loop, since it is a single row lookup. if let Search::RowidEq { cmp_expr } = search { let src_reg = program.alloc_register(); - translate_expr(program, Some(tables), cmp_expr, src_reg, &t_ctx.resolver)?; + translate_expr( + program, + Some(table_references), + cmp_expr, + src_reg, + &t_ctx.resolver, + )?; program.emit_insn(Insn::SeekRowid { cursor_id: table_cursor_id .expect("Search::RowidEq requires a table cursor"), @@ -570,7 +580,7 @@ pub fn open_loop( let start_reg = program.alloc_registers(seek_def.key.len()); emit_seek( program, - tables, + table_references, seek_def, t_ctx, seek_cursor_id, @@ -580,7 +590,7 @@ pub fn open_loop( )?; emit_seek_termination( program, - tables, + table_references, seek_def, t_ctx, seek_cursor_id, @@ -613,7 +623,7 @@ pub fn open_loop( }; translate_condition_expr( program, - tables, + table_references, &cond.expr, condition_metadata, &t_ctx.resolver, @@ -629,7 +639,7 @@ pub fn open_loop( // for the right table's cursor. if let Some(join_info) = table.join_info.as_ref() { if join_info.outer { - let lj_meta = t_ctx.meta_left_joins[table_index].as_ref().unwrap(); + let lj_meta = t_ctx.meta_left_joins[joined_table_index].as_ref().unwrap(); program.resolve_label(lj_meta.label_match_flag_set_true, program.offset()); program.emit_insn(Insn::Integer { value: 1, @@ -902,7 +912,7 @@ fn emit_loop_source<'a>( pub fn close_loop( program: &mut ProgramBuilder, t_ctx: &mut TranslateCtx, - tables: &[TableReference], + tables: &TableReferences, join_order: &[JoinOrderMember], ) -> Result<()> { // We close the loops for all tables in reverse order, i.e. innermost first. @@ -915,7 +925,7 @@ pub fn close_loop( // CLOSE t1 for join in join_order.iter().rev() { let table_index = join.original_idx; - let table = &tables[table_index]; + let table = &tables.joined_tables()[table_index]; let loop_labels = *t_ctx .labels_main_loop .get(table_index) @@ -1052,7 +1062,7 @@ pub fn close_loop( #[allow(clippy::too_many_arguments)] fn emit_seek( program: &mut ProgramBuilder, - tables: &[TableReference], + tables: &TableReferences, seek_def: &SeekDef, t_ctx: &mut TranslateCtx, seek_cursor_id: usize, @@ -1163,7 +1173,7 @@ fn emit_seek( #[allow(clippy::too_many_arguments)] fn emit_seek_termination( program: &mut ProgramBuilder, - tables: &[TableReference], + tables: &TableReferences, seek_def: &SeekDef, t_ctx: &mut TranslateCtx, seek_cursor_id: usize, diff --git a/core/translate/optimizer/access_method.rs b/core/translate/optimizer/access_method.rs index e1b926c8e..9fafe1fc6 100644 --- a/core/translate/optimizer/access_method.rs +++ b/core/translate/optimizer/access_method.rs @@ -4,7 +4,7 @@ use limbo_sqlite3_parser::ast::SortOrder; use crate::{ schema::Index, - translate::plan::{IterationDirection, JoinOrderMember, TableReference}, + translate::plan::{IterationDirection, JoinOrderMember, JoinedTable}, Result, }; @@ -48,7 +48,7 @@ impl<'a> AccessMethod<'a> { /// Return the best [AccessMethod] for a given join order. pub fn find_best_access_method_for_join_order<'a>( - rhs_table: &TableReference, + rhs_table: &JoinedTable, rhs_constraints: &'a TableConstraints, join_order: &[JoinOrderMember], maybe_order_target: Option<&OrderTarget>, diff --git a/core/translate/optimizer/constraints.rs b/core/translate/optimizer/constraints.rs index 04398971f..fec8fddf3 100644 --- a/core/translate/optimizer/constraints.rs +++ b/core/translate/optimizer/constraints.rs @@ -4,7 +4,7 @@ use crate::{ schema::{Column, Index}, translate::{ expr::as_binary_components, - plan::{JoinOrderMember, TableReference, WhereTerm}, + plan::{JoinOrderMember, TableReferences, WhereTerm}, planner::{table_mask_from_expr, TableMask}, }, Result, @@ -172,13 +172,13 @@ fn estimate_selectivity(column: &Column, op: ast::Operator) -> f64 { /// The resulting list of [TableConstraints] is then used to evaluate the best access methods for various join orders. pub fn constraints_from_where_clause( where_clause: &[WhereTerm], - table_references: &[TableReference], + table_references: &TableReferences, available_indexes: &HashMap>>, ) -> Result> { let mut constraints = Vec::new(); // For each table, collect all the Constraints and all potential index candidates that may use them. - for table_reference in table_references.iter() { + for table_reference in table_references.joined_tables() { let rowid_alias_column = table_reference .columns() .iter() diff --git a/core/translate/optimizer/join.rs b/core/translate/optimizer/join.rs index b74e746ab..63072f73e 100644 --- a/core/translate/optimizer/join.rs +++ b/core/translate/optimizer/join.rs @@ -5,7 +5,7 @@ use limbo_sqlite3_parser::ast::TableInternalId; use crate::{ translate::{ optimizer::{cost::Cost, order::plan_satisfies_order_target}, - plan::{JoinOrderMember, TableReference}, + plan::{JoinOrderMember, JoinedTable}, planner::TableMask, }, Result, @@ -45,7 +45,7 @@ impl JoinN { /// Returns None if the plan is worse than the provided cost upper bound. pub fn join_lhs_and_rhs<'a>( lhs: Option<&JoinN>, - rhs_table_reference: &TableReference, + rhs_table_reference: &JoinedTable, rhs_constraints: &'a TableConstraints, join_order: &[JoinOrderMember], maybe_order_target: Option<&OrderTarget>, @@ -118,21 +118,21 @@ pub struct BestJoinOrderResult { /// Compute the best way to join a given set of tables. /// Returns the best [JoinN] if one exists, otherwise returns None. pub fn compute_best_join_order<'a>( - table_references: &[TableReference], + joined_tables: &[JoinedTable], maybe_order_target: Option<&OrderTarget>, constraints: &'a [TableConstraints], access_methods_arena: &'a RefCell>>, ) -> Result> { // Skip work if we have no tables to consider. - if table_references.is_empty() { + if joined_tables.is_empty() { return Ok(None); } - let num_tables = table_references.len(); + let num_tables = joined_tables.len(); // Compute naive left-to-right plan to use as pruning threshold let naive_plan = compute_naive_left_deep_plan( - table_references, + joined_tables, maybe_order_target, access_methods_arena, &constraints, @@ -146,7 +146,7 @@ pub fn compute_best_join_order<'a>( plan_satisfies_order_target( &naive_plan, &access_methods_arena, - table_references, + joined_tables, order_target, ) } else { @@ -154,7 +154,7 @@ pub fn compute_best_join_order<'a>( }; // If we have one table, then the "naive left-to-right plan" is always the best. - if table_references.len() == 1 { + if joined_tables.len() == 1 { return Ok(Some(BestJoinOrderResult { best_plan: naive_plan, best_ordered_plan: None, @@ -188,7 +188,7 @@ pub fn compute_best_join_order<'a>( for i in 0..num_tables { let mut mask = TableMask::new(); mask.add_table(i); - let table_ref = &table_references[i]; + let table_ref = &joined_tables[i]; join_order[0] = JoinOrderMember { table_id: table_ref.internal_id, original_idx: i, @@ -215,7 +215,7 @@ pub fn compute_best_join_order<'a>( // "a LEFT JOIN b" can NOT be reordered as "b LEFT JOIN a". // If there are outer joins in the plan, ensure correct ordering. let left_join_illegal_map = { - let left_join_count = table_references + let left_join_count = joined_tables .iter() .filter(|t| t.join_info.as_ref().map_or(false, |j| j.outer)) .count(); @@ -225,9 +225,9 @@ pub fn compute_best_join_order<'a>( // map from rhs table index to lhs table index let mut left_join_illegal_map: HashMap = HashMap::with_capacity(left_join_count); - for (i, _) in table_references.iter().enumerate() { - for j in i + 1..table_references.len() { - if table_references[j] + for (i, _) in joined_tables.iter().enumerate() { + for j in i + 1..joined_tables.len() { + if joined_tables[j] .join_info .as_ref() .map_or(false, |j| j.outer) @@ -295,18 +295,18 @@ pub fn compute_best_join_order<'a>( // Build a JoinOrder out of the table bitmask we are now considering. for table_no in lhs.table_numbers() { join_order.push(JoinOrderMember { - table_id: table_references[table_no].internal_id, + table_id: joined_tables[table_no].internal_id, original_idx: table_no, - is_outer: table_references[table_no] + is_outer: joined_tables[table_no] .join_info .as_ref() .map_or(false, |j| j.outer), }); } join_order.push(JoinOrderMember { - table_id: table_references[rhs_idx].internal_id, + table_id: joined_tables[rhs_idx].internal_id, original_idx: rhs_idx, - is_outer: table_references[rhs_idx] + is_outer: joined_tables[rhs_idx] .join_info .as_ref() .map_or(false, |j| j.outer), @@ -316,7 +316,7 @@ pub fn compute_best_join_order<'a>( // Calculate the best way to join LHS with RHS. let rel = join_lhs_and_rhs( Some(lhs), - &table_references[rhs_idx], + &joined_tables[rhs_idx], &constraints[rhs_idx], &join_order, maybe_order_target, @@ -333,7 +333,7 @@ pub fn compute_best_join_order<'a>( plan_satisfies_order_target( &rel, &access_methods_arena, - table_references, + joined_tables, order_target, ) } else { @@ -396,15 +396,15 @@ pub fn compute_best_join_order<'a>( /// in the SQL query. This is used as an upper bound for any other plans -- we can give up enumerating /// permutations if they exceed this cost during enumeration. pub fn compute_naive_left_deep_plan<'a>( - table_references: &[TableReference], + joined_tables: &[JoinedTable], maybe_order_target: Option<&OrderTarget>, access_methods_arena: &'a RefCell>>, constraints: &'a [TableConstraints], ) -> Result { - let n = table_references.len(); + let n = joined_tables.len(); assert!(n > 0); - let join_order = table_references + let join_order = joined_tables .iter() .enumerate() .map(|(i, t)| JoinOrderMember { @@ -417,7 +417,7 @@ pub fn compute_naive_left_deep_plan<'a>( // Start with first table let mut best_plan = join_lhs_and_rhs( None, - &table_references[0], + &joined_tables[0], &constraints[0], &join_order[..1], maybe_order_target, @@ -430,7 +430,7 @@ pub fn compute_naive_left_deep_plan<'a>( for i in 1..n { best_plan = join_lhs_and_rhs( Some(&best_plan), - &table_references[i], + &joined_tables[i], &constraints[i], &join_order[..=i], maybe_order_target, @@ -523,17 +523,17 @@ mod tests { #[test] /// Test that [compute_best_join_order] returns None when there are no table references. fn test_compute_best_join_order_empty() { - let table_references = vec![]; + let joined_tables = vec![]; let available_indexes = HashMap::new(); let where_clause = vec![]; let access_methods_arena = RefCell::new(Vec::new()); let table_constraints = - constraints_from_where_clause(&where_clause, &table_references, &available_indexes) + constraints_from_where_clause(&where_clause, &joined_tables, &available_indexes) .unwrap(); let result = compute_best_join_order( - &table_references, + &joined_tables, None, &table_constraints, &access_methods_arena, @@ -547,7 +547,7 @@ mod tests { fn test_compute_best_join_order_single_table_no_indexes() { let t1 = _create_btree_table("test_table", _create_column_list(&["id"], Type::Integer)); let mut table_id_counter = TableRefIdCounter::new(); - let table_references = vec![_create_table_reference( + let joined_tables = vec![_create_table_reference( t1.clone(), None, table_id_counter.next(), @@ -557,13 +557,13 @@ mod tests { let access_methods_arena = RefCell::new(Vec::new()); let table_constraints = - constraints_from_where_clause(&where_clause, &table_references, &available_indexes) + constraints_from_where_clause(&where_clause, &joined_tables, &available_indexes) .unwrap(); // SELECT * from test_table // expecting best_best_plan() not to do any work due to empty where clause. let BestJoinOrderResult { best_plan, .. } = compute_best_join_order( - &table_references, + &joined_tables, None, &table_constraints, &access_methods_arena, @@ -581,14 +581,14 @@ mod tests { fn test_compute_best_join_order_single_table_rowid_eq() { let t1 = _create_btree_table("test_table", vec![_create_column_rowid_alias("id")]); let mut table_id_counter = TableRefIdCounter::new(); - let table_references = vec![_create_table_reference( + let joined_tables = vec![_create_table_reference( t1.clone(), None, table_id_counter.next(), )]; let where_clause = vec![_create_binary_expr( - _create_column_expr(table_references[0].internal_id, 0, true), // table 0, column 0 (rowid) + _create_column_expr(joined_tables[0].internal_id, 0, true), // table 0, column 0 (rowid) ast::Operator::Equals, _create_numeric_literal("42"), )]; @@ -596,13 +596,13 @@ mod tests { let access_methods_arena = RefCell::new(Vec::new()); let available_indexes = HashMap::new(); let table_constraints = - constraints_from_where_clause(&where_clause, &table_references, &available_indexes) + constraints_from_where_clause(&where_clause, &joined_tables, &available_indexes) .unwrap(); // SELECT * FROM test_table WHERE id = 42 // expecting a RowidEq access method because id is a rowid alias. let result = compute_best_join_order( - &table_references, + &joined_tables, None, &table_constraints, &access_methods_arena, @@ -630,14 +630,14 @@ mod tests { vec![_create_column_of_type("id", Type::Integer)], ); let mut table_id_counter = TableRefIdCounter::new(); - let table_references = vec![_create_table_reference( + let joined_tables = vec![_create_table_reference( t1.clone(), None, table_id_counter.next(), )]; let where_clause = vec![_create_binary_expr( - _create_column_expr(table_references[0].internal_id, 0, false), // table 0, column 0 (id) + _create_column_expr(joined_tables[0].internal_id, 0, false), // table 0, column 0 (id) ast::Operator::Equals, _create_numeric_literal("42"), )]; @@ -661,12 +661,12 @@ mod tests { available_indexes.insert("test_table".to_string(), vec![index]); let table_constraints = - constraints_from_where_clause(&where_clause, &table_references, &available_indexes) + constraints_from_where_clause(&where_clause, &joined_tables, &available_indexes) .unwrap(); // SELECT * FROM test_table WHERE id = 42 // expecting an IndexScan access method because id is a primary key with an index let result = compute_best_join_order( - &table_references, + &joined_tables, None, &table_constraints, &access_methods_arena, @@ -694,7 +694,7 @@ mod tests { let t2 = _create_btree_table("table2", _create_column_list(&["id"], Type::Integer)); let mut table_id_counter = TableRefIdCounter::new(); - let mut table_references = vec![ + let mut joined_tables = vec![ _create_table_reference(t1.clone(), None, table_id_counter.next()), _create_table_reference( t2.clone(), @@ -730,18 +730,18 @@ mod tests { // SELECT * FROM table1 JOIN table2 WHERE table1.id = table2.id // expecting table2 to be chosen first due to the index on table1.id let where_clause = vec![_create_binary_expr( - _create_column_expr(table_references[TABLE1].internal_id, 0, false), // table1.id + _create_column_expr(joined_tables[TABLE1].internal_id, 0, false), // table1.id ast::Operator::Equals, - _create_column_expr(table_references[TABLE2].internal_id, 0, false), // table2.id + _create_column_expr(joined_tables[TABLE2].internal_id, 0, false), // table2.id )]; let access_methods_arena = RefCell::new(Vec::new()); let table_constraints = - constraints_from_where_clause(&where_clause, &table_references, &available_indexes) + constraints_from_where_clause(&where_clause, &joined_tables, &available_indexes) .unwrap(); let result = compute_best_join_order( - &mut table_references, + &mut joined_tables, None, &table_constraints, &access_methods_arena, @@ -795,7 +795,7 @@ mod tests { ); let mut table_id_counter = TableRefIdCounter::new(); - let table_references = vec![ + let joined_tables = vec![ _create_table_reference(table_orders.clone(), None, table_id_counter.next()), _create_table_reference( table_customers.clone(), @@ -885,19 +885,19 @@ mod tests { let where_clause = vec![ // orders.customer_id = customers.id _create_binary_expr( - _create_column_expr(table_references[TABLE_NO_ORDERS].internal_id, 1, false), // orders.customer_id + _create_column_expr(joined_tables[TABLE_NO_ORDERS].internal_id, 1, false), // orders.customer_id ast::Operator::Equals, - _create_column_expr(table_references[TABLE_NO_CUSTOMERS].internal_id, 0, false), // customers.id + _create_column_expr(joined_tables[TABLE_NO_CUSTOMERS].internal_id, 0, false), // customers.id ), // orders.id = order_items.order_id _create_binary_expr( - _create_column_expr(table_references[TABLE_NO_ORDERS].internal_id, 0, false), // orders.id + _create_column_expr(joined_tables[TABLE_NO_ORDERS].internal_id, 0, false), // orders.id ast::Operator::Equals, - _create_column_expr(table_references[TABLE_NO_ORDER_ITEMS].internal_id, 1, false), // order_items.order_id + _create_column_expr(joined_tables[TABLE_NO_ORDER_ITEMS].internal_id, 1, false), // order_items.order_id ), // customers.id = 42 _create_binary_expr( - _create_column_expr(table_references[TABLE_NO_CUSTOMERS].internal_id, 0, false), // customers.id + _create_column_expr(joined_tables[TABLE_NO_CUSTOMERS].internal_id, 0, false), // customers.id ast::Operator::Equals, _create_numeric_literal("42"), ), @@ -905,11 +905,11 @@ mod tests { let access_methods_arena = RefCell::new(Vec::new()); let table_constraints = - constraints_from_where_clause(&where_clause, &table_references, &available_indexes) + constraints_from_where_clause(&where_clause, &joined_tables, &available_indexes) .unwrap(); let result = compute_best_join_order( - &table_references, + &joined_tables, None, &table_constraints, &access_methods_arena, @@ -975,7 +975,7 @@ mod tests { let t3 = _create_btree_table("t3", _create_column_list(&["id", "foo"], Type::Integer)); let mut table_id_counter = TableRefIdCounter::new(); - let mut table_references = vec![ + let mut joined_tables = vec![ _create_table_reference(t1.clone(), None, table_id_counter.next()), _create_table_reference( t2.clone(), @@ -998,13 +998,13 @@ mod tests { let where_clause = vec![ // t2.foo = 42 (equality filter, more selective) _create_binary_expr( - _create_column_expr(table_references[1].internal_id, 1, false), // table 1, column 1 (foo) + _create_column_expr(joined_tables[1].internal_id, 1, false), // table 1, column 1 (foo) ast::Operator::Equals, _create_numeric_literal("42"), ), // t1.foo > 10 (inequality filter, less selective) _create_binary_expr( - _create_column_expr(table_references[0].internal_id, 1, false), // table 0, column 1 (foo) + _create_column_expr(joined_tables[0].internal_id, 1, false), // table 0, column 1 (foo) ast::Operator::Greater, _create_numeric_literal("10"), ), @@ -1013,11 +1013,11 @@ mod tests { let available_indexes = HashMap::new(); let access_methods_arena = RefCell::new(Vec::new()); let table_constraints = - constraints_from_where_clause(&where_clause, &table_references, &available_indexes) + constraints_from_where_clause(&where_clause, &joined_tables, &available_indexes) .unwrap(); let BestJoinOrderResult { best_plan, .. } = compute_best_join_order( - &mut table_references, + &mut joined_tables, None, &table_constraints, &access_methods_arena, @@ -1075,7 +1075,7 @@ mod tests { .collect(); let mut table_id_counter = TableRefIdCounter::new(); - let table_references = { + let joined_tables = { let mut refs = vec![_create_table_reference( dim_tables[0].clone(), None, @@ -1106,8 +1106,8 @@ mod tests { // Add join conditions between fact and each dimension table for i in 0..NUM_DIM_TABLES { - let internal_id_fact = table_references[FACT_TABLE_IDX].internal_id; - let internal_id_other = table_references[i].internal_id; + let internal_id_fact = joined_tables[FACT_TABLE_IDX].internal_id; + let internal_id_other = joined_tables[i].internal_id; where_clause.push(_create_binary_expr( _create_column_expr(internal_id_fact, i + 1, false), // fact.dimX_id ast::Operator::Equals, @@ -1118,11 +1118,11 @@ mod tests { let access_methods_arena = RefCell::new(Vec::new()); let available_indexes = HashMap::new(); let table_constraints = - constraints_from_where_clause(&where_clause, &table_references, &available_indexes) + constraints_from_where_clause(&where_clause, &joined_tables, &available_indexes) .unwrap(); let result = compute_best_join_order( - &table_references, + &joined_tables, None, &table_constraints, &access_methods_arena, @@ -1182,7 +1182,7 @@ mod tests { let mut table_id_counter = TableRefIdCounter::new(); // Create table references - let table_references: Vec<_> = tables + let joined_tables: Vec<_> = tables .iter() .map(|t| _create_table_reference(t.clone(), None, table_id_counter.next())) .collect(); @@ -1190,8 +1190,8 @@ mod tests { // Create where clause linking each table to the next let mut where_clause = Vec::new(); for i in 0..NUM_TABLES - 1 { - let internal_id_left = table_references[i].internal_id; - let internal_id_right = table_references[i + 1].internal_id; + let internal_id_left = joined_tables[i].internal_id; + let internal_id_right = joined_tables[i + 1].internal_id; where_clause.push(_create_binary_expr( _create_column_expr(internal_id_left, 1, false), // ti.next_id ast::Operator::Equals, @@ -1201,12 +1201,12 @@ mod tests { let access_methods_arena = RefCell::new(Vec::new()); let table_constraints = - constraints_from_where_clause(&where_clause, &table_references, &available_indexes) + constraints_from_where_clause(&where_clause, &joined_tables, &available_indexes) .unwrap(); // Run the optimizer let BestJoinOrderResult { best_plan, .. } = compute_best_join_order( - &table_references, + &joined_tables, None, &table_constraints, &access_methods_arena, @@ -1302,9 +1302,9 @@ mod tests { table: Rc, join_info: Option, internal_id: TableInternalId, - ) -> TableReference { + ) -> JoinedTable { let name = table.name.clone(); - TableReference { + JoinedTable { table: Table::BTree(table), op: Operation::Scan { iter_dir: IterationDirection::Forwards, diff --git a/core/translate/optimizer/mod.rs b/core/translate/optimizer/mod.rs index bcb09a0a0..c9126bd58 100644 --- a/core/translate/optimizer/mod.rs +++ b/core/translate/optimizer/mod.rs @@ -20,8 +20,8 @@ use crate::{ use super::{ emitter::Resolver, plan::{ - DeletePlan, GroupBy, IterationDirection, JoinOrderMember, Operation, Plan, Search, SeekDef, - SeekKey, SelectPlan, TableReference, UpdatePlan, WhereTerm, + DeletePlan, GroupBy, IterationDirection, JoinOrderMember, JoinedTable, Operation, Plan, + Search, SeekDef, SeekKey, SelectPlan, TableReferences, UpdatePlan, WhereTerm, }, }; @@ -52,7 +52,7 @@ pub fn optimize_plan(plan: &mut Plan, schema: &Schema) -> Result<()> { * TODO: these could probably be done in less passes, * but having them separate makes them easier to understand */ -fn optimize_select_plan(plan: &mut SelectPlan, schema: &Schema) -> Result<()> { +pub fn optimize_select_plan(plan: &mut SelectPlan, schema: &Schema) -> Result<()> { optimize_subqueries(plan, schema)?; rewrite_exprs_select(plan)?; if let ConstantConditionEliminationResult::ImpossibleCondition = @@ -116,7 +116,7 @@ fn optimize_update_plan(plan: &mut UpdatePlan, schema: &Schema) -> Result<()> { } fn optimize_subqueries(plan: &mut SelectPlan, schema: &Schema) -> Result<()> { - for table in plan.table_references.iter_mut() { + for table in plan.table_references.joined_tables_mut() { if let Table::FromClauseSubquery(from_clause_subquery) = &mut table.table { optimize_select_plan(&mut from_clause_subquery.plan, schema)?; } @@ -131,13 +131,13 @@ fn optimize_subqueries(plan: &mut SelectPlan, schema: &Schema) -> Result<()> { /// - Computes a set of [Constraint]s for each table. /// - Using those constraints, computes the best join order for the list of [TableReference]s /// and selects the best [crate::translate::optimizer::access_method::AccessMethod] for each table in the join order. -/// - Mutates the [Operation]s in `table_references` to use the selected access methods. +/// - Mutates the [Operation]s in `joined_tables` to use the selected access methods. /// - Removes predicates from the `where_clause` that are now redundant due to the selected access methods. /// - Removes sorting operations if the selected join order and access methods satisfy the [crate::translate::optimizer::order::OrderTarget]. /// /// Returns the join order if it was optimized, or None if the default join order was considered best. fn optimize_table_access( - table_references: &mut [TableReference], + table_references: &mut TableReferences, available_indexes: &HashMap>>, where_clause: &mut Vec, order_by: &mut Option>, @@ -146,9 +146,9 @@ fn optimize_table_access( let access_methods_arena = RefCell::new(Vec::new()); let maybe_order_target = compute_order_target(order_by, group_by.as_mut()); let constraints_per_table = - constraints_from_where_clause(where_clause, table_references, available_indexes)?; + constraints_from_where_clause(where_clause, &table_references, available_indexes)?; let Some(best_join_order_result) = compute_best_join_order( - table_references, + table_references.joined_tables_mut(), maybe_order_target.as_ref(), &constraints_per_table, &access_methods_arena, @@ -162,6 +162,8 @@ fn optimize_table_access( best_ordered_plan, } = best_join_order_result; + let joined_tables = table_references.joined_tables_mut(); + // See if best_ordered_plan is better than the overall best_plan if we add a sorting penalty // to the unordered plan's cost. let best_plan = if let Some(best_ordered_plan) = best_ordered_plan { @@ -184,7 +186,7 @@ fn optimize_table_access( let satisfies_order_target = plan_satisfies_order_target( &best_plan, &access_methods_arena, - table_references, + joined_tables, &order_target, ); if satisfies_order_target { @@ -211,15 +213,15 @@ fn optimize_table_access( let best_join_order: Vec = best_table_numbers .into_iter() .map(|table_number| JoinOrderMember { - table_id: table_references[table_number].internal_id, + table_id: joined_tables[table_number].internal_id, original_idx: table_number, - is_outer: table_references[table_number] + is_outer: joined_tables[table_number] .join_info .as_ref() .map_or(false, |join_info| join_info.outer), }) .collect(); - // Mutate the Operations in `table_references` to use the selected access methods. + // Mutate the Operations in `joined_tables` to use the selected access methods. for (i, join_order_member) in best_join_order.iter().enumerate() { let table_idx = join_order_member.original_idx; let access_method = &access_methods_arena.borrow()[best_access_methods[i]]; @@ -227,14 +229,14 @@ fn optimize_table_access( let is_leftmost_table = i == 0; let uses_index = access_method.index.is_some(); let source_table_is_from_clause_subquery = matches!( - &table_references[table_idx].table, + &joined_tables[table_idx].table, Table::FromClauseSubquery(_) ); let try_to_build_ephemeral_index = !is_leftmost_table && !uses_index && !source_table_is_from_clause_subquery; if !try_to_build_ephemeral_index { - table_references[table_idx].op = Operation::Scan { + joined_tables[table_idx].op = Operation::Scan { iter_dir: access_method.iter_dir, index: access_method.index.clone(), }; @@ -246,7 +248,7 @@ fn optimize_table_access( .iter() .find(|c| c.table_id == join_order_member.table_id); let Some(table_constraints) = table_constraints else { - table_references[table_idx].op = Operation::Scan { + joined_tables[table_idx].op = Operation::Scan { iter_dir: access_method.iter_dir, index: access_method.index.clone(), }; @@ -265,19 +267,19 @@ fn optimize_table_access( &best_join_order[..=i], ); if usable_constraint_refs.is_empty() { - table_references[table_idx].op = Operation::Scan { + joined_tables[table_idx].op = Operation::Scan { iter_dir: access_method.iter_dir, index: access_method.index.clone(), }; continue; } let ephemeral_index = ephemeral_index_build( - &table_references[table_idx], + &joined_tables[table_idx], &table_constraints.constraints, &usable_constraint_refs, ); let ephemeral_index = Arc::new(ephemeral_index); - table_references[table_idx].op = Operation::Search(Search::Seek { + joined_tables[table_idx].op = Operation::Search(Search::Seek { index: Some(ephemeral_index), seek_def: build_seek_def_from_constraints( &table_constraints.constraints, @@ -302,7 +304,7 @@ fn optimize_table_access( .set(true); } if let Some(index) = &access_method.index { - table_references[table_idx].op = Operation::Search(Search::Seek { + joined_tables[table_idx].op = Operation::Search(Search::Seek { index: Some(index.clone()), seek_def: build_seek_def_from_constraints( &constraints_per_table[table_idx].constraints, @@ -320,7 +322,7 @@ fn optimize_table_access( ); let constraint = &constraints_per_table[table_idx].constraints [constraint_refs[0].constraint_vec_pos]; - table_references[table_idx].op = match constraint.operator { + joined_tables[table_idx].op = match constraint.operator { ast::Operator::Equals => Operation::Search(Search::RowidEq { cmp_expr: constraint.get_constraining_expr(where_clause), }), @@ -458,7 +460,7 @@ pub trait Optimizable { .map_or(false, |c| c == AlwaysTrueOrFalse::AlwaysFalse)) } fn is_constant(&self, resolver: &Resolver<'_>) -> bool; - fn is_nonnull(&self, tables: &[TableReference]) -> bool; + fn is_nonnull(&self, tables: &TableReferences) -> bool; } impl Optimizable for ast::Expr { @@ -468,7 +470,7 @@ impl Optimizable for ast::Expr { /// This function is currently very conservative, and will return false /// for any expression where we aren't sure and didn't bother to find out /// by writing more complex code. - fn is_nonnull(&self, tables: &[TableReference]) -> bool { + fn is_nonnull(&self, tables: &TableReferences) -> bool { match self { Expr::Between { lhs, start, end, .. @@ -507,7 +509,7 @@ impl Optimizable for ast::Expr { return true; } - let table_ref = tables.iter().find(|t| t.internal_id == *table).unwrap(); + let table_ref = tables.find_joined_table_by_internal_id(*table).unwrap(); let columns = table_ref.columns(); let column = &columns[*column]; return column.primary_key || column.notnull; @@ -748,7 +750,7 @@ impl Optimizable for ast::Expr { } fn ephemeral_index_build( - table_reference: &TableReference, + table_reference: &JoinedTable, constraints: &[Constraint], constraint_refs: &[ConstraintRef], ) -> Index { diff --git a/core/translate/optimizer/order.rs b/core/translate/optimizer/order.rs index 51b05a5a5..ea1692a9e 100644 --- a/core/translate/optimizer/order.rs +++ b/core/translate/optimizer/order.rs @@ -3,7 +3,7 @@ use std::cell::RefCell; use limbo_sqlite3_parser::ast::{self, SortOrder, TableInternalId}; use crate::{ - translate::plan::{GroupBy, IterationDirection, TableReference}, + translate::plan::{GroupBy, IterationDirection, JoinedTable}, util::exprs_are_equivalent, }; @@ -157,14 +157,14 @@ pub fn compute_order_target( pub fn plan_satisfies_order_target( plan: &JoinN, access_methods_arena: &RefCell>, - table_references: &[TableReference], + joined_tables: &[JoinedTable], order_target: &OrderTarget, ) -> bool { let mut target_col_idx = 0; let num_cols_in_order_target = order_target.0.len(); for (table_index, access_method_index) in plan.data.iter() { let target_col = &order_target.0[target_col_idx]; - let table_ref = &table_references[*table_index]; + let table_ref = &joined_tables[*table_index]; let correct_table = target_col.table_id == table_ref.internal_id; if !correct_table { return false; diff --git a/core/translate/order_by.rs b/core/translate/order_by.rs index 67ab13422..eed98cd3e 100644 --- a/core/translate/order_by.rs +++ b/core/translate/order_by.rs @@ -16,7 +16,7 @@ use crate::{ use super::{ emitter::{Resolver, TranslateCtx}, expr::translate_expr, - plan::{Distinctness, ResultSetColumn, SelectPlan, TableReference}, + plan::{Distinctness, ResultSetColumn, SelectPlan, TableReferences}, result_row::{emit_offset, emit_result_row_and_limit}, }; @@ -34,7 +34,7 @@ pub fn init_order_by( program: &mut ProgramBuilder, t_ctx: &mut TranslateCtx, order_by: &[(ast::Expr, SortOrder)], - referenced_tables: &[TableReference], + referenced_tables: &TableReferences, ) -> Result<()> { let sort_cursor = program.alloc_cursor_id(CursorType::Sorter); t_ctx.meta_sort = Some(SortMetadata { @@ -54,12 +54,9 @@ pub fn init_order_by( .map(|(expr, _)| match expr { ast::Expr::Collate(_, collation_name) => CollationSeq::new(collation_name).map(Some), ast::Expr::Column { table, column, .. } => { - let table_reference = referenced_tables - .iter() - .find(|t| t.internal_id == *table) - .unwrap(); + let table = referenced_tables.find_table_by_internal_id(*table).unwrap(); - let Some(table_column) = table_reference.table.get_column_at(*column) else { + let Some(table_column) = table.get_column_at(*column) else { crate::bail_parse_error!("column index out of bounds"); }; diff --git a/core/translate/plan.rs b/core/translate/plan.rs index d79437de9..5d2e4f81d 100644 --- a/core/translate/plan.rs +++ b/core/translate/plan.rs @@ -35,24 +35,19 @@ pub struct ResultSetColumn { } impl ResultSetColumn { - pub fn name<'a>(&'a self, tables: &'a [TableReference]) -> Option<&'a str> { + pub fn name<'a>(&'a self, tables: &'a TableReferences) -> Option<&'a str> { if let Some(alias) = &self.alias { return Some(alias); } match &self.expr { ast::Expr::Column { table, column, .. } => { - let table_ref = tables.iter().find(|t| t.internal_id == *table).unwrap(); - table_ref - .table - .get_column_at(*column) - .unwrap() - .name - .as_deref() + let table_ref = tables.find_table_by_internal_id(*table).unwrap(); + table_ref.get_column_at(*column).unwrap().name.as_deref() } ast::Expr::RowId { table, .. } => { // If there is a rowid alias column, use its name - let table_ref = tables.iter().find(|t| t.internal_id == *table).unwrap(); - if let Table::BTree(table) = &table_ref.table { + let table_ref = tables.find_table_by_internal_id(*table).unwrap(); + if let Table::BTree(table) = &table_ref { if let Some(rowid_alias_column) = table.get_rowid_alias_column() { if let Some(name) = &rowid_alias_column.1.name { return Some(name); @@ -345,7 +340,7 @@ pub struct JoinOrderMember { /// The internal ID of the[TableReference] pub table_id: TableInternalId, /// The index of the table in the original join order. - /// This is used to index into e.g. [SelectPlan::table_references] + /// This is used to index into e.g. [TableReferences::joined_tables()] pub original_idx: usize, /// Whether this member is the right side of an OUTER JOIN pub is_outer: bool, @@ -429,9 +424,8 @@ impl DistinctCtx { #[derive(Debug, Clone)] pub struct SelectPlan { - /// List of table references in loop order, outermost first. - pub table_references: Vec, - /// The order in which the tables are joined. Tables have usize Ids (their index in table_references) + pub table_references: TableReferences, + /// The order in which the tables are joined. Tables have usize Ids (their index in joined_tables) pub join_order: Vec, /// the columns inside SELECT ... FROM pub result_columns: Vec, @@ -459,6 +453,10 @@ pub struct SelectPlan { } impl SelectPlan { + pub fn joined_tables(&self) -> &[JoinedTable] { + self.table_references.joined_tables() + } + pub fn agg_args_count(&self) -> usize { self.aggregates.iter().map(|agg| agg.args.len()).sum() } @@ -502,7 +500,8 @@ impl SelectPlan { self.query_destination, QueryDestination::CoroutineYield { .. } ) - || self.table_references.len() != 1 + || self.table_references.joined_tables().len() != 1 + || self.table_references.outer_query_refs().len() != 0 || self.result_columns.len() != 1 || self.group_by.is_some() || self.contains_constant_false_condition @@ -510,7 +509,7 @@ impl SelectPlan { { return false; } - let table_ref = self.table_references.first().unwrap(); + let table_ref = self.table_references.joined_tables().first().unwrap(); if !matches!(table_ref.table, crate::schema::Table::BTree(..)) { return false; } @@ -541,8 +540,7 @@ impl SelectPlan { #[allow(dead_code)] #[derive(Debug, Clone)] pub struct DeletePlan { - /// List of table references. Delete is always a single table. - pub table_references: Vec, + pub table_references: TableReferences, /// the columns inside SELECT ... FROM pub result_columns: Vec, /// where clause split into a vec at 'AND' boundaries. @@ -561,8 +559,7 @@ pub struct DeletePlan { #[derive(Debug, Clone)] pub struct UpdatePlan { - // table being updated is always first - pub table_references: Vec, + pub table_references: TableReferences, // (colum index, new value) pairs pub set_clauses: Vec<(usize, ast::Expr)>, pub where_clause: Vec, @@ -583,7 +580,7 @@ pub enum IterationDirection { Backwards, } -pub fn select_star(tables: &[TableReference], out_columns: &mut Vec) { +pub fn select_star(tables: &[JoinedTable], out_columns: &mut Vec) { for table in tables.iter() { let maybe_using_cols = table .join_info @@ -856,6 +853,16 @@ impl TableReferences { ); } } + + pub fn contains_table(&self, table: &Table) -> bool { + self.joined_tables.iter().any(|t| t.table == *table) + || self.outer_query_refs.iter().any(|t| t.table == *table) + } + + pub fn extend(&mut self, other: TableReferences) { + self.joined_tables.extend(other.joined_tables); + self.outer_query_refs.extend(other.outer_query_refs); + } } #[derive(Clone, Debug, PartialEq, Eq)] @@ -1233,8 +1240,8 @@ impl Display for SelectPlan { writeln!(f, "QUERY PLAN")?; // Print each table reference with appropriate indentation based on join depth - for (i, reference) in self.table_references.iter().enumerate() { - let is_last = i == self.table_references.len() - 1; + for (i, reference) in self.table_references.joined_tables().iter().enumerate() { + let is_last = i == self.table_references.joined_tables().len() - 1; let indent = if i == 0 { if is_last { "`--" } else { "|--" }.to_string() } else { @@ -1284,7 +1291,7 @@ impl Display for DeletePlan { writeln!(f, "QUERY PLAN")?; // Delete plan should only have one table reference - if let Some(reference) = self.table_references.first() { + if let Some(reference) = self.table_references.joined_tables().first() { let indent = "`--"; match &reference.op { @@ -1310,8 +1317,8 @@ impl fmt::Display for UpdatePlan { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "QUERY PLAN")?; - for (i, reference) in self.table_references.iter().enumerate() { - let is_last = i == self.table_references.len() - 1; + for (i, reference) in self.table_references.joined_tables().iter().enumerate() { + let is_last = i == self.table_references.joined_tables().len() - 1; let indent = if i == 0 { if is_last { "`--" } else { "|--" }.to_string() } else { diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 28c7c1b0e..0863b1c68 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -4,8 +4,8 @@ use super::{ expr::walk_expr, plan::{ Aggregate, ColumnUsedMask, Distinctness, EvalAt, IterationDirection, JoinInfo, - JoinOrderMember, Operation, Plan, QueryDestination, ResultSetColumn, SelectPlan, - TableReference, WhereTerm, + JoinOrderMember, JoinedTable, Operation, OuterQueryReference, Plan, QueryDestination, + ResultSetColumn, TableReferences, WhereTerm, }, select::prepare_select_plan, SymbolTable, @@ -97,7 +97,7 @@ pub fn resolve_aggregates(top_level_expr: &Expr, aggs: &mut Vec) -> R pub fn bind_column_references( top_level_expr: &mut Expr, - referenced_tables: &mut [TableReference], + referenced_tables: &mut TableReferences, result_columns: Option<&[ResultSetColumn]>, ) -> Result<()> { walk_expr_mut(top_level_expr, &mut |expr: &mut Expr| -> Result<()> { @@ -110,20 +110,22 @@ pub fn bind_column_references( } let normalized_id = normalize_ident(id.0.as_str()); - if !referenced_tables.is_empty() { - if let Some(row_id_expr) = - parse_row_id(&normalized_id, referenced_tables[0].internal_id, || { - referenced_tables.len() != 1 - })? - { + if !referenced_tables.joined_tables().is_empty() { + if let Some(row_id_expr) = parse_row_id( + &normalized_id, + referenced_tables.joined_tables()[0].internal_id, + || referenced_tables.joined_tables().len() != 1, + )? { *expr = row_id_expr; return Ok(()); } } let mut match_result = None; - for (tbl_idx, table) in referenced_tables.iter().enumerate() { - let col_idx = table.columns().iter().position(|c| { + + // First check joined tables + for joined_table in referenced_tables.joined_tables().iter() { + let col_idx = joined_table.table.columns().iter().position(|c| { c.name .as_ref() .map_or(false, |name| name.eq_ignore_ascii_case(&normalized_id)) @@ -132,18 +134,49 @@ pub fn bind_column_references( if match_result.is_some() { crate::bail_parse_error!("Column {} is ambiguous", id.0); } - let col = table.columns().get(col_idx.unwrap()).unwrap(); - match_result = Some((tbl_idx, col_idx.unwrap(), col.is_rowid_alias)); + let col = joined_table.table.columns().get(col_idx.unwrap()).unwrap(); + match_result = Some(( + joined_table.internal_id, + col_idx.unwrap(), + col.is_rowid_alias, + )); } } - if let Some((tbl_idx, col_idx, is_rowid_alias)) = match_result { + + // Then check outer query references, if we still didn't find something. + // Normally finding multiple matches for a non-qualified column is an error (column x is ambiguous) + // but in the case of subqueries, the inner query takes precedence. + // For example: + // SELECT * FROM t WHERE x = (SELECT x FROM t2) + // In this case, there is no ambiguity: + // - x in the outer query refers to t.x, + // - x in the inner query refers to t2.x. + if match_result.is_none() { + for outer_ref in referenced_tables.outer_query_refs().iter() { + let col_idx = outer_ref.table.columns().iter().position(|c| { + c.name + .as_ref() + .map_or(false, |name| name.eq_ignore_ascii_case(&normalized_id)) + }); + if col_idx.is_some() { + if match_result.is_some() { + crate::bail_parse_error!("Column {} is ambiguous", id.0); + } + let col = outer_ref.table.columns().get(col_idx.unwrap()).unwrap(); + match_result = + Some((outer_ref.internal_id, col_idx.unwrap(), col.is_rowid_alias)); + } + } + } + + if let Some((table_id, col_idx, is_rowid_alias)) = match_result { *expr = Expr::Column { database: None, // TODO: support different databases - table: referenced_tables[tbl_idx].internal_id, + table: table_id, column: col_idx, is_rowid_alias, }; - referenced_tables[tbl_idx].mark_column_used(col_idx); + referenced_tables.mark_column_used(table_id, col_idx); return Ok(()); } @@ -162,43 +195,35 @@ pub fn bind_column_references( } Expr::Qualified(tbl, id) => { let normalized_table_name = normalize_ident(tbl.0.as_str()); - let matching_tbl_idx = referenced_tables - .iter() - .position(|t| t.identifier.eq_ignore_ascii_case(&normalized_table_name)); - if matching_tbl_idx.is_none() { + let matching_tbl = referenced_tables + .find_table_and_internal_id_by_identifier(&normalized_table_name); + if matching_tbl.is_none() { crate::bail_parse_error!("Table {} not found", normalized_table_name); } - let tbl_idx = matching_tbl_idx.unwrap(); + let (tbl_id, tbl) = matching_tbl.unwrap(); let normalized_id = normalize_ident(id.0.as_str()); - if let Some(row_id_expr) = parse_row_id( - &normalized_id, - referenced_tables[tbl_idx].internal_id, - || false, - )? { + if let Some(row_id_expr) = parse_row_id(&normalized_id, tbl_id, || false)? { *expr = row_id_expr; return Ok(()); } - let col_idx = referenced_tables[tbl_idx].columns().iter().position(|c| { + let col_idx = tbl.columns().iter().position(|c| { c.name .as_ref() .map_or(false, |name| name.eq_ignore_ascii_case(&normalized_id)) }); - if col_idx.is_none() { + let Some(col_idx) = col_idx else { crate::bail_parse_error!("Column {} not found", normalized_id); - } - let col = referenced_tables[tbl_idx] - .columns() - .get(col_idx.unwrap()) - .unwrap(); + }; + let col = tbl.columns().get(col_idx).unwrap(); *expr = Expr::Column { database: None, // TODO: support different databases - table: referenced_tables[tbl_idx].internal_id, - column: col_idx.unwrap(), + table: tbl_id, + column: col_idx, is_rowid_alias: col.is_rowid_alias, }; - referenced_tables[tbl_idx].mark_column_used(col_idx.unwrap()); + referenced_tables.mark_column_used(tbl_id, col_idx); Ok(()) } _ => Ok(()), @@ -209,7 +234,8 @@ pub fn bind_column_references( fn parse_from_clause_table<'a>( schema: &Schema, table: ast::SelectTable, - scope: &mut Scope<'a>, + table_references: &mut TableReferences, + ctes: &mut Vec, syms: &SymbolTable, table_ref_counter: &mut TableRefIdCounter, ) -> Result<()> { @@ -217,22 +243,16 @@ fn parse_from_clause_table<'a>( ast::SelectTable::Table(qualified_name, maybe_alias, _) => { let normalized_qualified_name = normalize_ident(qualified_name.name.0.as_str()); // Check if the FROM clause table is referring to a CTE in the current scope. - if let Some(cte) = scope - .ctes + if let Some(cte_idx) = ctes .iter() - .find(|cte| cte.name == normalized_qualified_name) + .position(|cte| cte.identifier == normalized_qualified_name) { - // CTE can be rewritten as a subquery. - // TODO: find a way not to clone the CTE plan here. - let cte_table = TableReference::new_subquery( - cte.name.clone(), - cte.plan.clone(), - None, - table_ref_counter.next(), - ); - scope.tables.push(cte_table); + // TODO: what if the CTE is referenced multiple times? + let cte_table = ctes.remove(cte_idx); + table_references.add_joined_table(cte_table); return Ok(()); }; + // Check if our top level schema has this table. if let Some(table) = schema.get_table(&normalized_qualified_name) { let alias = maybe_alias @@ -250,7 +270,7 @@ fn parse_from_clause_table<'a>( "Table type not supported".to_string(), )); }; - scope.tables.push(TableReference { + table_references.add_joined_table(JoinedTable { op: Operation::Scan { iter_dir: IterationDirection::Forwards, index: None, @@ -264,30 +284,28 @@ fn parse_from_clause_table<'a>( return Ok(()); }; - // Check if the outer query scope has this table. - if let Some(outer_scope) = scope.parent { - if let Some(table_ref_idx) = outer_scope - .tables - .iter() - .position(|t| t.identifier == normalized_qualified_name) - { - // TODO: avoid cloning the table reference here. - scope.tables.push(outer_scope.tables[table_ref_idx].clone()); - return Ok(()); - } - if let Some(cte) = outer_scope - .ctes - .iter() - .find(|cte| cte.name == normalized_qualified_name) - { - // TODO: avoid cloning the CTE plan here. - let cte_table = TableReference::new_subquery( - cte.name.clone(), - cte.plan.clone(), - None, - table_ref_counter.next(), - ); - scope.tables.push(cte_table); + // CTEs are transformed into FROM clause subqueries. + // If we find a CTE with this name in our outer query references, + // we can use it as a joined table, but we must clone it since it's not MATERIALIZED. + // + // For other types of tables in the outer query references, we do not add them as joined tables, + // because the query can simply _reference_ them in e.g. the SELECT columns or the WHERE clause, + // but it's not part of the join order. + if let Some(outer_ref) = + table_references.find_outer_query_ref_by_identifier(&normalized_qualified_name) + { + if matches!(outer_ref.table, Table::FromClauseSubquery(_)) { + table_references.add_joined_table(JoinedTable { + op: Operation::Scan { + iter_dir: IterationDirection::Forwards, + index: None, + }, + table: outer_ref.table.clone(), + identifier: outer_ref.identifier.clone(), + internal_id: table_ref_counter.next(), + join_info: None, + col_used_mask: ColumnUsedMask::new(), + }); return Ok(()); } } @@ -295,29 +313,28 @@ fn parse_from_clause_table<'a>( crate::bail_parse_error!("Table {} not found", normalized_qualified_name); } ast::SelectTable::Select(subselect, maybe_alias) => { - let query_destination = QueryDestination::CoroutineYield { - yield_reg: usize::MAX, // will be set later in bytecode emission - coroutine_implementation_start: BranchOffset::Placeholder, // will be set later in bytecode emission - }; let Plan::Select(subplan) = prepare_select_plan( schema, *subselect, syms, - Some(scope), + table_references.outer_query_refs(), table_ref_counter, - query_destination, + QueryDestination::CoroutineYield { + yield_reg: usize::MAX, // will be set later in bytecode emission + coroutine_implementation_start: BranchOffset::Placeholder, // will be set later in bytecode emission + }, )? else { crate::bail_parse_error!("Only non-compound SELECT queries are currently supported in FROM clause subqueries"); }; - let cur_table_index = scope.tables.len(); + let cur_table_index = table_references.joined_tables().len(); let identifier = maybe_alias .map(|a| match a { ast::As::As(id) => id.0.clone(), ast::As::Elided(id) => id.0.clone(), }) .unwrap_or(format!("subquery_{}", cur_table_index)); - scope.tables.push(TableReference::new_subquery( + table_references.add_joined_table(JoinedTable::new_subquery( identifier, subplan, None, @@ -347,7 +364,7 @@ fn parse_from_clause_table<'a>( }) .unwrap_or(normalized_name.to_string()); - scope.tables.push(TableReference { + table_references.add_joined_table(JoinedTable { op: Operation::Scan { iter_dir: IterationDirection::Forwards, index: None, @@ -365,64 +382,20 @@ fn parse_from_clause_table<'a>( } } -/// A scope is a list of tables that are visible to the current query. -/// It is used to resolve table references in the FROM clause. -/// To resolve table references that are potentially ambiguous, the resolution -/// first looks at schema tables and tables in the current scope (which currently just means CTEs in the current query), -/// and only after that looks at whether a table from an outer (upper) query level matches. -/// -/// For example: -/// -/// WITH nested AS (SELECT foo FROM bar) -/// WITH sub AS (SELECT foo FROM bar) -/// SELECT * FROM sub -/// -/// 'sub' would preferentially refer to the 'foo' column from the 'bar' table in the catalog. -/// With an explicit reference like: -/// -/// SELECT nested.foo FROM sub -/// -/// 'nested.foo' would refer to the 'foo' column from the 'nested' CTE. -/// -/// TODO: we should probably use Scope in all of our identifier resolution, because it allows for e.g. -/// WITH users AS (SELECT * FROM products) SELECT * FROM users <-- returns products, even if there is a table named 'users' in the catalog! -/// -/// Currently we are treating Schema as a first-class object in identifier resolution, when in reality -/// be part of the 'Scope' struct. -pub struct Scope<'a> { - /// The tables that are explicitly present in the current query, including catalog tables and CTEs. - tables: Vec, - ctes: Vec, - /// The parent scope, if any. For example, a second CTE has access to the first CTE via the parent scope. - parent: Option<&'a Scope<'a>>, -} - -pub struct Cte { - /// The name of the CTE. - name: String, - /// The query plan for the CTE. - /// Currently we only support SELECT queries in CTEs. - plan: SelectPlan, -} - pub fn parse_from<'a>( schema: &Schema, mut from: Option, syms: &SymbolTable, with: Option, out_where_clause: &mut Vec, - outer_scope: Option<&'a Scope<'a>>, + table_references: &mut TableReferences, table_ref_counter: &mut TableRefIdCounter, -) -> Result> { +) -> Result<()> { if from.as_ref().and_then(|f| f.select.as_ref()).is_none() { - return Ok(vec![]); + return Ok(()); } - let mut scope = Scope { - tables: vec![], - ctes: vec![], - parent: outer_scope, - }; + let mut ctes_as_subqueries = vec![]; if let Some(with) = with { if with.recursive { @@ -446,63 +419,82 @@ pub fn parse_from<'a>( cte.tbl_name.0 ); } - if scope - .tables + if table_references + .outer_query_refs() .iter() .any(|t| t.identifier == cte_name_normalized) { - crate::bail_parse_error!("CTE name {} conflicts with table name", cte.tbl_name.0); - } - if scope.ctes.iter().any(|c| c.name == cte_name_normalized) { - crate::bail_parse_error!("duplicate WITH table name {}", cte.tbl_name.0); + crate::bail_parse_error!( + "CTE name {} conflicts with WITH table name {}", + cte.tbl_name.0, + cte_name_normalized + ); } - // CTE can be rewritten as a subquery. - let query_destination = QueryDestination::CoroutineYield { - yield_reg: usize::MAX, // will be set later in bytecode emission - coroutine_implementation_start: BranchOffset::Placeholder, // will be set later in bytecode emission - }; + let mut outer_query_refs_for_cte = table_references.outer_query_refs().to_vec(); + outer_query_refs_for_cte.extend(ctes_as_subqueries.iter().map(|t: &JoinedTable| { + OuterQueryReference { + identifier: t.identifier.clone(), + internal_id: t.internal_id, + table: t.table.clone(), + col_used_mask: ColumnUsedMask::new(), + } + })); + // CTE can refer to other CTEs that came before it, plus any schema tables or tables in the outer scope. let cte_plan = prepare_select_plan( schema, *cte.select, syms, - Some(&scope), + &outer_query_refs_for_cte, table_ref_counter, - query_destination, + QueryDestination::CoroutineYield { + yield_reg: usize::MAX, // will be set later in bytecode emission + coroutine_implementation_start: BranchOffset::Placeholder, // will be set later in bytecode emission + }, )?; let Plan::Select(cte_plan) = cte_plan else { crate::bail_parse_error!("Only SELECT queries are currently supported in CTEs"); }; - scope.ctes.push(Cte { - name: cte_name_normalized, - plan: cte_plan, - }); + ctes_as_subqueries.push(JoinedTable::new_subquery( + cte_name_normalized, + cte_plan, + None, + table_ref_counter.next(), + )); } } let mut from_owned = std::mem::take(&mut from).unwrap(); let select_owned = *std::mem::take(&mut from_owned.select).unwrap(); let joins_owned = std::mem::take(&mut from_owned.joins).unwrap_or_default(); - parse_from_clause_table(schema, select_owned, &mut scope, syms, table_ref_counter)?; + parse_from_clause_table( + schema, + select_owned, + table_references, + &mut ctes_as_subqueries, + syms, + table_ref_counter, + )?; for join in joins_owned.into_iter() { parse_join( schema, join, syms, - &mut scope, + &mut ctes_as_subqueries, out_where_clause, + table_references, table_ref_counter, )?; } - Ok(scope.tables) + Ok(()) } pub fn parse_where( where_clause: Option, - table_references: &mut [TableReference], + table_references: &mut TableReferences, result_columns: Option<&[ResultSetColumn]>, out_where_clause: &mut Vec, ) -> Result<()> { @@ -549,7 +541,7 @@ pub fn determine_where_to_eval_term( } /// A bitmask representing a set of tables in a query plan. -/// Tables are numbered by their index in [SelectPlan::table_references]. +/// Tables are numbered by their index in [SelectPlan::joined_tables]. /// In the bitmask, the first bit is unused so that a mask with all zeros /// can represent "no tables". /// @@ -647,17 +639,27 @@ impl TableMask { /// Used in the optimizer for constraint analysis. pub fn table_mask_from_expr( top_level_expr: &Expr, - table_references: &[TableReference], + table_references: &TableReferences, ) -> Result { let mut mask = TableMask::new(); walk_expr(top_level_expr, &mut |expr: &Expr| -> Result<()> { match expr { Expr::Column { table, .. } | Expr::RowId { table, .. } => { - let table_idx = table_references + if let Some(table_idx) = table_references + .joined_tables() .iter() .position(|t| t.internal_id == *table) - .expect("table not found in table_references"); - mask.add_table(table_idx); + { + mask.add_table(table_idx); + } else if table_references + .find_outer_query_ref_by_internal_id(*table) + .is_none() + { + // Tables from outer query scopes are guaranteed to be 'in scope' for this query, + // so they don't need to be added to the table mask. However, if the table is not found + // in the outer scope either, then it's an invalid reference. + crate::bail_parse_error!("table not found in joined_tables"); + } } _ => {} } @@ -693,8 +695,9 @@ fn parse_join<'a>( schema: &Schema, join: ast::JoinedSelectTable, syms: &SymbolTable, - scope: &mut Scope<'a>, + ctes: &mut Vec, out_where_clause: &mut Vec, + table_references: &mut TableReferences, table_ref_counter: &mut TableRefIdCounter, ) -> Result<()> { let ast::JoinedSelectTable { @@ -703,7 +706,14 @@ fn parse_join<'a>( constraint, } = join; - parse_from_clause_table(schema, table, scope, syms, table_ref_counter)?; + parse_from_clause_table( + schema, + table, + table_references, + ctes, + syms, + table_ref_counter, + )?; let (outer, natural) = match join_operator { ast::JoinOperator::TypedJoin(Some(join_type)) => { @@ -721,15 +731,19 @@ fn parse_join<'a>( } let constraint = if natural { - assert!(scope.tables.len() >= 2); - let rightmost_table = scope.tables.last().unwrap(); + assert!(table_references.joined_tables().len() >= 2); + let rightmost_table = table_references.joined_tables().last().unwrap(); // NATURAL JOIN is first transformed into a USING join with the common columns let right_cols = rightmost_table.columns(); let mut distinct_names: Option = None; // TODO: O(n^2) maybe not great for large tables or big multiway joins for right_col in right_cols.iter() { let mut found_match = false; - for left_table in scope.tables.iter().take(scope.tables.len() - 1) { + for left_table in table_references + .joined_tables() + .iter() + .take(table_references.joined_tables().len() - 1) + { for left_col in left_table.columns().iter() { if left_col.name == right_col.name { if let Some(distinct_names) = distinct_names.as_mut() { @@ -767,13 +781,13 @@ fn parse_join<'a>( let mut preds = vec![]; break_predicate_at_and_boundaries(expr, &mut preds); for predicate in preds.iter_mut() { - bind_column_references(predicate, &mut scope.tables, None)?; + bind_column_references(predicate, table_references, None)?; } for pred in preds { out_where_clause.push(WhereTerm { expr: pred, from_outer_join: if outer { - Some(scope.tables.last().unwrap().internal_id) + Some(table_references.joined_tables().last().unwrap().internal_id) } else { None }, @@ -785,10 +799,10 @@ fn parse_join<'a>( // USING join is replaced with a list of equality predicates for distinct_name in distinct_names.iter() { let name_normalized = normalize_ident(distinct_name.0.as_str()); - let cur_table_idx = scope.tables.len() - 1; - let left_tables = &scope.tables[..cur_table_idx]; + let cur_table_idx = table_references.joined_tables().len() - 1; + let left_tables = &table_references.joined_tables()[..cur_table_idx]; assert!(!left_tables.is_empty()); - let right_table = scope.tables.last().unwrap(); + let right_table = table_references.joined_tables().last().unwrap(); let mut left_col = None; for (left_table_idx, left_table) in left_tables.iter().enumerate() { left_col = left_table @@ -840,9 +854,15 @@ fn parse_join<'a>( }), ); - let left_table = scope.tables.get_mut(left_table_idx).unwrap(); + let left_table: &mut JoinedTable = table_references + .joined_tables_mut() + .get_mut(left_table_idx) + .unwrap(); left_table.mark_column_used(left_col_idx); - let right_table = scope.tables.get_mut(cur_table_idx).unwrap(); + let right_table: &mut JoinedTable = table_references + .joined_tables_mut() + .get_mut(cur_table_idx) + .unwrap(); right_table.mark_column_used(right_col_idx); out_where_clause.push(WhereTerm { expr, @@ -859,9 +879,12 @@ fn parse_join<'a>( } } - assert!(scope.tables.len() >= 2); - let last_idx = scope.tables.len() - 1; - let rightmost_table = scope.tables.get_mut(last_idx).unwrap(); + assert!(table_references.joined_tables().len() >= 2); + let last_idx = table_references.joined_tables().len() - 1; + let rightmost_table = table_references + .joined_tables_mut() + .get_mut(last_idx) + .unwrap(); rightmost_table.join_info = Some(JoinInfo { outer, using }); Ok(()) diff --git a/core/translate/select.rs b/core/translate/select.rs index 5e02d9e25..392bfcf53 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -1,8 +1,8 @@ use super::emitter::{emit_program, TranslateCtx}; use super::plan::{ - select_star, Distinctness, JoinOrderMember, Operation, QueryDestination, Search, + select_star, Distinctness, JoinOrderMember, Operation, OuterQueryReference, QueryDestination, + Search, TableReferences, }; -use super::planner::Scope; use crate::function::{AggFunc, ExtFunc, Func}; use crate::schema::Table; use crate::translate::optimizer::optimize_plan; @@ -36,7 +36,7 @@ pub fn translate_select( schema, select, syms, - None, + &[], &mut program.table_reference_counter, query_destination, )?; @@ -90,7 +90,7 @@ pub fn prepare_select_plan<'a>( schema: &Schema, mut select: ast::Select, syms: &SymbolTable, - outer_scope: Option<&'a Scope<'a>>, + outer_query_refs: &[OuterQueryReference], table_ref_counter: &mut TableRefIdCounter, query_destination: QueryDestination, ) -> Result { @@ -105,7 +105,7 @@ pub fn prepare_select_plan<'a>( select.order_by.take(), select.with.take(), syms, - outer_scope, + outer_query_refs, table_ref_counter, query_destination, )?)) @@ -118,7 +118,7 @@ pub fn prepare_select_plan<'a>( None, None, syms, - outer_scope, + outer_query_refs, table_ref_counter, query_destination.clone(), )?; @@ -139,7 +139,7 @@ pub fn prepare_select_plan<'a>( None, None, syms, - outer_scope, + outer_query_refs, table_ref_counter, query_destination.clone(), )?; @@ -189,7 +189,7 @@ fn prepare_one_select_plan<'a>( order_by: Option>, with: Option, syms: &SymbolTable, - outer_scope: Option<&'a Scope<'a>>, + outer_query_refs: &[OuterQueryReference], table_ref_counter: &mut TableRefIdCounter, query_destination: QueryDestination, ) -> Result { @@ -210,14 +210,16 @@ fn prepare_one_select_plan<'a>( let mut where_predicates = vec![]; + let mut table_references = TableReferences::new(vec![], outer_query_refs.to_vec()); + // Parse the FROM clause into a vec of TableReferences. Fold all the join conditions expressions into the WHERE clause. - let table_references = parse_from( + parse_from( schema, from, syms, with, &mut where_predicates, - outer_scope, + &mut table_references, table_ref_counter, )?; @@ -227,11 +229,14 @@ fn prepare_one_select_plan<'a>( .iter() .map(|c| match c { // Allocate space for all columns in all tables - ResultColumn::Star => { - table_references.iter().map(|t| t.columns().len()).sum() - } + ResultColumn::Star => table_references + .joined_tables() + .iter() + .map(|t| t.columns().len()) + .sum(), // Guess 5 columns if we can't find the table using the identifier (maybe it's in [brackets] or `tick_quotes`, or miXeDcAse) ResultColumn::TableStar(n) => table_references + .joined_tables() .iter() .find(|t| t.identifier == n.0) .map(|t| t.columns().len()) @@ -244,6 +249,7 @@ fn prepare_one_select_plan<'a>( let mut plan = SelectPlan { join_order: table_references + .joined_tables() .iter() .enumerate() .map(|(i, t)| JoinOrderMember { @@ -270,8 +276,11 @@ fn prepare_one_select_plan<'a>( for column in columns.iter_mut() { match column { ResultColumn::Star => { - select_star(&plan.table_references, &mut plan.result_columns); - for table in plan.table_references.iter_mut() { + select_star( + &plan.table_references.joined_tables(), + &mut plan.result_columns, + ); + for table in plan.table_references.joined_tables_mut() { for idx in 0..table.columns().len() { table.mark_column_used(idx); } @@ -281,6 +290,7 @@ fn prepare_one_select_plan<'a>( let name_normalized = normalize_ident(name.0.as_str()); let referenced_table = plan .table_references + .joined_tables_mut() .iter_mut() .find(|t| t.identifier == name_normalized); @@ -566,7 +576,7 @@ fn prepare_one_select_plan<'a>( } let plan = SelectPlan { join_order: vec![], - table_references: vec![], + table_references: TableReferences::new(vec![], vec![]), result_columns, where_clause: vec![], group_by: None, @@ -612,7 +622,7 @@ fn replace_column_number_with_copy_of_column_expr( fn count_plan_required_cursors(plan: &SelectPlan) -> usize { let num_table_cursors: usize = plan - .table_references + .joined_tables() .iter() .map(|t| match &t.op { Operation::Scan { .. } => 1, @@ -634,7 +644,7 @@ fn count_plan_required_cursors(plan: &SelectPlan) -> usize { fn estimate_num_instructions(select: &SelectPlan) -> usize { let table_instructions: usize = select - .table_references + .joined_tables() .iter() .map(|t| match &t.op { Operation::Scan { .. } => 10, @@ -663,7 +673,7 @@ fn estimate_num_labels(select: &SelectPlan) -> usize { let init_halt_labels = 2; // 3 loop labels for each table in main loop + 1 to signify end of main loop let table_labels = select - .table_references + .joined_tables() .iter() .map(|t| match &t.op { Operation::Scan { .. } => 3, @@ -692,7 +702,7 @@ pub fn emit_simple_count<'a>( plan: &'a SelectPlan, ) -> Result<()> { let cursors = plan - .table_references + .joined_tables() .get(0) .unwrap() .resolve_cursors(program)?; diff --git a/core/translate/subquery.rs b/core/translate/subquery.rs index 5f95db9e9..26c9df8bc 100644 --- a/core/translate/subquery.rs +++ b/core/translate/subquery.rs @@ -7,7 +7,7 @@ use crate::{ use super::{ emitter::{emit_query, Resolver, TranslateCtx}, main_loop::LoopLabels, - plan::{QueryDestination, SelectPlan, TableReference}, + plan::{QueryDestination, SelectPlan, TableReferences}, }; /// Emit the subqueries contained in the FROM clause. @@ -15,9 +15,9 @@ use super::{ pub fn emit_subqueries( program: &mut ProgramBuilder, t_ctx: &mut TranslateCtx, - tables: &mut [TableReference], + tables: &mut TableReferences, ) -> Result<()> { - for table_reference in tables.iter_mut() { + for table_reference in tables.joined_tables_mut() { if let Table::FromClauseSubquery(from_clause_subquery) = &mut table_reference.table { // Emit the subquery and get the start register of the result columns. let result_columns_start = @@ -65,12 +65,12 @@ pub fn emit_subquery<'a>( } let end_coroutine_label = program.allocate_label(); let mut metadata = TranslateCtx { - labels_main_loop: (0..plan.table_references.len()) + labels_main_loop: (0..plan.joined_tables().len()) .map(|_| LoopLabels::new(program)) .collect(), label_main_loop_end: None, meta_group_by: None, - meta_left_joins: (0..plan.table_references.len()).map(|_| None).collect(), + meta_left_joins: (0..plan.joined_tables().len()).map(|_| None).collect(), meta_sort: None, reg_agg_start: None, reg_nonagg_emit_once_flag: None, diff --git a/core/translate/update.rs b/core/translate/update.rs index b6f6d3486..9d3d72f9e 100644 --- a/core/translate/update.rs +++ b/core/translate/update.rs @@ -12,7 +12,8 @@ use limbo_sqlite3_parser::ast::{self, Expr, ResultColumn, SortOrder, Update}; use super::emitter::emit_program; use super::optimizer::optimize_plan; use super::plan::{ - ColumnUsedMask, IterationDirection, Plan, ResultSetColumn, TableReference, UpdatePlan, + ColumnUsedMask, IterationDirection, JoinedTable, Plan, ResultSetColumn, TableReferences, + UpdatePlan, }; use super::planner::bind_column_references; use super::planner::{parse_limit, parse_where}; @@ -103,7 +104,7 @@ pub fn prepare_update_plan( }) }) .unwrap_or(IterationDirection::Forwards); - let mut table_references = vec![TableReference { + let joined_tables = vec![JoinedTable { table: match table.as_ref() { Table::Virtual(vtab) => Table::Virtual(vtab.clone()), Table::BTree(btree_table) => Table::BTree(btree_table.clone()), @@ -118,6 +119,7 @@ pub fn prepare_update_plan( join_info: None, col_used_mask: ColumnUsedMask::new(), }]; + let mut table_references = TableReferences::new(joined_tables, vec![]); let set_clauses = body .sets .iter_mut() diff --git a/core/vdbe/builder.rs b/core/vdbe/builder.rs index aa3f9a315..1c78e5bf6 100644 --- a/core/vdbe/builder.rs +++ b/core/vdbe/builder.rs @@ -15,7 +15,7 @@ use crate::{ translate::{ collate::CollationSeq, emitter::TransactionMode, - plan::{ResultSetColumn, TableReference}, + plan::{ResultSetColumn, TableReferences}, }, Connection, VirtualTable, }; @@ -106,7 +106,7 @@ pub struct ProgramBuilder { comments: Option>, pub parameters: Parameters, pub result_columns: Vec, - pub table_references: Vec, + pub table_references: TableReferences, /// Curr collation sequence. Bool indicates whether it was set by a COLLATE expr collation: Option<(CollationSeq, bool)>, /// Current parsing nesting level @@ -170,7 +170,7 @@ impl ProgramBuilder { }, parameters: Parameters::new(), result_columns: Vec::new(), - table_references: Vec::new(), + table_references: TableReferences::new(vec![], vec![]), collation: None, nested_level: 0, // These labels will be filled when `prologue()` is called @@ -740,7 +740,7 @@ impl ProgramBuilder { /// Checks whether `table` or any of its indices has been opened in the program pub fn is_table_open(&self, table: &Table) -> bool { - self.table_references.iter().any(|t| t.table == *table) + self.table_references.contains_table(table) } pub fn build( diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index d62413c58..520c6af10 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -29,11 +29,12 @@ use crate::{ fast_lock::SpinLock, function::{AggFunc, FuncCtx}, storage::{pager::PagerCacheflushStatus, sqlite3_ondisk::SmallVec}, + translate::plan::TableReferences, }; use crate::{ storage::{btree::BTreeCursor, pager::Pager, sqlite3_ondisk::DatabaseHeader}, - translate::plan::{ResultSetColumn, TableReference}, + translate::plan::ResultSetColumn, types::{AggContext, Cursor, CursorResult, ImmutableRecord, SeekKey, SeekOp, Value}, vdbe::{builder::CursorType, insn::Insn}, }; @@ -393,7 +394,7 @@ pub struct Program { pub n_change: Cell, pub change_cnt_on: bool, pub result_columns: Vec, - pub table_references: Vec, + pub table_references: TableReferences, } impl Program { From 211b511189bc9d14a2acdd3a2301f579fc49816e Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Tue, 27 May 2025 20:12:25 +0300 Subject: [PATCH 3/3] Fix join optimizer tests --- core/translate/optimizer/join.rs | 54 +++++++++++++++++++------------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/core/translate/optimizer/join.rs b/core/translate/optimizer/join.rs index 63072f73e..97737f50a 100644 --- a/core/translate/optimizer/join.rs +++ b/core/translate/optimizer/join.rs @@ -503,7 +503,9 @@ mod tests { schema::{BTreeTable, Column, Index, IndexColumn, Table, Type}, translate::{ optimizer::constraints::{constraints_from_where_clause, BinaryExprSide}, - plan::{ColumnUsedMask, IterationDirection, JoinInfo, Operation, WhereTerm}, + plan::{ + ColumnUsedMask, IterationDirection, JoinInfo, Operation, TableReferences, WhereTerm, + }, planner::TableMask, }, vdbe::builder::TableRefIdCounter, @@ -523,17 +525,17 @@ mod tests { #[test] /// Test that [compute_best_join_order] returns None when there are no table references. fn test_compute_best_join_order_empty() { - let joined_tables = vec![]; + let table_references = TableReferences::new(vec![], vec![]); let available_indexes = HashMap::new(); let where_clause = vec![]; let access_methods_arena = RefCell::new(Vec::new()); let table_constraints = - constraints_from_where_clause(&where_clause, &joined_tables, &available_indexes) + constraints_from_where_clause(&where_clause, &table_references, &available_indexes) .unwrap(); let result = compute_best_join_order( - &joined_tables, + table_references.joined_tables(), None, &table_constraints, &access_methods_arena, @@ -552,18 +554,19 @@ mod tests { None, table_id_counter.next(), )]; + let table_references = TableReferences::new(joined_tables, vec![]); let available_indexes = HashMap::new(); let where_clause = vec![]; let access_methods_arena = RefCell::new(Vec::new()); let table_constraints = - constraints_from_where_clause(&where_clause, &joined_tables, &available_indexes) + constraints_from_where_clause(&where_clause, &table_references, &available_indexes) .unwrap(); // SELECT * from test_table // expecting best_best_plan() not to do any work due to empty where clause. let BestJoinOrderResult { best_plan, .. } = compute_best_join_order( - &joined_tables, + table_references.joined_tables(), None, &table_constraints, &access_methods_arena, @@ -593,16 +596,17 @@ mod tests { _create_numeric_literal("42"), )]; + let table_references = TableReferences::new(joined_tables, vec![]); let access_methods_arena = RefCell::new(Vec::new()); let available_indexes = HashMap::new(); let table_constraints = - constraints_from_where_clause(&where_clause, &joined_tables, &available_indexes) + constraints_from_where_clause(&where_clause, &table_references, &available_indexes) .unwrap(); // SELECT * FROM test_table WHERE id = 42 // expecting a RowidEq access method because id is a rowid alias. let result = compute_best_join_order( - &joined_tables, + table_references.joined_tables(), None, &table_constraints, &access_methods_arena, @@ -642,6 +646,7 @@ mod tests { _create_numeric_literal("42"), )]; + let table_references = TableReferences::new(joined_tables, vec![]); let access_methods_arena = RefCell::new(Vec::new()); let mut available_indexes = HashMap::new(); let index = Arc::new(Index { @@ -661,12 +666,12 @@ mod tests { available_indexes.insert("test_table".to_string(), vec![index]); let table_constraints = - constraints_from_where_clause(&where_clause, &joined_tables, &available_indexes) + constraints_from_where_clause(&where_clause, &table_references, &available_indexes) .unwrap(); // SELECT * FROM test_table WHERE id = 42 // expecting an IndexScan access method because id is a primary key with an index let result = compute_best_join_order( - &joined_tables, + table_references.joined_tables(), None, &table_constraints, &access_methods_arena, @@ -694,7 +699,7 @@ mod tests { let t2 = _create_btree_table("table2", _create_column_list(&["id"], Type::Integer)); let mut table_id_counter = TableRefIdCounter::new(); - let mut joined_tables = vec![ + let joined_tables = vec![ _create_table_reference(t1.clone(), None, table_id_counter.next()), _create_table_reference( t2.clone(), @@ -735,13 +740,14 @@ mod tests { _create_column_expr(joined_tables[TABLE2].internal_id, 0, false), // table2.id )]; + let table_references = TableReferences::new(joined_tables, vec![]); let access_methods_arena = RefCell::new(Vec::new()); let table_constraints = - constraints_from_where_clause(&where_clause, &joined_tables, &available_indexes) + constraints_from_where_clause(&where_clause, &table_references, &available_indexes) .unwrap(); let result = compute_best_join_order( - &mut joined_tables, + table_references.joined_tables(), None, &table_constraints, &access_methods_arena, @@ -903,13 +909,14 @@ mod tests { ), ]; + let table_references = TableReferences::new(joined_tables, vec![]); let access_methods_arena = RefCell::new(Vec::new()); let table_constraints = - constraints_from_where_clause(&where_clause, &joined_tables, &available_indexes) + constraints_from_where_clause(&where_clause, &table_references, &available_indexes) .unwrap(); let result = compute_best_join_order( - &joined_tables, + table_references.joined_tables(), None, &table_constraints, &access_methods_arena, @@ -975,7 +982,7 @@ mod tests { let t3 = _create_btree_table("t3", _create_column_list(&["id", "foo"], Type::Integer)); let mut table_id_counter = TableRefIdCounter::new(); - let mut joined_tables = vec![ + let joined_tables = vec![ _create_table_reference(t1.clone(), None, table_id_counter.next()), _create_table_reference( t2.clone(), @@ -1010,14 +1017,15 @@ mod tests { ), ]; + let table_references = TableReferences::new(joined_tables, vec![]); let available_indexes = HashMap::new(); let access_methods_arena = RefCell::new(Vec::new()); let table_constraints = - constraints_from_where_clause(&where_clause, &joined_tables, &available_indexes) + constraints_from_where_clause(&where_clause, &table_references, &available_indexes) .unwrap(); let BestJoinOrderResult { best_plan, .. } = compute_best_join_order( - &mut joined_tables, + table_references.joined_tables(), None, &table_constraints, &access_methods_arena, @@ -1115,14 +1123,15 @@ mod tests { )); } + let table_references = TableReferences::new(joined_tables, vec![]); let access_methods_arena = RefCell::new(Vec::new()); let available_indexes = HashMap::new(); let table_constraints = - constraints_from_where_clause(&where_clause, &joined_tables, &available_indexes) + constraints_from_where_clause(&where_clause, &table_references, &available_indexes) .unwrap(); let result = compute_best_join_order( - &joined_tables, + table_references.joined_tables(), None, &table_constraints, &access_methods_arena, @@ -1199,14 +1208,15 @@ mod tests { )); } + let table_references = TableReferences::new(joined_tables, vec![]); let access_methods_arena = RefCell::new(Vec::new()); let table_constraints = - constraints_from_where_clause(&where_clause, &joined_tables, &available_indexes) + constraints_from_where_clause(&where_clause, &table_references, &available_indexes) .unwrap(); // Run the optimizer let BestJoinOrderResult { best_plan, .. } = compute_best_join_order( - &joined_tables, + table_references.joined_tables(), None, &table_constraints, &access_methods_arena,