More structured query planner

This commit is contained in:
jussisaurio
2024-08-11 16:58:32 +03:00
parent c2944f6eeb
commit 2e32ca0bdb
13 changed files with 3314 additions and 2392 deletions

View File

@@ -17,6 +17,7 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
use fallible_iterator::FallibleIterator;
use log::trace;
use schema::Schema;
use sqlite3_parser::ast;
use sqlite3_parser::{ast::Cmd, lexer::sql::Parser};
use std::sync::Arc;
use std::{cell::RefCell, rc::Rc};
@@ -27,6 +28,9 @@ use storage::sqlite3_ondisk::DatabaseHeader;
#[cfg(feature = "fs")]
use storage::wal::WalFile;
use translate::optimizer::optimize_plan;
use translate::planner::prepare_select_plan;
pub use error::LimboError;
pub type Result<T> = std::result::Result<T, error::LimboError>;
@@ -173,7 +177,17 @@ impl Connection {
program.explain();
Ok(None)
}
Cmd::ExplainQueryPlan(_stmt) => Ok(None),
Cmd::ExplainQueryPlan(stmt) => {
match stmt {
ast::Stmt::Select(select) => {
let plan = prepare_select_plan(&self.schema, select)?;
let plan = optimize_plan(plan)?;
println!("{}", plan);
}
_ => todo!(),
}
Ok(None)
}
}
} else {
Ok(None)

838
core/translate/emitter.rs Normal file
View File

@@ -0,0 +1,838 @@
use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;
use std::usize;
use crate::schema::{BTreeTable, Column, PseudoTable, Table};
use crate::storage::sqlite3_ondisk::DatabaseHeader;
use crate::types::{OwnedRecord, OwnedValue};
use crate::vdbe::builder::ProgramBuilder;
use crate::vdbe::{BranchOffset, Insn, Program};
use crate::Result;
use super::expr::maybe_apply_affinity;
use super::expr::{
translate_aggregation, translate_condition_expr, translate_expr, ConditionMetadata,
};
use super::plan::Plan;
use super::plan::{Operator, ProjectionColumn};
/**
* The Emitter trait is used to emit bytecode instructions for a given operator in the query plan.
*
* - start: open cursors, etc.
* - emit: open loops, emit conditional jumps etc.
* - end: close loops, etc.
* - result_columns: emit the bytecode instructions for the result columns.
* - result_row: emit the bytecode instructions for a result row.
*/
pub trait Emitter {
fn start(
&mut self,
pb: &mut ProgramBuilder,
m: &mut Metadata,
referenced_tables: &[(Rc<BTreeTable>, String)],
) -> Result<()>;
fn emit(
&mut self,
pb: &mut ProgramBuilder,
m: &mut Metadata,
referenced_tables: &[(Rc<BTreeTable>, String)],
can_emit_row: bool,
) -> Result<bool>;
fn end(
&mut self,
pb: &mut ProgramBuilder,
m: &mut Metadata,
referenced_tables: &[(Rc<BTreeTable>, String)],
) -> Result<()>;
fn result_columns(
&self,
program: &mut ProgramBuilder,
referenced_tables: &[(Rc<BTreeTable>, String)],
metadata: &mut Metadata,
cursor_override: Option<usize>,
) -> Result<usize>;
fn result_row(
&mut self,
program: &mut ProgramBuilder,
referenced_tables: &[(Rc<BTreeTable>, String)],
metadata: &mut Metadata,
cursor_override: Option<usize>,
) -> Result<bool>;
}
#[derive(Debug)]
pub struct LeftJoinMetadata {
// integer register that holds a flag that is set to true if the current row has a match for the left join
pub match_flag_register: usize,
// label for the instruction that sets the match flag to true
pub set_match_flag_true_label: BranchOffset,
// label for the instruction that checks if the match flag is true
pub check_match_flag_label: BranchOffset,
// label for the instruction where the program jumps to if the current row has a match for the left join
pub on_match_jump_to_label: BranchOffset,
}
#[derive(Debug)]
pub struct SortMetadata {
pub sort_cursor: usize,
pub sort_register: usize,
pub next_row_label: BranchOffset,
pub done_label: BranchOffset,
}
#[derive(Debug)]
pub struct Metadata {
termination_labels: Vec<BranchOffset>,
next_row_labels: HashMap<usize, BranchOffset>,
rewind_labels: Vec<BranchOffset>,
aggregations: HashMap<usize, usize>,
sorts: HashMap<usize, SortMetadata>,
left_joins: HashMap<usize, LeftJoinMetadata>,
}
impl Emitter for Operator {
fn start(
&mut self,
program: &mut ProgramBuilder,
m: &mut Metadata,
referenced_tables: &[(Rc<BTreeTable>, String)],
) -> Result<()> {
match self {
Operator::Scan {
table,
table_identifier,
id,
..
} => {
let cursor_id = program.alloc_cursor_id(
Some(table_identifier.clone()),
Some(Table::BTree(table.clone())),
);
let root_page = table.root_page;
let next_row_label = program.allocate_label();
m.next_row_labels.insert(*id, next_row_label);
program.emit_insn(Insn::OpenReadAsync {
cursor_id,
root_page,
});
program.emit_insn(Insn::OpenReadAwait);
Ok(())
}
Operator::SeekRowid {
table,
table_identifier,
..
} => {
let cursor_id = program.alloc_cursor_id(
Some(table_identifier.clone()),
Some(Table::BTree(table.clone())),
);
let root_page = table.root_page;
program.emit_insn(Insn::OpenReadAsync {
cursor_id,
root_page,
});
program.emit_insn(Insn::OpenReadAwait);
Ok(())
}
Operator::Join {
left,
right,
outer,
id,
..
} => {
if *outer {
let lj_metadata = LeftJoinMetadata {
match_flag_register: program.alloc_register(),
set_match_flag_true_label: program.allocate_label(),
check_match_flag_label: program.allocate_label(),
on_match_jump_to_label: program.allocate_label(),
};
m.left_joins.insert(*id, lj_metadata);
}
left.start(program, m, referenced_tables)?;
right.start(program, m, referenced_tables)
}
Operator::Aggregate {
id,
source,
aggregates,
} => {
let can_continue = source.start(program, m, referenced_tables)?;
let agg_final_label = program.allocate_label();
m.termination_labels.push(agg_final_label);
source.emit(program, m, referenced_tables, false)?;
let num_aggs = aggregates.len();
let start_reg = program.alloc_registers(num_aggs);
m.aggregations.insert(*id, start_reg);
Ok(can_continue)
}
Operator::Filter { .. } => unreachable!("predicates have been pushed down"),
Operator::Limit { source, .. } => source.start(program, m, referenced_tables),
Operator::Order { id, source, key } => {
let sort_cursor = program.alloc_cursor_id(None, None);
m.sorts.insert(
*id,
SortMetadata {
sort_cursor,
sort_register: usize::MAX, // will be set later
next_row_label: program.allocate_label(),
done_label: program.allocate_label(),
},
);
let mut order = Vec::new();
for (_, direction) in key.iter() {
order.push(OwnedValue::Integer(*direction as i64));
}
program.emit_insn(Insn::SorterOpen {
cursor_id: sort_cursor,
columns: key.len(),
order: OwnedRecord::new(order),
});
source.start(program, m, referenced_tables)
}
Operator::Projection { source, .. } => source.start(program, m, referenced_tables),
Operator::Nothing => Ok(()),
}
}
fn emit(
&mut self,
program: &mut ProgramBuilder,
m: &mut Metadata,
referenced_tables: &[(Rc<BTreeTable>, String)],
can_emit_row: bool,
) -> Result<bool> {
match self {
Operator::Aggregate {
source,
aggregates,
id,
} => {
let can_continue = source.emit(program, m, referenced_tables, false)?;
if !can_continue {
return Ok(false);
}
let start_reg = m.aggregations.get(id).unwrap();
for (i, agg) in aggregates.iter().enumerate() {
let agg_result_reg = start_reg + i;
translate_aggregation(program, referenced_tables, agg, agg_result_reg, None)?;
}
Ok(false)
}
Operator::Filter { .. } => unreachable!("predicates have been pushed down"),
Operator::SeekRowid {
rowid_predicate,
predicates,
table_identifier,
id,
..
} => {
let cursor_id = program.resolve_cursor_id(table_identifier, None);
let rowid_reg = program.alloc_register();
translate_expr(
program,
Some(referenced_tables),
rowid_predicate,
rowid_reg,
None,
)?;
let jump_label = m
.next_row_labels
.get(id)
.unwrap_or(&m.termination_labels.last().unwrap());
program.emit_insn_with_label_dependency(
Insn::SeekRowid {
cursor_id,
src_reg: rowid_reg,
target_pc: *jump_label,
},
*jump_label,
);
if let Some(predicates) = predicates {
for predicate in predicates.iter() {
let jump_target_when_true = program.allocate_label();
let condition_metadata = ConditionMetadata {
jump_if_condition_is_true: false,
jump_target_when_true,
jump_target_when_false: *jump_label,
};
translate_condition_expr(
program,
referenced_tables,
predicate,
None,
condition_metadata,
)?;
program.resolve_label(jump_target_when_true, program.offset());
}
}
Ok(true)
}
Operator::Limit { source, .. } => {
source.emit(program, m, referenced_tables, false)?;
Ok(true)
}
Operator::Join {
left,
right,
predicates,
outer,
id,
} => {
left.emit(program, m, referenced_tables, false)?;
let mut jump_target_when_false = *m
.next_row_labels
.get(&right.id())
.unwrap_or(&m.termination_labels.last().unwrap());
if *outer {
let lj_meta = m.left_joins.get(id).unwrap();
program.emit_insn(Insn::Integer {
value: 0,
dest: lj_meta.match_flag_register,
});
jump_target_when_false = lj_meta.check_match_flag_label;
m.next_row_labels.insert(right.id(), jump_target_when_false);
}
right.emit(program, m, referenced_tables, false)?;
if let Some(predicates) = predicates {
let jump_target_when_true = program.allocate_label();
let condition_metadata = ConditionMetadata {
jump_if_condition_is_true: false,
jump_target_when_true,
jump_target_when_false,
};
for predicate in predicates.iter() {
translate_condition_expr(
program,
referenced_tables,
predicate,
None,
condition_metadata,
)?;
}
program.resolve_label(jump_target_when_true, program.offset());
}
if *outer {
let lj_meta = m.left_joins.get(id).unwrap();
program.defer_label_resolution(
lj_meta.set_match_flag_true_label,
program.offset() as usize,
);
program.emit_insn(Insn::Integer {
value: 1,
dest: lj_meta.match_flag_register,
});
}
if can_emit_row {
return self.result_row(program, referenced_tables, m, None);
}
Ok(true)
}
Operator::Order { source, key, id } => {
source.emit(program, m, referenced_tables, false)?;
let sort_keys_count = key.len();
let source_cols_count = source.column_count(referenced_tables);
let start_reg = program.alloc_registers(sort_keys_count);
for (i, (expr, _)) in key.iter().enumerate() {
let key_reg = start_reg + i;
translate_expr(program, Some(referenced_tables), expr, key_reg, None)?;
}
source.result_columns(program, referenced_tables, m, None)?;
let dest = program.alloc_register();
program.emit_insn(Insn::MakeRecord {
start_reg,
count: sort_keys_count + source_cols_count,
dest_reg: dest,
});
let sort_metadata = m.sorts.get_mut(id).unwrap();
program.emit_insn(Insn::SorterInsert {
cursor_id: sort_metadata.sort_cursor,
record_reg: dest,
});
sort_metadata.sort_register = start_reg;
if can_emit_row {
return self.result_row(program, referenced_tables, m, None);
}
Ok(true)
}
Operator::Projection { source, .. } => {
source.emit(program, m, referenced_tables, false)?;
if can_emit_row {
return self.result_row(program, referenced_tables, m, None);
}
Ok(true)
}
Operator::Scan {
predicates,
table_identifier,
id,
..
} => {
let cursor_id = program.resolve_cursor_id(table_identifier, None);
program.emit_insn(Insn::RewindAsync { cursor_id });
let rewind_label = program.allocate_label();
let halt_label = m.termination_labels.last().unwrap();
m.rewind_labels.push(rewind_label);
program.defer_label_resolution(rewind_label, program.offset() as usize);
program.emit_insn_with_label_dependency(
Insn::RewindAwait {
cursor_id,
pc_if_empty: *halt_label,
},
*halt_label,
);
let jump_label = m.next_row_labels.get(id).unwrap_or(halt_label);
if let Some(preds) = predicates {
for expr in preds {
let jump_target_when_true = program.allocate_label();
let condition_metadata = ConditionMetadata {
jump_if_condition_is_true: false,
jump_target_when_true,
jump_target_when_false: *jump_label,
};
translate_condition_expr(
program,
referenced_tables,
expr,
None,
condition_metadata,
)?;
program.resolve_label(jump_target_when_true, program.offset());
}
}
if can_emit_row {
return self.result_row(program, referenced_tables, m, None);
}
Ok(true)
}
Operator::Nothing => Ok(false),
}
}
fn end(
&mut self,
program: &mut ProgramBuilder,
m: &mut Metadata,
referenced_tables: &[(Rc<BTreeTable>, String)],
) -> Result<()> {
match self {
Operator::Scan {
table_identifier,
id,
..
} => {
let cursor_id = program.resolve_cursor_id(table_identifier, None);
program.resolve_label(*m.next_row_labels.get(id).unwrap(), program.offset());
program.emit_insn(Insn::NextAsync { cursor_id });
let jump_label = m.rewind_labels.pop().unwrap();
program.emit_insn_with_label_dependency(
Insn::NextAwait {
cursor_id,
pc_if_next: jump_label,
},
jump_label,
);
Ok(())
}
Operator::Join {
left,
right,
outer,
id,
..
} => {
right.end(program, m, referenced_tables)?;
if *outer {
let lj_meta = m.left_joins.get(id).unwrap();
// If the left join match flag has been set to 1, we jump to the next row on the outer table (result row has been emitted already)
program.resolve_label(lj_meta.check_match_flag_label, program.offset());
program.emit_insn_with_label_dependency(
Insn::IfPos {
reg: lj_meta.match_flag_register,
target_pc: lj_meta.on_match_jump_to_label,
decrement_by: 0,
},
lj_meta.on_match_jump_to_label,
);
// If not, we set the right table cursor's "pseudo null bit" on, which means any Insn::Column will return NULL
let right_cursor_id = match right.as_ref() {
Operator::Scan {
table_identifier, ..
} => program.resolve_cursor_id(table_identifier, None),
Operator::SeekRowid {
table_identifier, ..
} => program.resolve_cursor_id(table_identifier, None),
_ => unreachable!(),
};
program.emit_insn(Insn::NullRow {
cursor_id: right_cursor_id,
});
// Jump to setting the left join match flag to 1 again, but this time the right table cursor will set everything to null
program.emit_insn_with_label_dependency(
Insn::Goto {
target_pc: lj_meta.set_match_flag_true_label,
},
lj_meta.set_match_flag_true_label,
);
// This points to the NextAsync instruction of the left table
program.resolve_label(lj_meta.on_match_jump_to_label, program.offset());
}
left.end(program, m, referenced_tables)
}
Operator::Aggregate {
id,
source,
aggregates,
} => {
source.end(program, m, referenced_tables)?;
program.resolve_label(m.termination_labels.pop().unwrap(), program.offset());
let start_reg = m.aggregations.get(id).unwrap();
for (i, agg) in aggregates.iter().enumerate() {
let agg_result_reg = *start_reg + i;
program.emit_insn(Insn::AggFinal {
register: agg_result_reg,
func: agg.func.clone(),
});
}
program.emit_insn(Insn::ResultRow {
start_reg: *start_reg,
count: aggregates.len(),
});
Ok(())
}
Operator::Filter { .. } => unreachable!("predicates have been pushed down"),
Operator::SeekRowid { .. } => Ok(()),
Operator::Limit { source, limit, .. } => {
source.result_row(program, referenced_tables, m, None)?;
let limit_reg = program.alloc_register();
program.emit_insn(Insn::Integer {
value: *limit as i64,
dest: limit_reg,
});
program.mark_last_insn_constant();
let jump_label = m.termination_labels.last().unwrap();
program.emit_insn_with_label_dependency(
Insn::DecrJumpZero {
reg: limit_reg,
target_pc: *jump_label,
},
*jump_label,
);
source.end(program, m, referenced_tables)?;
Ok(())
}
Operator::Order { id, .. } => {
let sort_metadata = m.sorts.get(id).unwrap();
program.emit_insn_with_label_dependency(
Insn::SorterNext {
cursor_id: sort_metadata.sort_cursor,
pc_if_next: sort_metadata.next_row_label,
},
sort_metadata.next_row_label,
);
program.resolve_label(sort_metadata.done_label, program.offset());
Ok(())
}
Operator::Projection { source, .. } => source.end(program, m, referenced_tables),
Operator::Nothing => Ok(()),
}
}
fn result_columns(
&self,
program: &mut ProgramBuilder,
referenced_tables: &[(Rc<BTreeTable>, String)],
m: &mut Metadata,
cursor_override: Option<usize>,
) -> Result<usize> {
let col_count = self.column_count(referenced_tables);
match self {
Operator::Scan {
table,
table_identifier,
..
} => {
let start_reg = program.alloc_registers(col_count);
table_columns(program, table, table_identifier, cursor_override, start_reg);
Ok(start_reg)
}
Operator::Join { left, right, .. } => {
let left_start_reg =
left.result_columns(program, referenced_tables, m, cursor_override)?;
right.result_columns(program, referenced_tables, m, cursor_override)?;
Ok(left_start_reg)
}
Operator::Aggregate { id, aggregates, .. } => {
let start_reg = m.aggregations.get(id).unwrap();
for (i, agg) in aggregates.iter().enumerate() {
let agg_result_reg = *start_reg + i;
program.emit_insn(Insn::AggFinal {
register: agg_result_reg,
func: agg.func.clone(),
});
}
Ok(*start_reg)
}
Operator::Filter { .. } => unreachable!("predicates have been pushed down"),
Operator::SeekRowid {
table_identifier, ..
} => {
let cursor_id =
cursor_override.unwrap_or(program.resolve_cursor_id(table_identifier, None));
let start_reg = program.alloc_registers(col_count);
for i in 0..col_count {
program.emit_insn(Insn::Column {
cursor_id,
column: i,
dest: start_reg + i,
});
}
Ok(start_reg)
}
Operator::Limit { .. } => {
unimplemented!()
}
Operator::Order { .. } => {
todo!()
}
Operator::Projection { expressions, .. } => {
let expr_count = expressions
.iter()
.map(|e| e.column_count(referenced_tables))
.sum();
let start_reg = program.alloc_registers(expr_count);
let mut cur_reg = start_reg;
for expr in expressions {
match expr {
ProjectionColumn::Column(expr) => {
translate_expr(
program,
Some(referenced_tables),
expr,
cur_reg,
cursor_override,
)?;
cur_reg += 1;
}
ProjectionColumn::Star => {
for (table, table_identifier) in referenced_tables.iter() {
cur_reg = table_columns(
program,
table,
table_identifier,
cursor_override,
cur_reg,
);
}
}
ProjectionColumn::TableStar(table, table_identifier) => {
let (table, table_identifier) = referenced_tables
.iter()
.find(|(_, id)| id == table_identifier)
.unwrap();
let cursor_id = cursor_override
.unwrap_or(program.resolve_cursor_id(table_identifier, None));
cur_reg = table_columns(
program,
table,
table_identifier,
Some(cursor_id),
cur_reg,
);
}
}
}
Ok(start_reg)
}
Operator::Nothing => unimplemented!(),
}
}
fn result_row(
&mut self,
program: &mut ProgramBuilder,
referenced_tables: &[(Rc<BTreeTable>, String)],
m: &mut Metadata,
cursor_override: Option<usize>,
) -> Result<bool> {
match self {
Operator::Order { id, source, key } => {
source.end(program, m, referenced_tables)?;
let column_names = source.column_names();
let pseudo_columns = column_names
.iter()
.map(|name| Column {
name: name.clone(),
primary_key: false,
ty: crate::schema::Type::Null,
})
.collect::<Vec<_>>();
let pseudo_cursor = program.alloc_cursor_id(
None,
Some(Table::Pseudo(Rc::new(PseudoTable {
columns: pseudo_columns,
}))),
);
let pseudo_content_reg = program.alloc_register();
program.emit_insn(Insn::OpenPseudo {
cursor_id: pseudo_cursor,
content_reg: pseudo_content_reg,
num_fields: key.len() + source.column_count(referenced_tables),
});
let sort_metadata = m.sorts.get(id).unwrap();
program.emit_insn_with_label_dependency(
Insn::SorterSort {
cursor_id: sort_metadata.sort_cursor,
pc_if_empty: sort_metadata.done_label,
},
sort_metadata.done_label,
);
program.defer_label_resolution(
sort_metadata.next_row_label,
program.offset() as usize,
);
program.emit_insn(Insn::SorterData {
cursor_id: sort_metadata.sort_cursor,
dest_reg: pseudo_content_reg,
pseudo_cursor,
});
let done_label = sort_metadata.done_label;
source.result_row(program, referenced_tables, m, Some(pseudo_cursor))?;
program.resolve_label(done_label, program.offset());
Ok(true)
}
node => {
let start_reg =
node.result_columns(program, referenced_tables, m, cursor_override)?;
program.emit_insn(Insn::ResultRow {
start_reg,
count: node.column_count(referenced_tables),
});
Ok(true)
}
}
}
}
pub fn emit_program(
database_header: Rc<RefCell<DatabaseHeader>>,
mut select_plan: Plan,
) -> Result<Program> {
let mut program = ProgramBuilder::new();
let init_label = program.allocate_label();
let halt_label = program.allocate_label();
program.emit_insn_with_label_dependency(
Insn::Init {
target_pc: init_label,
},
init_label,
);
let start_offset = program.offset();
let mut metadata = Metadata {
termination_labels: vec![halt_label],
next_row_labels: HashMap::new(),
rewind_labels: Vec::new(),
aggregations: HashMap::new(),
sorts: HashMap::new(),
left_joins: HashMap::new(),
};
select_plan
.root_node
.start(&mut program, &mut metadata, &select_plan.referenced_tables)?;
select_plan.root_node.emit(
&mut program,
&mut metadata,
&select_plan.referenced_tables,
true,
)?;
select_plan
.root_node
.end(&mut program, &mut metadata, &select_plan.referenced_tables)?;
program.resolve_label(halt_label, program.offset());
program.emit_insn(Insn::Halt);
program.resolve_label(init_label, program.offset());
program.emit_insn(Insn::Transaction);
program.emit_constant_insns();
program.emit_insn(Insn::Goto {
target_pc: start_offset,
});
program.resolve_deferred_labels();
Ok(program.build(database_header))
}
fn table_columns(
program: &mut ProgramBuilder,
table: &Rc<BTreeTable>,
table_identifier: &str,
cursor_override: Option<usize>,
start_reg: usize,
) -> usize {
let mut cur_reg = start_reg;
let cursor_id = cursor_override.unwrap_or(program.resolve_cursor_id(table_identifier, None));
for i in 0..table.columns.len() {
let is_primary_key = table.columns[i].primary_key;
let col_type = &table.columns[i].ty;
if is_primary_key {
program.emit_insn(Insn::RowId {
cursor_id,
dest: cur_reg,
});
} else {
program.emit_insn(Insn::Column {
cursor_id,
column: i,
dest: cur_reg,
});
}
maybe_apply_affinity(*col_type, cur_reg, program);
cur_reg += 1;
}
cur_reg
}

File diff suppressed because it is too large Load Diff

View File

@@ -7,10 +7,13 @@
//! a SELECT statement will be translated into a sequence of instructions that
//! will read rows from the database and filter them according to a WHERE clause.
pub(crate) mod emitter;
pub(crate) mod expr;
pub(crate) mod insert;
pub(crate) mod optimizer;
pub(crate) mod plan;
pub(crate) mod planner;
pub(crate) mod select;
pub(crate) mod where_clause;
use std::cell::RefCell;
use std::rc::Rc;
@@ -18,11 +21,10 @@ use std::rc::Rc;
use crate::schema::Schema;
use crate::storage::pager::Pager;
use crate::storage::sqlite3_ondisk::{DatabaseHeader, MIN_PAGE_CACHE_SIZE};
use crate::util::normalize_ident;
use crate::vdbe::{builder::ProgramBuilder, Insn, Program};
use crate::{bail_parse_error, Result};
use insert::translate_insert;
use select::{prepare_select, translate_select};
use select::translate_select;
use sqlite3_parser::ast;
/// Translate SQL statement into bytecode program.
@@ -56,10 +58,7 @@ pub fn translate(
ast::Stmt::Release(_) => bail_parse_error!("RELEASE not supported yet"),
ast::Stmt::Rollback { .. } => bail_parse_error!("ROLLBACK not supported yet"),
ast::Stmt::Savepoint(_) => bail_parse_error!("SAVEPOINT not supported yet"),
ast::Stmt::Select(select) => {
let select = prepare_select(schema, &select)?;
translate_select(select, database_header)
}
ast::Stmt::Select(select) => translate_select(schema, select, database_header),
ast::Stmt::Update { .. } => bail_parse_error!("UPDATE not supported yet"),
ast::Stmt::Vacuum(_, _) => bail_parse_error!("VACUUM not supported yet"),
ast::Stmt::Insert {

732
core/translate/optimizer.rs Normal file
View File

@@ -0,0 +1,732 @@
use std::rc::Rc;
use sqlite3_parser::ast;
use crate::{schema::BTreeTable, util::normalize_ident, Result};
use super::plan::{
get_table_ref_bitmask_for_ast_expr, get_table_ref_bitmask_for_query_plan_node, Operator, Plan,
};
/**
* Make a few passes over the plan to optimize it.
*/
pub fn optimize_plan(mut select_plan: Plan) -> Result<Plan> {
push_predicates(&mut select_plan.root_node, &select_plan.referenced_tables)?;
eliminate_constants(&mut select_plan.root_node)?;
use_indexes(&mut select_plan.root_node, &select_plan.referenced_tables)?;
Ok(select_plan)
}
/**
* Use indexes where possible (currently just primary key lookups)
*/
fn use_indexes(node: &mut Operator, referenced_tables: &[(Rc<BTreeTable>, String)]) -> Result<()> {
match node {
Operator::Scan {
table,
predicates: filter,
table_identifier,
id,
} => {
if filter.is_none() {
return Ok(());
}
let fs = filter.as_mut().unwrap();
let mut i = 0;
let mut maybe_rowid_predicate = None;
while i < fs.len() {
let f = fs[i].take_ownership();
let table_index = referenced_tables
.iter()
.position(|(t, t_id)| Rc::ptr_eq(t, table) && t_id == table_identifier)
.unwrap();
let (can_use, expr) =
try_extract_rowid_comparison_expression(f, table_index, referenced_tables)?;
if can_use {
maybe_rowid_predicate = Some(expr);
fs.remove(i);
break;
} else {
fs[i] = expr;
i += 1;
}
}
if let Some(rowid_predicate) = maybe_rowid_predicate {
let predicates_owned = if fs.is_empty() {
None
} else {
Some(fs.drain(..).collect())
};
*node = Operator::SeekRowid {
table: table.clone(),
table_identifier: table_identifier.clone(),
rowid_predicate,
predicates: predicates_owned,
id: *id,
}
}
return Ok(());
}
Operator::Aggregate { source, .. } => {
use_indexes(source, referenced_tables)?;
return Ok(());
}
Operator::Filter { source, .. } => {
use_indexes(source, referenced_tables)?;
return Ok(());
}
Operator::SeekRowid { .. } => {
return Ok(());
}
Operator::Limit { source, .. } => {
use_indexes(source, referenced_tables)?;
return Ok(());
}
Operator::Join { left, right, .. } => {
use_indexes(left, referenced_tables)?;
use_indexes(right, referenced_tables)?;
return Ok(());
}
Operator::Order { source, .. } => {
use_indexes(source, referenced_tables)?;
return Ok(());
}
Operator::Projection { source, .. } => {
use_indexes(source, referenced_tables)?;
return Ok(());
}
Operator::Nothing => {
return Ok(());
}
}
}
// removes predicates that are always true
// returns false if there is an impossible predicate that is always false
fn eliminate_constants(node: &mut Operator) -> Result<bool> {
match node {
Operator::Filter {
source, predicates, ..
} => {
let mut i = 0;
while i < predicates.len() {
let predicate = &predicates[i];
if predicate.is_always_true()? {
predicates.remove(i);
} else if predicate.is_always_false()? {
return Ok(false);
} else {
i += 1;
}
}
if predicates.is_empty() {
*node = source.take_ownership();
eliminate_constants(node)?;
} else {
eliminate_constants(source)?;
}
return Ok(true);
}
Operator::Join {
left,
right,
predicates,
outer,
..
} => {
if !eliminate_constants(left)? {
return Ok(false);
}
if !eliminate_constants(right)? && !*outer {
return Ok(false);
}
if predicates.is_none() {
return Ok(true);
}
let predicates = predicates.as_mut().unwrap();
let mut i = 0;
while i < predicates.len() {
let predicate = &predicates[i];
if predicate.is_always_true()? {
predicates.remove(i);
} else if predicate.is_always_false()? && !*outer {
return Ok(false);
} else {
i += 1;
}
}
return Ok(true);
}
Operator::Aggregate { source, .. } => {
let ok = eliminate_constants(source)?;
if !ok {
*source = Box::new(Operator::Nothing);
}
return Ok(ok);
}
Operator::SeekRowid {
rowid_predicate,
predicates,
..
} => {
if let Some(predicates) = predicates {
let mut i = 0;
while i < predicates.len() {
let predicate = &predicates[i];
if predicate.is_always_true()? {
predicates.remove(i);
} else if predicate.is_always_false()? {
return Ok(false);
} else {
i += 1;
}
}
}
if rowid_predicate.is_always_false()? {
return Ok(false);
}
return Ok(true);
}
Operator::Limit { source, .. } => {
let ok = eliminate_constants(source)?;
if !ok {
*node = Operator::Nothing;
}
return Ok(ok);
}
Operator::Order { source, .. } => {
let ok = eliminate_constants(source)?;
if !ok {
*node = Operator::Nothing;
}
return Ok(true);
}
Operator::Projection { source, .. } => {
let ok = eliminate_constants(source)?;
if !ok {
*node = Operator::Nothing;
}
return Ok(ok);
}
Operator::Scan { predicates, .. } => {
if let Some(ps) = predicates {
let mut i = 0;
while i < ps.len() {
let predicate = &ps[i];
if predicate.is_always_true()? {
ps.remove(i);
} else if predicate.is_always_false()? {
return Ok(false);
} else {
i += 1;
}
}
if ps.is_empty() {
*predicates = None;
}
}
return Ok(true);
}
Operator::Nothing => return Ok(true),
}
}
/**
Recursively pushes predicates down the tree, as far as possible.
*/
fn push_predicates(
node: &mut Operator,
referenced_tables: &Vec<(Rc<BTreeTable>, String)>,
) -> Result<()> {
match node {
Operator::Filter {
source, predicates, ..
} => {
let mut i = 0;
while i < predicates.len() {
// try to push the predicate to the source
// if it succeeds, remove the predicate from the filter
let predicate_owned = predicates[i].take_ownership();
let Some(predicate) = push_predicate(source, predicate_owned, referenced_tables)?
else {
predicates.remove(i);
continue;
};
predicates[i] = predicate;
i += 1;
}
if predicates.is_empty() {
*node = source.take_ownership();
}
return Ok(());
}
Operator::Join {
left,
right,
predicates,
outer,
..
} => {
push_predicates(left, referenced_tables)?;
push_predicates(right, referenced_tables)?;
if predicates.is_none() {
return Ok(());
}
let predicates = predicates.as_mut().unwrap();
let mut i = 0;
while i < predicates.len() {
// try to push the predicate to the left side first, then to the right side
// temporarily take ownership of the predicate
let predicate_owned = predicates[i].take_ownership();
// left join predicates cant be pushed to the left side
let push_result = if *outer {
Some(predicate_owned)
} else {
push_predicate(left, predicate_owned, referenced_tables)?
};
// if the predicate was pushed to a child, remove it from the list
let Some(predicate) = push_result else {
predicates.remove(i);
continue;
};
// otherwise try to push it to the right side
// if it was pushed to the right side, remove it from the list
let Some(predicate) = push_predicate(right, predicate, referenced_tables)? else {
predicates.remove(i);
continue;
};
// otherwise keep the predicate in the list
predicates[i] = predicate;
i += 1;
}
return Ok(());
}
Operator::Aggregate { source, .. } => {
push_predicates(source, referenced_tables)?;
return Ok(());
}
Operator::SeekRowid { .. } => {
return Ok(());
}
Operator::Limit { source, .. } => {
push_predicates(source, referenced_tables)?;
return Ok(());
}
Operator::Order { source, .. } => {
push_predicates(source, referenced_tables)?;
return Ok(());
}
Operator::Projection { source, .. } => {
push_predicates(source, referenced_tables)?;
return Ok(());
}
Operator::Scan { .. } => {
return Ok(());
}
Operator::Nothing => {
return Ok(());
}
}
}
/**
Push a single predicate down the tree, as far as possible.
Returns Ok(None) if the predicate was pushed, otherwise returns itself as Ok(Some(predicate))
*/
fn push_predicate(
node: &mut Operator,
predicate: ast::Expr,
referenced_tables: &Vec<(Rc<BTreeTable>, String)>,
) -> Result<Option<ast::Expr>> {
match node {
Operator::Scan {
predicates,
table_identifier,
..
} => {
let table_index = referenced_tables
.iter()
.position(|(_, t_id)| t_id == table_identifier)
.unwrap();
let predicate_bitmask =
get_table_ref_bitmask_for_ast_expr(referenced_tables, &predicate)?;
// the expression is allowed to refer to tables on its left, i.e. the righter bits in the mask
// e.g. if this table is 0010, and the table on its right in the join is 0100:
// if predicate_bitmask is 0011, the predicate can be pushed (refers to this table and the table on its left)
// if predicate_bitmask is 0001, the predicate can be pushed (refers to the table on its left)
// if predicate_bitmask is 0101, the predicate can't be pushed (refers to this table and a table on its right)
let next_table_on_the_right_in_join_bitmask = 1 << (table_index + 1);
if predicate_bitmask >= next_table_on_the_right_in_join_bitmask {
return Ok(Some(predicate));
}
if predicates.is_none() {
predicates.replace(vec![predicate]);
} else {
predicates.as_mut().unwrap().push(predicate);
}
return Ok(None);
}
Operator::Filter {
source,
predicates: ps,
..
} => {
let push_result = push_predicate(source, predicate, referenced_tables)?;
if push_result.is_none() {
return Ok(None);
}
ps.push(push_result.unwrap());
return Ok(None);
}
Operator::Join {
left,
right,
predicates: join_on_preds,
outer,
..
} => {
let push_result_left = push_predicate(left, predicate, referenced_tables)?;
if push_result_left.is_none() {
return Ok(None);
}
let push_result_right =
push_predicate(right, push_result_left.unwrap(), referenced_tables)?;
if push_result_right.is_none() {
return Ok(None);
}
if *outer {
return Ok(Some(push_result_right.unwrap()));
}
let pred = push_result_right.unwrap();
let table_refs_bitmask = get_table_ref_bitmask_for_ast_expr(referenced_tables, &pred)?;
let left_bitmask = get_table_ref_bitmask_for_query_plan_node(referenced_tables, left)?;
let right_bitmask =
get_table_ref_bitmask_for_query_plan_node(referenced_tables, right)?;
if table_refs_bitmask & left_bitmask == 0 || table_refs_bitmask & right_bitmask == 0 {
return Ok(Some(pred));
}
if join_on_preds.is_none() {
join_on_preds.replace(vec![pred]);
} else {
join_on_preds.as_mut().unwrap().push(pred);
}
return Ok(None);
}
Operator::Aggregate { source, .. } => {
let push_result = push_predicate(source, predicate, referenced_tables)?;
if push_result.is_none() {
return Ok(None);
}
return Ok(Some(push_result.unwrap()));
}
Operator::SeekRowid { .. } => {
return Ok(Some(predicate));
}
Operator::Limit { source, .. } => {
let push_result = push_predicate(source, predicate, referenced_tables)?;
if push_result.is_none() {
return Ok(None);
}
return Ok(Some(push_result.unwrap()));
}
Operator::Order { source, .. } => {
let push_result = push_predicate(source, predicate, referenced_tables)?;
if push_result.is_none() {
return Ok(None);
}
return Ok(Some(push_result.unwrap()));
}
Operator::Projection { source, .. } => {
let push_result = push_predicate(source, predicate, referenced_tables)?;
if push_result.is_none() {
return Ok(None);
}
return Ok(Some(push_result.unwrap()));
}
Operator::Nothing => {
return Ok(Some(predicate));
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConstantPredicate {
AlwaysTrue,
AlwaysFalse,
}
/**
Helper trait for expressions that can be optimized
Implemented for ast::Expr
*/
pub trait Optimizable {
// if the expression is a constant expression e.g. '1', returns the constant condition
fn check_constant(&self) -> Result<Option<ConstantPredicate>>;
fn is_always_true(&self) -> Result<bool> {
Ok(self
.check_constant()?
.map_or(false, |c| c == ConstantPredicate::AlwaysTrue))
}
fn is_always_false(&self) -> Result<bool> {
Ok(self
.check_constant()?
.map_or(false, |c| c == ConstantPredicate::AlwaysFalse))
}
// if the expression is the primary key of a table, returns the index of the table
fn check_primary_key(
&self,
referenced_tables: &[(Rc<BTreeTable>, String)],
) -> Result<Option<usize>>;
}
impl Optimizable for ast::Expr {
fn check_primary_key(
&self,
referenced_tables: &[(Rc<BTreeTable>, String)],
) -> Result<Option<usize>> {
match self {
ast::Expr::Id(ident) => {
let ident = normalize_ident(&ident.0);
let tables = referenced_tables
.iter()
.enumerate()
.filter_map(|(i, (t, _))| {
if t.get_column(&ident).map_or(false, |(_, c)| c.primary_key) {
Some(i)
} else {
None
}
});
let mut matches = 0;
let mut matching_tbl = None;
for tbl in tables {
matching_tbl = Some(tbl);
matches += 1;
if matches > 1 {
crate::bail_parse_error!("ambiguous column name {}", ident)
}
}
Ok(matching_tbl)
}
ast::Expr::Qualified(tbl, ident) => {
let tbl = normalize_ident(&tbl.0);
let ident = normalize_ident(&ident.0);
let table = referenced_tables.iter().enumerate().find(|(_, (t, t_id))| {
*t_id == tbl && t.get_column(&ident).map_or(false, |(_, c)| c.primary_key)
});
if table.is_none() {
return Ok(None);
}
let table = table.unwrap();
Ok(Some(table.0))
}
_ => Ok(None),
}
}
fn check_constant(&self) -> Result<Option<ConstantPredicate>> {
match self {
ast::Expr::Literal(lit) => match lit {
ast::Literal::Null => Ok(Some(ConstantPredicate::AlwaysFalse)),
ast::Literal::Numeric(b) => {
if let Ok(int_value) = b.parse::<i64>() {
return Ok(Some(if int_value == 0 {
ConstantPredicate::AlwaysFalse
} else {
ConstantPredicate::AlwaysTrue
}));
}
if let Ok(float_value) = b.parse::<f64>() {
return Ok(Some(if float_value == 0.0 {
ConstantPredicate::AlwaysFalse
} else {
ConstantPredicate::AlwaysTrue
}));
}
Ok(None)
}
ast::Literal::String(s) => {
let without_quotes = s.trim_matches('\'');
if let Ok(int_value) = without_quotes.parse::<i64>() {
return Ok(Some(if int_value == 0 {
ConstantPredicate::AlwaysFalse
} else {
ConstantPredicate::AlwaysTrue
}));
}
if let Ok(float_value) = without_quotes.parse::<f64>() {
return Ok(Some(if float_value == 0.0 {
ConstantPredicate::AlwaysFalse
} else {
ConstantPredicate::AlwaysTrue
}));
}
Ok(Some(ConstantPredicate::AlwaysFalse))
}
_ => Ok(None),
},
ast::Expr::Unary(op, expr) => {
if *op == ast::UnaryOperator::Not {
let trivial = expr.check_constant()?;
return Ok(trivial.map(|t| match t {
ConstantPredicate::AlwaysTrue => ConstantPredicate::AlwaysFalse,
ConstantPredicate::AlwaysFalse => ConstantPredicate::AlwaysTrue,
}));
}
if *op == ast::UnaryOperator::Negative {
let trivial = expr.check_constant()?;
return Ok(trivial);
}
Ok(None)
}
ast::Expr::InList { lhs: _, not, rhs } => {
if rhs.is_none() {
return Ok(Some(if *not {
ConstantPredicate::AlwaysTrue
} else {
ConstantPredicate::AlwaysFalse
}));
}
let rhs = rhs.as_ref().unwrap();
if rhs.is_empty() {
return Ok(Some(if *not {
ConstantPredicate::AlwaysTrue
} else {
ConstantPredicate::AlwaysFalse
}));
}
Ok(None)
}
ast::Expr::Binary(lhs, op, rhs) => {
let lhs_trivial = lhs.check_constant()?;
let rhs_trivial = rhs.check_constant()?;
match op {
ast::Operator::And => {
if lhs_trivial == Some(ConstantPredicate::AlwaysFalse)
|| rhs_trivial == Some(ConstantPredicate::AlwaysFalse)
{
return Ok(Some(ConstantPredicate::AlwaysFalse));
}
if lhs_trivial == Some(ConstantPredicate::AlwaysTrue)
&& rhs_trivial == Some(ConstantPredicate::AlwaysTrue)
{
return Ok(Some(ConstantPredicate::AlwaysTrue));
}
Ok(None)
}
ast::Operator::Or => {
if lhs_trivial == Some(ConstantPredicate::AlwaysTrue)
|| rhs_trivial == Some(ConstantPredicate::AlwaysTrue)
{
return Ok(Some(ConstantPredicate::AlwaysTrue));
}
if lhs_trivial == Some(ConstantPredicate::AlwaysFalse)
&& rhs_trivial == Some(ConstantPredicate::AlwaysFalse)
{
return Ok(Some(ConstantPredicate::AlwaysFalse));
}
Ok(None)
}
_ => Ok(None),
}
}
_ => Ok(None),
}
}
}
pub fn try_extract_rowid_comparison_expression(
expr: ast::Expr,
table_index: usize,
referenced_tables: &[(Rc<BTreeTable>, String)],
) -> Result<(bool, ast::Expr)> {
match expr {
ast::Expr::Binary(lhs, ast::Operator::Equals, rhs) => {
if let Some(lhs_table_index) = lhs.check_primary_key(referenced_tables)? {
if lhs_table_index == table_index {
return Ok((true, *rhs));
}
}
if let Some(rhs_table_index) = rhs.check_primary_key(referenced_tables)? {
if rhs_table_index == table_index {
return Ok((true, *lhs));
}
}
Ok((false, ast::Expr::Binary(lhs, ast::Operator::Equals, rhs)))
}
_ => Ok((false, expr)),
}
}
trait TakeOwnership {
fn take_ownership(&mut self) -> Self;
}
impl TakeOwnership for ast::Expr {
fn take_ownership(&mut self) -> Self {
std::mem::replace(self, ast::Expr::Literal(ast::Literal::Null))
}
}
impl TakeOwnership for Operator {
fn take_ownership(&mut self) -> Self {
std::mem::replace(self, Operator::Nothing)
}
}
fn replace_with<T: TakeOwnership>(expr: &mut T, mut replacement: T) {
*expr = replacement.take_ownership();
}

469
core/translate/plan.rs Normal file
View File

@@ -0,0 +1,469 @@
use core::fmt;
use std::{
fmt::{Display, Formatter},
rc::Rc,
};
use sqlite3_parser::ast;
use crate::{function::AggFunc, schema::BTreeTable, util::normalize_ident, Result};
pub struct Plan {
pub root_node: Operator,
pub referenced_tables: Vec<(Rc<BTreeTable>, String)>,
}
impl Display for Plan {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.root_node)
}
}
/**
An Operator is a Node in the query plan.
Operators form a tree structure, with each having zero or more children.
For example, a query like `SELECT t1.foo FROM t1 ORDER BY t1.foo LIMIT 1` would have the following structure:
Limit
Order
Project
Scan
*/
#[derive(Clone, Debug)]
pub enum Operator {
Aggregate {
id: usize,
source: Box<Operator>,
aggregates: Vec<Aggregate>,
},
Filter {
id: usize,
source: Box<Operator>,
predicates: Vec<ast::Expr>,
},
SeekRowid {
id: usize,
table: Rc<BTreeTable>,
table_identifier: String,
rowid_predicate: ast::Expr,
predicates: Option<Vec<ast::Expr>>,
},
Limit {
id: usize,
source: Box<Operator>,
limit: usize,
},
Join {
id: usize,
left: Box<Operator>,
right: Box<Operator>,
predicates: Option<Vec<ast::Expr>>,
outer: bool,
},
Order {
id: usize,
source: Box<Operator>,
key: Vec<(ast::Expr, Direction)>,
},
Projection {
id: usize,
source: Box<Operator>,
expressions: Vec<ProjectionColumn>,
},
Scan {
id: usize,
table: Rc<BTreeTable>,
table_identifier: String,
predicates: Option<Vec<ast::Expr>>,
},
Nothing,
}
#[derive(Clone, Debug)]
pub enum ProjectionColumn {
Column(ast::Expr),
Star,
TableStar(Rc<BTreeTable>, String),
}
impl ProjectionColumn {
pub fn column_count(&self, referenced_tables: &[(Rc<BTreeTable>, String)]) -> usize {
match self {
ProjectionColumn::Column(_) => 1,
ProjectionColumn::Star => {
let mut count = 0;
for (table, _) in referenced_tables {
count += table.columns.len();
}
count
}
ProjectionColumn::TableStar(table, _) => table.columns.len(),
}
}
}
impl Operator {
pub fn column_count(&self, referenced_tables: &[(Rc<BTreeTable>, String)]) -> usize {
match self {
Operator::Aggregate { aggregates, .. } => aggregates.len(),
Operator::Filter { source, .. } => source.column_count(referenced_tables),
Operator::SeekRowid { table, .. } => table.columns.len(),
Operator::Limit { source, .. } => source.column_count(referenced_tables),
Operator::Join { left, right, .. } => {
left.column_count(referenced_tables) + right.column_count(referenced_tables)
}
Operator::Order { source, .. } => source.column_count(referenced_tables),
Operator::Projection { expressions, .. } => expressions
.iter()
.map(|e| e.column_count(referenced_tables))
.sum(),
Operator::Scan { table, .. } => table.columns.len(),
Operator::Nothing => 0,
}
}
pub fn column_names(&self) -> Vec<String> {
match self {
Operator::Aggregate { .. } => {
todo!();
}
Operator::Filter { source, .. } => source.column_names(),
Operator::SeekRowid { table, .. } => {
table.columns.iter().map(|c| c.name.clone()).collect()
}
Operator::Limit { source, .. } => source.column_names(),
Operator::Join { left, right, .. } => {
let mut names = left.column_names();
names.extend(right.column_names());
names
}
Operator::Order { source, .. } => source.column_names(),
Operator::Projection { expressions, .. } => expressions
.iter()
.map(|e| match e {
ProjectionColumn::Column(expr) => match expr {
ast::Expr::Id(ident) => ident.0.clone(),
ast::Expr::Qualified(tbl, ident) => format!("{}.{}", tbl.0, ident.0),
_ => "expr".to_string(),
},
ProjectionColumn::Star => "*".to_string(),
ProjectionColumn::TableStar(_, tbl) => format!("{}.{}", tbl, "*"),
})
.collect(),
Operator::Scan { table, .. } => table.columns.iter().map(|c| c.name.clone()).collect(),
Operator::Nothing => vec![],
}
}
pub fn id(&self) -> usize {
match self {
Operator::Aggregate { id, .. } => *id,
Operator::Filter { id, .. } => *id,
Operator::SeekRowid { id, .. } => *id,
Operator::Limit { id, .. } => *id,
Operator::Join { id, .. } => *id,
Operator::Order { id, .. } => *id,
Operator::Projection { id, .. } => *id,
Operator::Scan { id, .. } => *id,
Operator::Nothing => unreachable!(),
}
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum Direction {
Ascending,
Descending,
}
impl Display for Direction {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self {
Direction::Ascending => write!(f, "ASC"),
Direction::Descending => write!(f, "DESC"),
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct Aggregate {
pub func: AggFunc,
pub args: Vec<ast::Expr>,
}
impl Display for Aggregate {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let args_str = self
.args
.iter()
.map(|arg| arg.to_string())
.collect::<Vec<String>>()
.join(", ");
write!(f, "{:?}({})", self.func, args_str)
}
}
// For EXPLAIN QUERY PLAN
impl Display for Operator {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
fn fmt_node(node: &Operator, f: &mut Formatter, level: usize) -> fmt::Result {
let indent = " ".repeat(level);
match node {
Operator::Aggregate {
source, aggregates, ..
} => {
// e.g. Aggregate count(*), sum(x)
let aggregates_display_string = aggregates
.iter()
.map(|agg| agg.to_string())
.collect::<Vec<String>>()
.join(", ");
writeln!(f, "{}AGGREGATE {}", indent, aggregates_display_string)?;
fmt_node(source, f, level + 1)
}
Operator::Filter {
source, predicates, ..
} => {
let predicates_string = predicates
.iter()
.map(|p| p.to_string())
.collect::<Vec<String>>()
.join(" AND ");
writeln!(f, "{}FILTER {}", indent, predicates_string)?;
fmt_node(source, f, level + 1)
}
Operator::SeekRowid {
table,
rowid_predicate,
predicates,
..
} => {
match predicates {
Some(ps) => {
let predicates_string = ps
.iter()
.map(|p| p.to_string())
.collect::<Vec<String>>()
.join(" AND ");
writeln!(
f,
"{}SEEK {}.rowid ON rowid={} FILTER {}",
indent, &table.name, rowid_predicate, predicates_string
)?;
}
None => writeln!(
f,
"{}SEEK {}.rowid ON rowid={}",
indent, &table.name, rowid_predicate
)?,
}
Ok(())
}
Operator::Limit { source, limit, .. } => {
writeln!(f, "{}TAKE {}", indent, limit)?;
fmt_node(source, f, level + 1)
}
Operator::Join {
left,
right,
predicates,
outer,
..
} => {
let join_name = if *outer { "OUTER JOIN" } else { "JOIN" };
match predicates
.as_ref()
.and_then(|ps| if ps.is_empty() { None } else { Some(ps) })
{
Some(ps) => {
let predicates_string = ps
.iter()
.map(|p| p.to_string())
.collect::<Vec<String>>()
.join(" AND ");
writeln!(f, "{}{} ON {}", indent, join_name, predicates_string)?;
}
None => writeln!(f, "{}{}", indent, join_name)?,
}
fmt_node(left, f, level + 1)?;
fmt_node(right, f, level + 1)
}
Operator::Order { source, key, .. } => {
let sort_keys_string = key
.iter()
.map(|(expr, dir)| format!("{} {}", expr, dir))
.collect::<Vec<String>>()
.join(", ");
writeln!(f, "{}SORT {}", indent, sort_keys_string)?;
fmt_node(source, f, level + 1)
}
Operator::Projection {
source,
expressions,
..
} => {
let expressions = expressions
.iter()
.map(|expr| match expr {
ProjectionColumn::Column(c) => c.to_string(),
ProjectionColumn::Star => "*".to_string(),
ProjectionColumn::TableStar(_, a) => format!("{}.{}", a, "*"),
})
.collect::<Vec<String>>()
.join(", ");
writeln!(f, "{}PROJECT {}", indent, expressions)?;
fmt_node(source, f, level + 1)
}
Operator::Scan {
table,
predicates: filter,
table_identifier,
..
} => {
let table_name = format!("{} AS {}", &table.name, &table_identifier);
let filter_string = filter.as_ref().map(|f| {
let filters_string = f
.iter()
.map(|p| p.to_string())
.collect::<Vec<String>>()
.join(" AND ");
format!("FILTER {}", filters_string)
});
match filter_string {
Some(fs) => writeln!(f, "{}SCAN {} {}", indent, table_name, fs),
None => writeln!(f, "{}SCAN {}", indent, table_name),
}?;
Ok(())
}
Operator::Nothing => Ok(()),
}
}
fmt_node(self, f, 0)
}
}
pub fn get_table_ref_bitmask_for_query_plan_node<'a>(
tables: &'a Vec<(Rc<BTreeTable>, String)>,
node: &'a Operator,
) -> Result<usize> {
let mut table_refs_mask = 0;
match node {
Operator::Aggregate { source, .. } => {
table_refs_mask |= get_table_ref_bitmask_for_query_plan_node(tables, source)?;
}
Operator::Filter {
source, predicates, ..
} => {
table_refs_mask |= get_table_ref_bitmask_for_query_plan_node(tables, source)?;
for predicate in predicates {
table_refs_mask |= get_table_ref_bitmask_for_ast_expr(tables, predicate)?;
}
}
Operator::SeekRowid { table, .. } => {
table_refs_mask |= 1
<< tables
.iter()
.position(|(t, _)| Rc::ptr_eq(t, table))
.unwrap();
}
Operator::Limit { source, .. } => {
table_refs_mask |= get_table_ref_bitmask_for_query_plan_node(tables, source)?;
}
Operator::Join { left, right, .. } => {
table_refs_mask |= get_table_ref_bitmask_for_query_plan_node(tables, left)?;
table_refs_mask |= get_table_ref_bitmask_for_query_plan_node(tables, right)?;
}
Operator::Order { source, .. } => {
table_refs_mask |= get_table_ref_bitmask_for_query_plan_node(tables, source)?;
}
Operator::Projection { source, .. } => {
table_refs_mask |= get_table_ref_bitmask_for_query_plan_node(tables, source)?;
}
Operator::Scan { table, .. } => {
table_refs_mask |= 1
<< tables
.iter()
.position(|(t, _)| Rc::ptr_eq(t, table))
.unwrap();
}
Operator::Nothing => {}
}
Ok(table_refs_mask)
}
pub fn get_table_ref_bitmask_for_ast_expr<'a>(
tables: &'a Vec<(Rc<BTreeTable>, String)>,
predicate: &'a ast::Expr,
) -> Result<usize> {
let mut table_refs_mask = 0;
match predicate {
ast::Expr::Binary(e1, _, e2) => {
table_refs_mask |= get_table_ref_bitmask_for_ast_expr(tables, e1)?;
table_refs_mask |= get_table_ref_bitmask_for_ast_expr(tables, e2)?;
}
ast::Expr::Id(ident) => {
let ident = normalize_ident(&ident.0);
let matching_tables = tables
.iter()
.enumerate()
.filter(|(_, (table, _))| table.get_column(&ident).is_some());
let mut matches = 0;
let mut matching_tbl = None;
for table in matching_tables {
matching_tbl = Some(table);
matches += 1;
if matches > 1 {
crate::bail_parse_error!("ambiguous column name {}", &ident)
}
}
if let Some((tbl_index, _)) = matching_tbl {
table_refs_mask |= 1 << tbl_index;
} else {
crate::bail_parse_error!("column not found: {}", &ident)
}
}
ast::Expr::Qualified(tbl, ident) => {
let tbl = normalize_ident(&tbl.0);
let ident = normalize_ident(&ident.0);
let matching_table = tables
.iter()
.enumerate()
.find(|(_, (table, t_id))| *t_id == tbl);
if matching_table.is_none() {
crate::bail_parse_error!("introspect: table not found: {}", &tbl)
}
let matching_table = matching_table.unwrap();
if matching_table.1 .0.get_column(&ident).is_none() {
crate::bail_parse_error!("column with qualified name {}.{} not found", &tbl, &ident)
}
table_refs_mask |= 1 << matching_table.0;
}
ast::Expr::Literal(_) => {}
ast::Expr::Like { lhs, rhs, .. } => {
table_refs_mask |= get_table_ref_bitmask_for_ast_expr(tables, lhs)?;
table_refs_mask |= get_table_ref_bitmask_for_ast_expr(tables, rhs)?;
}
ast::Expr::FunctionCall {
args: Some(args), ..
} => {
for arg in args {
table_refs_mask |= get_table_ref_bitmask_for_ast_expr(tables, arg)?;
}
}
ast::Expr::InList { lhs, rhs, .. } => {
table_refs_mask |= get_table_ref_bitmask_for_ast_expr(tables, lhs)?;
if let Some(rhs_list) = rhs {
for rhs_expr in rhs_list {
table_refs_mask |= get_table_ref_bitmask_for_ast_expr(tables, rhs_expr)?;
}
}
}
_ => {}
}
Ok(table_refs_mask)
}

359
core/translate/planner.rs Normal file
View File

@@ -0,0 +1,359 @@
use super::plan::{Aggregate, Direction, Operator, Plan, ProjectionColumn};
use crate::{
function::Func,
schema::{BTreeTable, Schema},
util::normalize_ident,
Result,
};
use sqlite3_parser::ast::{self, FromClause, JoinType, ResultColumn};
use std::rc::Rc;
pub struct NodeIdCounter {
id: usize,
}
impl NodeIdCounter {
pub fn new() -> Self {
Self { id: 0 }
}
pub fn get_next_id(&mut self) -> usize {
let id = self.id;
self.id += 1;
id
}
}
pub fn prepare_select_plan<'a>(schema: &Schema, select: ast::Select) -> Result<Plan> {
match select.body.select {
ast::OneSelect::Select {
columns,
from,
where_clause,
..
} => {
let col_count = columns.len();
if col_count == 0 {
crate::bail_parse_error!("SELECT without columns is not allowed");
}
let mut node_id_counter = NodeIdCounter::new();
// Parse the FROM clause
let (mut node, referenced_tables) = parse_from(schema, from, &mut node_id_counter)?;
// Parse the WHERE clause
if let Some(w) = where_clause {
node = Operator::Filter {
source: Box::new(node),
predicates: break_predicate_at_and_boundaries(w, vec![]),
id: node_id_counter.get_next_id(),
};
}
// Parse the SELECT clause to either a projection or an aggregation
// depending on the presence of aggregate functions.
// Since GROUP BY is not supported yet, mixing aggregate and non-aggregate
// columns is not allowed.
//
// If there are no aggregate functions, we can simply project the columns.
// For a simple SELECT *, the projection node is skipped.
let is_select_star = col_count == 1 && matches!(columns[0], ast::ResultColumn::Star);
if !is_select_star {
let mut aggregate_expressions = Vec::new();
let mut scalar_expressions = Vec::with_capacity(col_count);
for column in columns.clone() {
match column {
ast::ResultColumn::Star => {
scalar_expressions.push(ProjectionColumn::Star);
}
ast::ResultColumn::TableStar(name) => {
let name_normalized = normalize_ident(name.0.as_str());
let referenced_table = referenced_tables
.iter()
.find(|(t, t_id)| *t_id == name_normalized);
if referenced_table.is_none() {
crate::bail_parse_error!("Table {} not found", name.0);
}
let (table, identifier) = referenced_table.unwrap();
scalar_expressions.push(ProjectionColumn::TableStar(
table.clone(),
identifier.clone(),
));
}
ast::ResultColumn::Expr(expr, _) => match expr {
ast::Expr::FunctionCall {
name,
distinctness,
args,
filter_over,
order_by,
} => {
let args_count = if let Some(args) = &args {
args.len()
} else {
0
};
match Func::resolve_function(
normalize_ident(name.0.as_str()).as_str(),
args_count,
) {
Ok(Func::Agg(f)) => aggregate_expressions.push(Aggregate {
func: f,
args: args.unwrap(),
}),
Ok(_) => {
scalar_expressions.push(ProjectionColumn::Column(
ast::Expr::FunctionCall {
name,
distinctness,
args,
filter_over,
order_by,
},
));
}
_ => {}
}
}
ast::Expr::FunctionCallStar { name, filter_over } => {
match Func::resolve_function(
normalize_ident(name.0.as_str()).as_str(),
0,
) {
Ok(Func::Agg(f)) => aggregate_expressions.push(Aggregate {
func: f,
args: vec![],
}),
Ok(Func::Scalar(_)) => {
scalar_expressions.push(ProjectionColumn::Column(
ast::Expr::FunctionCallStar { name, filter_over },
));
}
_ => {}
}
}
_ => {
scalar_expressions.push(ProjectionColumn::Column(expr));
}
},
}
}
let mixing_aggregate_and_non_aggregate_columns =
!aggregate_expressions.is_empty() && aggregate_expressions.len() != col_count;
if mixing_aggregate_and_non_aggregate_columns {
crate::bail_parse_error!(
"mixing aggregate and non-aggregate columns is not allowed (GROUP BY is not supported)"
);
}
if !aggregate_expressions.is_empty() {
node = Operator::Aggregate {
source: Box::new(node),
aggregates: aggregate_expressions,
id: node_id_counter.get_next_id(),
}
} else if !scalar_expressions.is_empty() {
node = Operator::Projection {
source: Box::new(node),
expressions: scalar_expressions,
id: node_id_counter.get_next_id(),
};
}
}
// Parse the ORDER BY clause
if let Some(order_by) = select.order_by {
let mut key = Vec::new();
for o in order_by {
// if the ORDER BY expression is a number, interpret it as an 1-indexed column number
// otherwise, interpret it normally as an expression
let expr = if let ast::Expr::Literal(ast::Literal::Numeric(num)) = o.expr {
let column_number = num.parse::<usize>()?;
if column_number == 0 {
crate::bail_parse_error!("invalid column index: {}", column_number);
}
let maybe_result_column = columns.get(column_number - 1);
match maybe_result_column {
Some(ResultColumn::Expr(expr, _)) => expr.clone(),
None => {
crate::bail_parse_error!("invalid column index: {}", column_number)
}
_ => todo!(),
}
} else {
o.expr
};
key.push((
expr,
o.order.map_or(Direction::Ascending, |o| match o {
ast::SortOrder::Asc => Direction::Ascending,
ast::SortOrder::Desc => Direction::Descending,
}),
));
}
node = Operator::Order {
source: Box::new(node),
key,
id: node_id_counter.get_next_id(),
};
}
// Parse the LIMIT clause
if let Some(limit) = &select.limit {
node = match &limit.expr {
ast::Expr::Literal(ast::Literal::Numeric(n)) => {
let l = n.parse()?;
if l == 0 {
Operator::Nothing
} else {
Operator::Limit {
source: Box::new(node),
limit: l,
id: node_id_counter.get_next_id(),
}
}
}
_ => todo!(),
}
}
// Return the unoptimized query plan
return Ok(Plan {
root_node: node,
referenced_tables,
});
}
_ => todo!(),
};
}
fn parse_from(
schema: &Schema,
from: Option<FromClause>,
node_id_counter: &mut NodeIdCounter,
) -> Result<(Operator, Vec<(Rc<BTreeTable>, String)>)> {
if from.as_ref().and_then(|f| f.select.as_ref()).is_none() {
return Ok((Operator::Nothing, vec![]));
}
let from = from.unwrap();
let first_table = match *from.select.unwrap() {
ast::SelectTable::Table(qualified_name, maybe_alias, _) => {
let Some(table) = schema.get_table(&qualified_name.name.0) else {
crate::bail_parse_error!("Table {} not found", qualified_name.name.0);
};
let alias = maybe_alias
.map(|a| match a {
ast::As::As(id) => id,
ast::As::Elided(id) => id,
})
.map(|a| a.0);
(table, alias.unwrap_or(qualified_name.name.0))
}
_ => todo!(),
};
let mut node = Operator::Scan {
table: first_table.0.clone(),
predicates: None,
table_identifier: first_table.1.clone(),
id: node_id_counter.get_next_id(),
};
let mut tables = vec![first_table];
for join in from.joins.unwrap_or_default().into_iter() {
let (right, outer, predicates) = parse_join(schema, join, node_id_counter, &mut tables)?;
node = Operator::Join {
left: Box::new(node),
right: Box::new(right),
predicates,
outer,
id: node_id_counter.get_next_id(),
}
}
return Ok((node, tables));
}
fn parse_join(
schema: &Schema,
join: ast::JoinedSelectTable,
node_id_counter: &mut NodeIdCounter,
tables: &mut Vec<(Rc<BTreeTable>, String)>,
) -> Result<(Operator, bool, Option<Vec<ast::Expr>>)> {
let ast::JoinedSelectTable {
operator,
table,
constraint,
} = join;
let table = match table {
ast::SelectTable::Table(qualified_name, maybe_alias, _) => {
let Some(table) = schema.get_table(&qualified_name.name.0) else {
crate::bail_parse_error!("Table {} not found", qualified_name.name.0);
};
let alias = maybe_alias
.map(|a| match a {
ast::As::As(id) => id,
ast::As::Elided(id) => id,
})
.map(|a| a.0);
(table, alias.unwrap_or(qualified_name.name.0))
}
_ => todo!(),
};
tables.push(table.clone());
let outer = match operator {
ast::JoinOperator::TypedJoin(Some(join_type)) => {
if join_type == JoinType::LEFT | JoinType::OUTER {
true
} else if join_type == JoinType::RIGHT | JoinType::OUTER {
true
} else {
false
}
}
_ => false,
};
let predicates = constraint.map(|c| match c {
ast::JoinConstraint::On(expr) => break_predicate_at_and_boundaries(expr, vec![]),
ast::JoinConstraint::Using(_) => todo!("USING joins not supported yet"),
});
Ok((
Operator::Scan {
table: table.0.clone(),
predicates: None,
table_identifier: table.1.clone(),
id: node_id_counter.get_next_id(),
},
outer,
predicates,
))
}
fn break_predicate_at_and_boundaries(
predicate: ast::Expr,
mut predicates: Vec<ast::Expr>,
) -> Vec<ast::Expr> {
match predicate {
ast::Expr::Binary(left, ast::Operator::And, right) => {
let ps = break_predicate_at_and_boundaries(*left, predicates);
let ps = break_predicate_at_and_boundaries(*right, ps);
ps
}
_ => {
predicates.push(predicate);
predicates
}
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1120,7 +1120,8 @@ impl Program {
let record = match *cursor.record()? {
Some(ref record) => record.clone(),
None => {
todo!();
state.pc += 1;
continue;
}
};
state.registers[*dest_reg] = OwnedValue::Record(record.clone());

View File

@@ -74,3 +74,7 @@ do_execsql_test select-string-agg-with-delimiter {
do_execsql_test select-string-agg-with-column-delimiter {
SELECT string_agg(name, id) FROM products;
} {hat2cap3shirt4sweater5sweatshirt6shorts7jeans8sneakers9boots10coat11accessories}
do_execsql_test select-count-star {
SELECT count(*) FROM users;
} {10000}

View File

@@ -141,7 +141,7 @@ do_execsql_test left-join-no-join-conditions-but-multiple-where {
} {Jamie|hat
Cindy|cap}
do_execsql_test left-join-order-by-qualified {
do_execsql_test left-join-order-by-qualified {
select users.first_name, products.name from users left join products on users.id = products.id where users.first_name like 'Jam%' order by null limit 2;
} {Jamie|hat
James|}
@@ -199,4 +199,4 @@ Jamie||Edward}
do_execsql_test left-join-constant-condition-true-inner-join-constant-condition-false {
select u.first_name, p.name, u2.first_name from users u left join products as p on 1 join users u2 on 0 limit 5;
} {}
} {}

View File

@@ -26,3 +26,11 @@ do_execsql_test select-add {
do_execsql_test case-insensitive-columns {
select u.aGe + 1 from USERS u where U.AGe = 91 limit 1;
} {92}
do_execsql_test table-star {
select p.*, p.name from products p limit 1;
} {1|hat|79.0|hat}
do_execsql_test table-star-2 {
select p.*, u.age from users u join products p limit 1;
} {1|hat|79.0|94}