select refactor: order by and basic agg kinda work

This commit is contained in:
jussisaurio
2024-11-23 12:33:41 +02:00
parent d0466e1cae
commit 3f9e60633f
10 changed files with 3032 additions and 2699 deletions

View File

@@ -235,8 +235,8 @@ impl Connection {
Cmd::ExplainQueryPlan(stmt) => {
match stmt {
ast::Stmt::Select(select) => {
let plan = prepare_select_plan(&self.schema.borrow(), select)?;
let (plan, _) = optimize_plan(plan)?;
let plan = prepare_select_plan(&*self.schema.borrow(), select)?;
let plan = optimize_plan(plan)?;
println!("{}", plan);
}
_ => todo!(),

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,5 @@
use sqlite3_parser::ast::{self, UnaryOperator};
use super::optimizer::CachedResult;
#[cfg(feature = "json")]
use crate::function::JsonFunc;
use crate::function::{AggFunc, Func, FuncCtx, ScalarFunc};
@@ -24,6 +23,7 @@ pub fn translate_condition_expr(
expr: &ast::Expr,
cursor_hint: Option<usize>,
condition_metadata: ConditionMetadata,
result_set_register_start: usize,
) -> Result<()> {
match expr {
ast::Expr::Between { .. } => todo!(),
@@ -39,6 +39,7 @@ pub fn translate_condition_expr(
jump_if_condition_is_true: false,
..condition_metadata
},
result_set_register_start,
);
let _ = translate_condition_expr(
program,
@@ -46,6 +47,7 @@ pub fn translate_condition_expr(
rhs,
cursor_hint,
condition_metadata,
result_set_register_start,
);
}
ast::Expr::Binary(lhs, ast::Operator::Or, rhs) => {
@@ -61,6 +63,7 @@ pub fn translate_condition_expr(
jump_target_when_false,
..condition_metadata
},
result_set_register_start,
);
program.resolve_label(jump_target_when_false, program.offset());
let _ = translate_condition_expr(
@@ -69,6 +72,7 @@ pub fn translate_condition_expr(
rhs,
cursor_hint,
condition_metadata,
result_set_register_start,
);
}
ast::Expr::Binary(lhs, op, rhs) => {
@@ -79,7 +83,7 @@ pub fn translate_condition_expr(
lhs,
lhs_reg,
cursor_hint,
None,
result_set_register_start,
);
if let ast::Expr::Literal(_) = lhs.as_ref() {
program.mark_last_insn_constant()
@@ -91,7 +95,7 @@ pub fn translate_condition_expr(
rhs,
rhs_reg,
cursor_hint,
None,
result_set_register_start,
);
if let ast::Expr::Literal(_) = rhs.as_ref() {
program.mark_last_insn_constant()
@@ -340,7 +344,7 @@ pub fn translate_condition_expr(
lhs,
lhs_reg,
cursor_hint,
None,
result_set_register_start,
)?;
let rhs = rhs.as_ref().unwrap();
@@ -370,7 +374,7 @@ pub fn translate_condition_expr(
expr,
rhs_reg,
cursor_hint,
None,
result_set_register_start,
)?;
// If this is not the last condition, we need to jump to the 'jump_target_when_true' label if the condition is true.
if !last_condition {
@@ -414,7 +418,7 @@ pub fn translate_condition_expr(
expr,
rhs_reg,
cursor_hint,
None,
result_set_register_start,
)?;
program.emit_insn_with_label_dependency(
Insn::Eq {
@@ -460,7 +464,7 @@ pub fn translate_condition_expr(
lhs,
column_reg,
cursor_hint,
None,
result_set_register_start,
)?;
if let ast::Expr::Literal(_) = lhs.as_ref() {
program.mark_last_insn_constant();
@@ -471,7 +475,7 @@ pub fn translate_condition_expr(
rhs,
pattern_reg,
cursor_hint,
None,
result_set_register_start,
)?;
if let ast::Expr::Literal(_) = rhs.as_ref() {
program.mark_last_insn_constant();
@@ -545,6 +549,7 @@ pub fn translate_condition_expr(
expr,
cursor_hint,
condition_metadata,
result_set_register_start,
);
}
}
@@ -553,71 +558,33 @@ pub fn translate_condition_expr(
Ok(())
}
pub fn get_cached_or_translate(
program: &mut ProgramBuilder,
referenced_tables: Option<&[BTreeTableReference]>,
expr: &ast::Expr,
cursor_hint: Option<usize>,
cached_results: Option<&Vec<&CachedResult>>,
) -> Result<usize> {
if let Some(cached_results) = cached_results {
if let Some(cached_result) = cached_results
.iter()
.find(|cached_result| cached_result.source_expr == *expr)
{
return Ok(cached_result.register_idx);
}
}
let reg = program.alloc_register();
translate_expr(
program,
referenced_tables,
expr,
reg,
cursor_hint,
cached_results,
)?;
Ok(reg)
}
pub fn translate_expr(
program: &mut ProgramBuilder,
referenced_tables: Option<&[BTreeTableReference]>,
expr: &ast::Expr,
target_register: usize,
cursor_hint: Option<usize>,
cached_results: Option<&Vec<&CachedResult>>,
result_set_register_start: usize,
) -> Result<usize> {
if let Some(cached_results) = &cached_results {
if let Some(cached_result) = cached_results
.iter()
.find(|cached_result| cached_result.source_expr == *expr)
{
program.emit_insn(Insn::Copy {
src_reg: cached_result.register_idx,
dst_reg: target_register,
amount: 0,
});
return Ok(target_register);
}
}
match expr {
ast::Expr::AggRef { index } => todo!(),
ast::Expr::Between { .. } => todo!(),
ast::Expr::Binary(e1, op, e2) => {
let e1_reg = get_cached_or_translate(
let e1_reg = translate_expr(
program,
referenced_tables,
e1,
target_register,
cursor_hint,
cached_results,
result_set_register_start,
)?;
let e2_reg = get_cached_or_translate(
let e2_reg = translate_expr(
program,
referenced_tables,
e2,
target_register,
cursor_hint,
cached_results,
result_set_register_start,
)?;
match op {
@@ -741,7 +708,7 @@ pub fn translate_expr(
expr,
reg_expr,
cursor_hint,
cached_results,
result_set_register_start,
)?;
let reg_type = program.alloc_register();
program.emit_insn(Insn::String8 {
@@ -814,7 +781,7 @@ pub fn translate_expr(
&args[0],
regs,
cursor_hint,
cached_results,
result_set_register_start,
)?;
program.emit_insn(Insn::Function {
constant_mask: 0,
@@ -841,7 +808,7 @@ pub fn translate_expr(
arg,
reg,
cursor_hint,
cached_results,
result_set_register_start,
)?;
}
@@ -879,7 +846,7 @@ pub fn translate_expr(
arg,
target_register,
cursor_hint,
cached_results,
result_set_register_start,
)?;
if index < args.len() - 1 {
program.emit_insn_with_label_dependency(
@@ -915,7 +882,7 @@ pub fn translate_expr(
arg,
reg,
cursor_hint,
cached_results,
result_set_register_start,
)?;
}
program.emit_insn(Insn::Function {
@@ -948,7 +915,7 @@ pub fn translate_expr(
arg,
reg,
cursor_hint,
cached_results,
result_set_register_start,
)?;
}
program.emit_insn(Insn::Function {
@@ -985,7 +952,7 @@ pub fn translate_expr(
&args[0],
temp_reg,
cursor_hint,
cached_results,
result_set_register_start,
)?;
program.emit_insn(Insn::NotNull {
reg: temp_reg,
@@ -998,7 +965,7 @@ pub fn translate_expr(
&args[1],
temp_reg,
cursor_hint,
cached_results,
result_set_register_start,
)?;
program.emit_insn(Insn::Copy {
src_reg: temp_reg,
@@ -1031,7 +998,7 @@ pub fn translate_expr(
arg,
reg,
cursor_hint,
cached_results,
result_set_register_start,
)?;
if let ast::Expr::Literal(_) = arg {
program.mark_last_insn_constant()
@@ -1079,7 +1046,7 @@ pub fn translate_expr(
&args[0],
regs,
cursor_hint,
cached_results,
result_set_register_start,
)?;
program.emit_insn(Insn::Function {
constant_mask: 0,
@@ -1116,7 +1083,7 @@ pub fn translate_expr(
arg,
target_reg,
cursor_hint,
cached_results,
result_set_register_start,
)?;
}
}
@@ -1154,7 +1121,7 @@ pub fn translate_expr(
&args[0],
str_reg,
cursor_hint,
cached_results,
result_set_register_start,
)?;
translate_expr(
program,
@@ -1162,7 +1129,7 @@ pub fn translate_expr(
&args[1],
start_reg,
cursor_hint,
cached_results,
result_set_register_start,
)?;
if args.len() == 3 {
translate_expr(
@@ -1171,7 +1138,7 @@ pub fn translate_expr(
&args[2],
length_reg,
cursor_hint,
cached_results,
result_set_register_start,
)?;
}
@@ -1201,7 +1168,7 @@ pub fn translate_expr(
&args[0],
regs,
cursor_hint,
cached_results,
result_set_register_start,
)?;
program.emit_insn(Insn::Function {
constant_mask: 0,
@@ -1225,7 +1192,7 @@ pub fn translate_expr(
&args[0],
arg_reg,
cursor_hint,
cached_results,
result_set_register_start,
)?;
start_reg = arg_reg;
}
@@ -1250,7 +1217,7 @@ pub fn translate_expr(
arg,
target_reg,
cursor_hint,
cached_results,
result_set_register_start,
)?;
}
}
@@ -1290,7 +1257,7 @@ pub fn translate_expr(
arg,
reg,
cursor_hint,
cached_results,
result_set_register_start,
)?;
if let ast::Expr::Literal(_) = arg {
program.mark_last_insn_constant();
@@ -1323,7 +1290,7 @@ pub fn translate_expr(
arg,
reg,
cursor_hint,
cached_results,
result_set_register_start,
)?;
if let ast::Expr::Literal(_) = arg {
program.mark_last_insn_constant()
@@ -1357,7 +1324,7 @@ pub fn translate_expr(
arg,
reg,
cursor_hint,
cached_results,
result_set_register_start,
)?;
if let ast::Expr::Literal(_) = arg {
program.mark_last_insn_constant()
@@ -1395,7 +1362,7 @@ pub fn translate_expr(
&args[0],
first_reg,
cursor_hint,
cached_results,
result_set_register_start,
)?;
let second_reg = program.alloc_register();
translate_expr(
@@ -1404,7 +1371,7 @@ pub fn translate_expr(
&args[1],
second_reg,
cursor_hint,
cached_results,
result_set_register_start,
)?;
program.emit_insn(Insn::Function {
constant_mask: 0,
@@ -1536,7 +1503,7 @@ pub fn translate_expr(
&exprs[0],
target_register,
cursor_hint,
cached_results,
result_set_register_start,
)?;
} else {
// Parenthesized expressions with multiple arguments are reserved for special cases
@@ -1660,7 +1627,7 @@ pub fn translate_aggregation(
expr,
expr_reg,
cursor_hint,
None,
0,
)?;
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
@@ -1682,7 +1649,7 @@ pub fn translate_aggregation(
expr,
expr_reg,
cursor_hint,
None,
0,
);
expr_reg
};
@@ -1725,7 +1692,7 @@ pub fn translate_aggregation(
expr,
expr_reg,
cursor_hint,
None,
0,
)?;
translate_expr(
program,
@@ -1733,7 +1700,7 @@ pub fn translate_aggregation(
&delimiter_expr,
delimiter_reg,
cursor_hint,
None,
0,
)?;
program.emit_insn(Insn::AggStep {
@@ -1757,7 +1724,7 @@ pub fn translate_aggregation(
expr,
expr_reg,
cursor_hint,
None,
0,
)?;
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
@@ -1779,7 +1746,7 @@ pub fn translate_aggregation(
expr,
expr_reg,
cursor_hint,
None,
0,
)?;
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
@@ -1816,7 +1783,7 @@ pub fn translate_aggregation(
expr,
expr_reg,
cursor_hint,
None,
0,
)?;
translate_expr(
program,
@@ -1824,7 +1791,7 @@ pub fn translate_aggregation(
&delimiter_expr,
delimiter_reg,
cursor_hint,
None,
0,
)?;
program.emit_insn(Insn::AggStep {
@@ -1848,7 +1815,7 @@ pub fn translate_aggregation(
expr,
expr_reg,
cursor_hint,
None,
0,
)?;
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
@@ -1870,7 +1837,7 @@ pub fn translate_aggregation(
expr,
expr_reg,
cursor_hint,
None,
0,
)?;
program.emit_insn(Insn::AggStep {
acc_reg: target_register,

View File

@@ -98,7 +98,7 @@ pub fn translate_insert(
expr,
column_registers_start + col,
None,
None,
0,
)?;
}
program.emit_insn(Insn::Yield {

View File

@@ -1,4 +1,4 @@
use std::{collections::HashMap, rc::Rc};
use std::rc::Rc;
use sqlite3_parser::ast;
@@ -6,7 +6,7 @@ use crate::{schema::Index, Result};
use super::plan::{
get_table_ref_bitmask_for_ast_expr, get_table_ref_bitmask_for_operator, BTreeTableReference,
Direction, IterationDirection, Operator, Plan, ProjectionColumn, Search,
Direction, IterationDirection, Plan, Search, SourceOperator,
};
/**
@@ -14,49 +14,45 @@ use super::plan::{
* TODO: these could probably be done in less passes,
* but having them separate makes them easier to understand
*/
pub fn optimize_plan(mut select_plan: Plan) -> Result<(Plan, ExpressionResultCache)> {
let mut expr_result_cache = ExpressionResultCache::new();
pub fn optimize_plan(mut select_plan: Plan) -> Result<Plan> {
push_predicates(
&mut select_plan.root_operator,
&mut select_plan.source,
&mut select_plan.where_clause,
&select_plan.referenced_tables,
)?;
if eliminate_constants(&mut select_plan.root_operator)?
if eliminate_constants(&mut select_plan.source)?
== ConstantConditionEliminationResult::ImpossibleCondition
{
return Ok((
Plan {
root_operator: Operator::Nothing,
referenced_tables: vec![],
available_indexes: vec![],
},
expr_result_cache,
));
return Ok(Plan {
source: SourceOperator::Nothing,
..select_plan
});
}
use_indexes(
&mut select_plan.root_operator,
&mut select_plan.source,
&select_plan.referenced_tables,
&select_plan.available_indexes,
)?;
eliminate_unnecessary_orderby(
&mut select_plan.root_operator,
&mut select_plan.source,
&mut select_plan.order_by,
&select_plan.referenced_tables,
&select_plan.available_indexes,
)?;
find_shared_expressions_in_child_operators_and_mark_them_so_that_the_parent_operator_doesnt_recompute_them(&select_plan.root_operator, &mut expr_result_cache);
Ok((select_plan, expr_result_cache))
Ok(select_plan)
}
fn _operator_is_already_ordered_by(
operator: &mut Operator,
operator: &mut SourceOperator,
key: &mut ast::Expr,
referenced_tables: &[BTreeTableReference],
available_indexes: &Vec<Rc<Index>>,
) -> Result<bool> {
match operator {
Operator::Scan {
SourceOperator::Scan {
table_reference, ..
} => Ok(key.is_primary_key_of(table_reference.table_index)),
Operator::Search {
SourceOperator::Search {
table_reference,
search,
..
@@ -77,61 +73,53 @@ fn _operator_is_already_ordered_by(
Ok(index_is_the_same)
}
},
Operator::Join { left, .. } => {
SourceOperator::Join { left, .. } => {
_operator_is_already_ordered_by(left, key, referenced_tables, available_indexes)
}
Operator::Aggregate { source, .. } => {
_operator_is_already_ordered_by(source, key, referenced_tables, available_indexes)
}
Operator::Projection { source, .. } => {
_operator_is_already_ordered_by(source, key, referenced_tables, available_indexes)
}
_ => Ok(false),
}
}
fn eliminate_unnecessary_orderby(
operator: &mut Operator,
operator: &mut SourceOperator,
order_by: &mut Option<Vec<(ast::Expr, Direction)>>,
referenced_tables: &[BTreeTableReference],
available_indexes: &Vec<Rc<Index>>,
) -> Result<()> {
match operator {
Operator::Order { source, key, .. } => {
if key.len() != 1 {
// TODO: handle multiple order by keys
return Ok(());
}
let (key, direction) = key.first_mut().unwrap();
let already_ordered = _operator_is_already_ordered_by(source, key, referenced_tables, available_indexes)?;
if already_ordered {
push_scan_direction(source, direction);
*operator = source.take_ownership();
}
Ok(())
}
Operator::Limit { source, .. } => {
eliminate_unnecessary_orderby(source, referenced_tables, available_indexes)?;
Ok(())
}
_ => Ok(()),
if order_by.is_none() {
return Ok(());
}
let o = order_by.as_mut().unwrap();
if o.len() != 1 {
// TODO: handle multiple order by keys
return Ok(());
}
let (key, _) = o.first_mut().unwrap();
let already_ordered =
_operator_is_already_ordered_by(operator, key, referenced_tables, available_indexes)?;
if already_ordered {
*order_by = None;
}
Ok(())
}
/**
* Use indexes where possible
*/
fn use_indexes(
operator: &mut Operator,
operator: &mut SourceOperator,
referenced_tables: &[BTreeTableReference],
available_indexes: &[Rc<Index>],
) -> Result<()> {
match operator {
Operator::Search { .. } => Ok(()),
Operator::Scan {
SourceOperator::Search { .. } => Ok(()),
SourceOperator::Scan {
table_reference,
predicates: filter,
id,
@@ -162,12 +150,11 @@ fn use_indexes(
}
Either::Right(index_search) => {
fs.remove(i);
*operator = Operator::Search {
*operator = SourceOperator::Search {
id: *id,
table_reference: table_reference.clone(),
predicates: Some(fs.clone()),
search: index_search,
step: 0,
};
return Ok(());
@@ -177,32 +164,12 @@ fn use_indexes(
Ok(())
}
Operator::Aggregate { source, .. } => {
use_indexes(source, referenced_tables, available_indexes)?;
Ok(())
}
Operator::Filter { source, .. } => {
use_indexes(source, referenced_tables, available_indexes)?;
Ok(())
}
Operator::Limit { source, .. } => {
use_indexes(source, referenced_tables, available_indexes)?;
Ok(())
}
Operator::Join { left, right, .. } => {
SourceOperator::Join { left, right, .. } => {
use_indexes(left, referenced_tables, available_indexes)?;
use_indexes(right, referenced_tables, available_indexes)?;
Ok(())
}
Operator::Order { source, .. } => {
use_indexes(source, referenced_tables, available_indexes)?;
Ok(())
}
Operator::Projection { source, .. } => {
use_indexes(source, referenced_tables, available_indexes)?;
Ok(())
}
Operator::Nothing => Ok(()),
SourceOperator::Nothing => Ok(()),
}
}
@@ -214,33 +181,11 @@ enum ConstantConditionEliminationResult {
// removes predicates that are always true
// returns a ConstantEliminationResult indicating whether any predicates are always false
fn eliminate_constants(operator: &mut Operator) -> Result<ConstantConditionEliminationResult> {
fn eliminate_constants(
operator: &mut SourceOperator,
) -> Result<ConstantConditionEliminationResult> {
match operator {
Operator::Filter {
source, predicates, ..
} => {
let mut i = 0;
while i < predicates.len() {
let predicate = &predicates[i];
if predicate.is_always_true()? {
predicates.remove(i);
} else if predicate.is_always_false()? {
return Ok(ConstantConditionEliminationResult::ImpossibleCondition);
} else {
i += 1;
}
}
if predicates.is_empty() {
*operator = source.take_ownership();
eliminate_constants(operator)?;
} else {
eliminate_constants(source)?;
}
Ok(ConstantConditionEliminationResult::Continue)
}
Operator::Join {
SourceOperator::Join {
left,
right,
predicates,
@@ -278,44 +223,7 @@ fn eliminate_constants(operator: &mut Operator) -> Result<ConstantConditionElimi
Ok(ConstantConditionEliminationResult::Continue)
}
Operator::Aggregate { source, .. } => {
if eliminate_constants(source)?
== ConstantConditionEliminationResult::ImpossibleCondition
{
*source = Box::new(Operator::Nothing);
}
// Aggregation operator can return a row even if the source is empty e.g. count(1) from users where 0
Ok(ConstantConditionEliminationResult::Continue)
}
Operator::Limit { source, .. } => {
let constant_elimination_result = eliminate_constants(source)?;
if constant_elimination_result
== ConstantConditionEliminationResult::ImpossibleCondition
{
*operator = Operator::Nothing;
}
Ok(constant_elimination_result)
}
Operator::Order { source, .. } => {
if eliminate_constants(source)?
== ConstantConditionEliminationResult::ImpossibleCondition
{
*operator = Operator::Nothing;
return Ok(ConstantConditionEliminationResult::ImpossibleCondition);
}
Ok(ConstantConditionEliminationResult::Continue)
}
Operator::Projection { source, .. } => {
if eliminate_constants(source)?
== ConstantConditionEliminationResult::ImpossibleCondition
{
*operator = Operator::Nothing;
return Ok(ConstantConditionEliminationResult::ImpossibleCondition);
}
Ok(ConstantConditionEliminationResult::Continue)
}
Operator::Scan { predicates, .. } => {
SourceOperator::Scan { predicates, .. } => {
if let Some(ps) = predicates {
let mut i = 0;
while i < ps.len() {
@@ -335,7 +243,7 @@ fn eliminate_constants(operator: &mut Operator) -> Result<ConstantConditionElimi
}
Ok(ConstantConditionEliminationResult::Continue)
}
Operator::Search { predicates, .. } => {
SourceOperator::Search { predicates, .. } => {
if let Some(predicates) = predicates {
let mut i = 0;
while i < predicates.len() {
@@ -352,7 +260,7 @@ fn eliminate_constants(operator: &mut Operator) -> Result<ConstantConditionElimi
Ok(ConstantConditionEliminationResult::Continue)
}
Operator::Nothing => Ok(ConstantConditionEliminationResult::Continue),
SourceOperator::Nothing => Ok(ConstantConditionEliminationResult::Continue),
}
}
@@ -360,42 +268,35 @@ fn eliminate_constants(operator: &mut Operator) -> Result<ConstantConditionElimi
Recursively pushes predicates down the tree, as far as possible.
*/
fn push_predicates(
operator: &mut Operator,
operator: &mut SourceOperator,
where_clause: &mut Option<Vec<ast::Expr>>,
referenced_tables: &Vec<BTreeTableReference>,
) -> Result<()> {
match operator {
Operator::Filter {
source, predicates, ..
} => {
let mut i = 0;
while i < predicates.len() {
// try to push the predicate to the source
// if it succeeds, remove the predicate from the filter
let predicate_owned = predicates[i].take_ownership();
let Some(predicate) = push_predicate(source, predicate_owned, referenced_tables)?
else {
predicates.remove(i);
continue;
};
predicates[i] = predicate;
i += 1;
}
if predicates.is_empty() {
*operator = source.take_ownership();
}
Ok(())
if let Some(predicates) = where_clause {
let mut i = 0;
while i < predicates.len() {
let predicate = predicates[i].take_ownership();
let Some(predicate) = push_predicate(operator, predicate, referenced_tables)? else {
predicates.remove(i);
continue;
};
predicates[i] = predicate;
i += 1;
}
Operator::Join {
if predicates.is_empty() {
*where_clause = None;
}
}
match operator {
SourceOperator::Join {
left,
right,
predicates,
outer,
..
} => {
push_predicates(left, referenced_tables)?;
push_predicates(right, referenced_tables)?;
push_predicates(left, where_clause, referenced_tables)?;
push_predicates(right, where_clause, referenced_tables)?;
if predicates.is_none() {
return Ok(());
@@ -433,26 +334,9 @@ fn push_predicates(
Ok(())
}
Operator::Aggregate { source, .. } => {
push_predicates(source, referenced_tables)?;
Ok(())
}
Operator::Limit { source, .. } => {
push_predicates(source, referenced_tables)?;
Ok(())
}
Operator::Order { source, .. } => {
push_predicates(source, referenced_tables)?;
Ok(())
}
Operator::Projection { source, .. } => {
push_predicates(source, referenced_tables)?;
Ok(())
}
Operator::Scan { .. } => Ok(()),
Operator::Search { .. } => Ok(()),
Operator::Nothing => Ok(()),
SourceOperator::Scan { .. } => Ok(()),
SourceOperator::Search { .. } => Ok(()),
SourceOperator::Nothing => Ok(()),
}
}
@@ -461,12 +345,12 @@ fn push_predicates(
Returns Ok(None) if the predicate was pushed, otherwise returns itself as Ok(Some(predicate))
*/
fn push_predicate(
operator: &mut Operator,
operator: &mut SourceOperator,
predicate: ast::Expr,
referenced_tables: &Vec<BTreeTableReference>,
) -> Result<Option<ast::Expr>> {
match operator {
Operator::Scan {
SourceOperator::Scan {
predicates,
table_reference,
..
@@ -497,22 +381,8 @@ fn push_predicate(
Ok(None)
}
Operator::Search { .. } => Ok(Some(predicate)),
Operator::Filter {
source,
predicates: ps,
..
} => {
let push_result = push_predicate(source, predicate, referenced_tables)?;
if push_result.is_none() {
return Ok(None);
}
ps.push(push_result.unwrap());
Ok(None)
}
Operator::Join {
SourceOperator::Search { .. } => Ok(Some(predicate)),
SourceOperator::Join {
left,
right,
predicates: join_on_preds,
@@ -552,46 +422,13 @@ fn push_predicate(
Ok(None)
}
Operator::Aggregate { source, .. } => {
let push_result = push_predicate(source, predicate, referenced_tables)?;
if push_result.is_none() {
return Ok(None);
}
Ok(Some(push_result.unwrap()))
}
Operator::Limit { source, .. } => {
let push_result = push_predicate(source, predicate, referenced_tables)?;
if push_result.is_none() {
return Ok(None);
}
Ok(Some(push_result.unwrap()))
}
Operator::Order { source, .. } => {
let push_result = push_predicate(source, predicate, referenced_tables)?;
if push_result.is_none() {
return Ok(None);
}
Ok(Some(push_result.unwrap()))
}
Operator::Projection { source, .. } => {
let push_result = push_predicate(source, predicate, referenced_tables)?;
if push_result.is_none() {
return Ok(None);
}
Ok(Some(push_result.unwrap()))
}
Operator::Nothing => Ok(Some(predicate)),
SourceOperator::Nothing => Ok(Some(predicate)),
}
}
fn push_scan_direction(operator: &mut Operator, direction: &Direction) {
fn push_scan_direction(operator: &mut SourceOperator, direction: &Direction) {
match operator {
Operator::Projection { source, .. } => push_scan_direction(source, direction),
Operator::Scan { iter_dir, .. } => {
SourceOperator::Scan { iter_dir, .. } => {
if iter_dir.is_none() {
match direction {
Direction::Ascending => *iter_dir = Some(IterationDirection::Forwards),
@@ -603,381 +440,6 @@ fn push_scan_direction(operator: &mut Operator, direction: &Direction) {
}
}
#[derive(Debug)]
pub struct ExpressionResultCache {
resultmap: HashMap<usize, CachedResult>,
keymap: HashMap<usize, Vec<usize>>,
}
#[derive(Debug)]
pub struct CachedResult {
pub register_idx: usize,
pub source_expr: ast::Expr,
}
const OPERATOR_ID_MULTIPLIER: usize = 10000;
/**
ExpressionResultCache is a cache for the results of expressions that are computed in the query plan,
or more precisely, the VM registers that hold the results of these expressions.
Right now the cache is mainly used to avoid recomputing e.g. the result of an aggregation expression
e.g. SELECT t.a, SUM(t.b) FROM t GROUP BY t.a ORDER BY SUM(t.b)
*/
impl ExpressionResultCache {
pub fn new() -> Self {
ExpressionResultCache {
resultmap: HashMap::new(),
keymap: HashMap::new(),
}
}
/**
Store the result of an expression that is computed in the query plan.
The result is stored in a VM register. A copy of the expression AST node is
stored as well, so that parent operators can use it to compare their own expressions
with the one that was computed in a child operator.
This is a weakness of our current reliance on a 3rd party AST library, as we can't
e.g. modify the AST to add identifiers to nodes or replace nodes with some kind of
reference to a register, etc.
*/
pub fn cache_result_register(
&mut self,
operator_id: usize,
result_column_idx: usize,
register_idx: usize,
expr: ast::Expr,
) {
let key = operator_id * OPERATOR_ID_MULTIPLIER + result_column_idx;
self.resultmap.insert(
key,
CachedResult {
register_idx,
source_expr: expr,
},
);
}
/**
Set a mapping from a parent operator to a child operator, so that the parent operator
can look up the register of a result that was computed in the child operator.
E.g. "Parent operator's result column 3 is computed in child operator 5, result column 2"
*/
pub fn set_precomputation_key(
&mut self,
operator_id: usize,
result_column_idx: usize,
child_operator_id: usize,
child_operator_result_column_idx_mask: usize,
) {
let key = operator_id * OPERATOR_ID_MULTIPLIER + result_column_idx;
let mut values = Vec::new();
for i in 0..64 {
if (child_operator_result_column_idx_mask >> i) & 1 == 1 {
values.push(child_operator_id * OPERATOR_ID_MULTIPLIER + i);
}
}
self.keymap.insert(key, values);
}
/**
Get the cache entries for a given operator and result column index.
There may be multiple cached entries, e.g. a binary operator's both
arms may have been cached.
*/
pub fn get_cached_result_registers(
&self,
operator_id: usize,
result_column_idx: usize,
) -> Option<Vec<&CachedResult>> {
let key = operator_id * OPERATOR_ID_MULTIPLIER + result_column_idx;
self.keymap.get(&key).and_then(|keys| {
let mut results = Vec::new();
for key in keys {
if let Some(result) = self.resultmap.get(key) {
results.push(result);
}
}
if results.is_empty() {
None
} else {
Some(results)
}
})
}
}
type ResultColumnIndexBitmask = usize;
/**
Find all result columns in an operator that match an expression, either fully or partially.
This is used to find the result columns that are computed in an operator and that are used
in a parent operator, so that the parent operator can look up the register that holds the result
of the child operator's expression.
The result is returned as a bitmask due to performance neuroticism. A limitation of this is that
we can only handle 64 result columns per operator.
*/
fn find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(
expr: &ast::Expr,
operator: &Operator,
) -> ResultColumnIndexBitmask {
let exact_match = match operator {
Operator::Aggregate {
aggregates,
group_by,
..
} => {
let mut idx = 0;
let mut mask = 0;
for agg in aggregates.iter() {
if agg.original_expr == *expr {
mask |= 1 << idx;
}
idx += 1;
}
if let Some(group_by) = group_by {
for g in group_by.iter() {
if g == expr {
mask |= 1 << idx;
}
idx += 1
}
}
mask
}
Operator::Filter { .. } => 0,
Operator::Limit { .. } => 0,
Operator::Join { .. } => 0,
Operator::Order { .. } => 0,
Operator::Projection { expressions, .. } => {
let mut mask = 0;
for (idx, e) in expressions.iter().enumerate() {
match e {
ProjectionColumn::Column(c) => {
if c == expr {
mask |= 1 << idx;
}
}
ProjectionColumn::Star => {}
ProjectionColumn::TableStar(_) => {}
}
}
mask
}
Operator::Scan { .. } => 0,
Operator::Search { .. } => 0,
Operator::Nothing => 0,
};
if exact_match != 0 {
return exact_match;
}
match expr {
ast::Expr::Between {
lhs,
not: _,
start,
end,
} => {
let mut mask = 0;
mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(lhs, operator);
mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(start, operator);
mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(end, operator);
mask
}
ast::Expr::Binary(lhs, _op, rhs) => {
let mut mask = 0;
mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(lhs, operator);
mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(rhs, operator);
mask
}
ast::Expr::Case {
base,
when_then_pairs,
else_expr,
} => {
let mut mask = 0;
if let Some(base) = base {
mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(base, operator);
}
for (w, t) in when_then_pairs.iter() {
mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(w, operator);
mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(t, operator);
}
if let Some(e) = else_expr {
mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(e, operator);
}
mask
}
ast::Expr::Cast { expr, type_name: _ } => {
find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(
expr, operator,
)
}
ast::Expr::Collate(expr, _collation) => {
find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(
expr, operator,
)
}
ast::Expr::DoublyQualified(_schema, _tbl, _ident) => 0,
ast::Expr::Exists(_) => 0,
ast::Expr::FunctionCall {
name: _,
distinctness: _,
args,
order_by: _,
filter_over: _,
} => {
let mut mask = 0;
if let Some(args) = args {
for a in args.iter() {
mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(a, operator);
}
}
mask
}
ast::Expr::FunctionCallStar {
name: _,
filter_over: _,
} => 0,
ast::Expr::Id(_) => unreachable!("Ids have been bound to Column references"),
ast::Expr::Column { .. } => 0,
ast::Expr::InList { lhs, not: _, rhs } => {
let mut mask = 0;
mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(lhs, operator);
if let Some(rhs) = rhs {
for r in rhs.iter() {
mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(r, operator);
}
}
mask
}
ast::Expr::InSelect {
lhs,
not: _,
rhs: _,
} => {
find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(
lhs, operator,
)
}
ast::Expr::InTable {
lhs: _,
not: _,
rhs: _,
args: _,
} => 0,
ast::Expr::IsNull(expr) => {
find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(
expr, operator,
)
}
ast::Expr::Like {
lhs,
not: _,
op: _,
rhs,
escape: _,
} => {
let mut mask = 0;
mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(lhs, operator);
mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(rhs, operator);
mask
}
ast::Expr::Literal(_) => 0,
ast::Expr::Name(_) => 0,
ast::Expr::NotNull(expr) => {
find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(
expr, operator,
)
}
ast::Expr::Parenthesized(expr) => {
let mut mask = 0;
for e in expr.iter() {
mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(e, operator);
}
mask
}
ast::Expr::Qualified(_, _) => {
unreachable!("Qualified expressions have been bound to Column references")
}
ast::Expr::Raise(_, _) => 0,
ast::Expr::Subquery(_) => 0,
ast::Expr::Unary(_op, expr) => {
find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(
expr, operator,
)
}
ast::Expr::Variable(_) => 0,
}
}
/**
* This function is used to find all the expressions that are shared between the parent operator and the child operators.
* If an expression is shared between the parent and child operators, then the parent operator should not recompute the expression.
* Instead, it should use the result of the expression that was computed by the child operator.
*/
fn find_shared_expressions_in_child_operators_and_mark_them_so_that_the_parent_operator_doesnt_recompute_them(
operator: &Operator,
expr_result_cache: &mut ExpressionResultCache,
) {
match operator {
Operator::Aggregate {
source,
..
} => {
find_shared_expressions_in_child_operators_and_mark_them_so_that_the_parent_operator_doesnt_recompute_them(
source, expr_result_cache,
)
}
Operator::Filter { .. } => unreachable!(),
Operator::Limit { source, .. } => {
find_shared_expressions_in_child_operators_and_mark_them_so_that_the_parent_operator_doesnt_recompute_them(source, expr_result_cache)
}
Operator::Join { .. } => {}
Operator::Order { source, key, .. } => {
for (idx, (expr, _)) in key.iter().enumerate() {
let result = find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(expr, source);
if result != 0 {
expr_result_cache.set_precomputation_key(
operator.id(),
idx,
source.id(),
result,
);
}
}
find_shared_expressions_in_child_operators_and_mark_them_so_that_the_parent_operator_doesnt_recompute_them(source, expr_result_cache)
}
Operator::Projection { source, expressions, .. } => {
for (idx, expr) in expressions.iter().enumerate() {
if let ProjectionColumn::Column(expr) = expr {
let result = find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(expr, source);
if result != 0 {
expr_result_cache.set_precomputation_key(
operator.id(),
idx,
source.id(),
result,
);
}
}
}
find_shared_expressions_in_child_operators_and_mark_them_so_that_the_parent_operator_doesnt_recompute_them(source, expr_result_cache)
}
Operator::Scan { .. } => {}
Operator::Search { .. } => {}
Operator::Nothing => {}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConstantPredicate {
AlwaysTrue,
@@ -1286,8 +748,8 @@ impl TakeOwnership for ast::Expr {
}
}
impl TakeOwnership for Operator {
impl TakeOwnership for SourceOperator {
fn take_ownership(&mut self) -> Self {
std::mem::replace(self, Operator::Nothing)
std::mem::replace(self, SourceOperator::Nothing)
}
}

View File

@@ -12,16 +12,29 @@ use crate::{
Result,
};
#[derive(Debug)]
pub enum ResultSetColumn {
Scalar(ast::Expr),
Agg(Aggregate),
ComputedAgg(ast::Expr),
}
#[derive(Debug)]
pub struct Plan {
pub root_operator: Operator,
pub source: SourceOperator,
pub result_columns: Vec<ResultSetColumn>,
pub where_clause: Option<Vec<ast::Expr>>,
pub group_by: Option<Vec<ast::Expr>>,
pub order_by: Option<Vec<(ast::Expr, Direction)>>,
pub aggregates: Option<Vec<Aggregate>>,
pub limit: Option<usize>,
pub referenced_tables: Vec<BTreeTableReference>,
pub available_indexes: Vec<Rc<Index>>,
}
impl Display for Plan {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.root_operator)
write!(f, "{}", self.source)
}
}
@@ -45,69 +58,17 @@ pub enum IterationDirection {
TODO: perhaps 'step' shouldn't be in this struct, since it's an execution time concept, not a plan time concept.
*/
#[derive(Clone, Debug)]
pub enum Operator {
// Aggregate operator
// This operator is used to compute aggregate functions like SUM, AVG, COUNT, etc.
// It takes a source operator and a list of aggregate functions to compute.
// GROUP BY is not supported yet.
Aggregate {
id: usize,
source: Box<Operator>,
aggregates: Vec<Aggregate>,
group_by: Option<Vec<ast::Expr>>,
step: usize,
},
// Filter operator
// This operator is used to filter rows from the source operator.
// It takes a source operator and a list of predicates to evaluate.
// Only rows for which all predicates evaluate to true are passed to the next operator.
// Generally filter operators will only exist in unoptimized plans,
// as the optimizer will try to push filters down to the lowest possible level,
// e.g. a table scan.
Filter {
id: usize,
source: Box<Operator>,
predicates: Vec<ast::Expr>,
},
// Limit operator
// This operator is used to limit the number of rows returned by the source operator.
Limit {
id: usize,
source: Box<Operator>,
limit: usize,
step: usize,
},
pub enum SourceOperator {
// Join operator
// This operator is used to join two source operators.
// It takes a left and right source operator, a list of predicates to evaluate,
// and a boolean indicating whether it is an outer join.
Join {
id: usize,
left: Box<Operator>,
right: Box<Operator>,
left: Box<SourceOperator>,
right: Box<SourceOperator>,
predicates: Option<Vec<ast::Expr>>,
outer: bool,
step: usize,
},
// Order operator
// This operator is used to sort the rows returned by the source operator.
Order {
id: usize,
source: Box<Operator>,
key: Vec<(ast::Expr, Direction)>,
step: usize,
},
// Projection operator
// This operator is used to project columns from the source operator.
// It takes a source operator and a list of expressions to evaluate.
// e.g. SELECT foo, bar FROM t1
// In this example, the expressions would be [foo, bar]
// and the source operator would be a Scan operator for table t1.
Projection {
id: usize,
source: Box<Operator>,
expressions: Vec<ProjectionColumn>,
step: usize,
},
// Scan operator
// This operator is used to scan a table.
@@ -122,7 +83,6 @@ pub enum Operator {
id: usize,
table_reference: BTreeTableReference,
predicates: Option<Vec<ast::Expr>>,
step: usize,
iter_dir: Option<IterationDirection>,
},
// Search operator
@@ -133,7 +93,6 @@ pub enum Operator {
table_reference: BTreeTableReference,
search: Search,
predicates: Option<Vec<ast::Expr>>,
step: usize,
},
// Nothing operator
// This operator is used to represent an empty query.
@@ -168,106 +127,30 @@ pub enum Search {
},
}
#[derive(Clone, Debug)]
pub enum ProjectionColumn {
Column(ast::Expr),
Star,
TableStar(BTreeTableReference),
}
impl ProjectionColumn {
impl SourceOperator {
pub fn column_count(&self, referenced_tables: &[BTreeTableReference]) -> usize {
match self {
ProjectionColumn::Column(_) => 1,
ProjectionColumn::Star => {
let mut count = 0;
for table_reference in referenced_tables {
count += table_reference.table.columns.len();
}
count
}
ProjectionColumn::TableStar(table_reference) => table_reference.table.columns.len(),
}
}
}
impl Operator {
pub fn column_count(&self, referenced_tables: &[BTreeTableReference]) -> usize {
match self {
Operator::Aggregate {
group_by,
aggregates,
..
} => aggregates.len() + group_by.as_ref().map_or(0, |g| g.len()),
Operator::Filter { source, .. } => source.column_count(referenced_tables),
Operator::Limit { source, .. } => source.column_count(referenced_tables),
Operator::Join { left, right, .. } => {
SourceOperator::Join { left, right, .. } => {
left.column_count(referenced_tables) + right.column_count(referenced_tables)
}
Operator::Order { source, .. } => source.column_count(referenced_tables),
Operator::Projection { expressions, .. } => expressions
.iter()
.map(|e| e.column_count(referenced_tables))
.sum(),
Operator::Scan {
SourceOperator::Scan {
table_reference, ..
} => table_reference.table.columns.len(),
Operator::Search {
SourceOperator::Search {
table_reference, ..
} => table_reference.table.columns.len(),
Operator::Nothing => 0,
SourceOperator::Nothing => 0,
}
}
pub fn column_names(&self) -> Vec<String> {
match self {
Operator::Aggregate {
aggregates,
group_by,
..
} => {
let mut names = vec![];
for agg in aggregates.iter() {
names.push(agg.func.to_string().to_string());
}
if let Some(group_by) = group_by {
for expr in group_by.iter() {
match expr {
ast::Expr::Id(ident) => names.push(ident.0.clone()),
ast::Expr::Qualified(tbl, ident) => {
names.push(format!("{}.{}", tbl.0, ident.0))
}
e => names.push(e.to_string()),
}
}
}
names
}
Operator::Filter { source, .. } => source.column_names(),
Operator::Limit { source, .. } => source.column_names(),
Operator::Join { left, right, .. } => {
SourceOperator::Join { left, right, .. } => {
let mut names = left.column_names();
names.extend(right.column_names());
names
}
Operator::Order { source, .. } => source.column_names(),
Operator::Projection { expressions, .. } => expressions
.iter()
.map(|e| match e {
ProjectionColumn::Column(expr) => match expr {
ast::Expr::Id(ident) => ident.0.clone(),
ast::Expr::Qualified(tbl, ident) => format!("{}.{}", tbl.0, ident.0),
_ => "expr".to_string(),
},
ProjectionColumn::Star => "*".to_string(),
ProjectionColumn::TableStar(table_reference) => {
format!("{}.{}", table_reference.table_identifier, "*")
}
})
.collect(),
Operator::Scan {
SourceOperator::Scan {
table_reference, ..
} => table_reference
.table
@@ -275,7 +158,7 @@ impl Operator {
.iter()
.map(|c| c.name.clone())
.collect(),
Operator::Search {
SourceOperator::Search {
table_reference, ..
} => table_reference
.table
@@ -283,21 +166,16 @@ impl Operator {
.iter()
.map(|c| c.name.clone())
.collect(),
Operator::Nothing => vec![],
SourceOperator::Nothing => vec![],
}
}
pub fn id(&self) -> usize {
match self {
Operator::Aggregate { id, .. } => *id,
Operator::Filter { id, .. } => *id,
Operator::Limit { id, .. } => *id,
Operator::Join { id, .. } => *id,
Operator::Order { id, .. } => *id,
Operator::Projection { id, .. } => *id,
Operator::Scan { id, .. } => *id,
Operator::Search { id, .. } => *id,
Operator::Nothing => unreachable!(),
SourceOperator::Join { id, .. } => *id,
SourceOperator::Scan { id, .. } => *id,
SourceOperator::Search { id, .. } => *id,
SourceOperator::Nothing => unreachable!(),
}
}
}
@@ -337,10 +215,10 @@ impl Display for Aggregate {
}
// For EXPLAIN QUERY PLAN
impl Display for Operator {
impl Display for SourceOperator {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
fn fmt_operator(
operator: &Operator,
operator: &SourceOperator,
f: &mut Formatter,
level: usize,
last: bool,
@@ -356,34 +234,7 @@ impl Display for Operator {
};
match operator {
Operator::Aggregate {
source, aggregates, ..
} => {
// e.g. Aggregate count(*), sum(x)
let aggregates_display_string = aggregates
.iter()
.map(|agg| agg.to_string())
.collect::<Vec<String>>()
.join(", ");
writeln!(f, "{}AGGREGATE {}", indent, aggregates_display_string)?;
fmt_operator(source, f, level + 1, true)
}
Operator::Filter {
source, predicates, ..
} => {
let predicates_string = predicates
.iter()
.map(|p| p.to_string())
.collect::<Vec<String>>()
.join(" AND ");
writeln!(f, "{}FILTER {}", indent, predicates_string)?;
fmt_operator(source, f, level + 1, true)
}
Operator::Limit { source, limit, .. } => {
writeln!(f, "{}TAKE {}", indent, limit)?;
fmt_operator(source, f, level + 1, true)
}
Operator::Join {
SourceOperator::Join {
left,
right,
predicates,
@@ -408,35 +259,7 @@ impl Display for Operator {
fmt_operator(left, f, level + 1, false)?;
fmt_operator(right, f, level + 1, true)
}
Operator::Order { source, key, .. } => {
let sort_keys_string = key
.iter()
.map(|(expr, dir)| format!("{} {}", expr, dir))
.collect::<Vec<String>>()
.join(", ");
writeln!(f, "{}SORT {}", indent, sort_keys_string)?;
fmt_operator(source, f, level + 1, true)
}
Operator::Projection {
source,
expressions,
..
} => {
let expressions = expressions
.iter()
.map(|expr| match expr {
ProjectionColumn::Column(c) => c.to_string(),
ProjectionColumn::Star => "*".to_string(),
ProjectionColumn::TableStar(table_reference) => {
format!("{}.{}", table_reference.table_identifier, "*")
}
})
.collect::<Vec<String>>()
.join(", ");
writeln!(f, "{}PROJECT {}", indent, expressions)?;
fmt_operator(source, f, level + 1, true)
}
Operator::Scan {
SourceOperator::Scan {
table_reference,
predicates: filter,
..
@@ -464,7 +287,7 @@ impl Display for Operator {
}?;
Ok(())
}
Operator::Search {
SourceOperator::Search {
table_reference,
search,
..
@@ -487,7 +310,7 @@ impl Display for Operator {
}
Ok(())
}
Operator::Nothing => Ok(()),
SourceOperator::Nothing => Ok(()),
}
}
writeln!(f, "QUERY PLAN")?;
@@ -505,35 +328,15 @@ impl Display for Operator {
*/
pub fn get_table_ref_bitmask_for_operator<'a>(
tables: &'a Vec<BTreeTableReference>,
operator: &'a Operator,
operator: &'a SourceOperator,
) -> Result<usize> {
let mut table_refs_mask = 0;
match operator {
Operator::Aggregate { source, .. } => {
table_refs_mask |= get_table_ref_bitmask_for_operator(tables, source)?;
}
Operator::Filter {
source, predicates, ..
} => {
table_refs_mask |= get_table_ref_bitmask_for_operator(tables, source)?;
for predicate in predicates {
table_refs_mask |= get_table_ref_bitmask_for_ast_expr(tables, predicate)?;
}
}
Operator::Limit { source, .. } => {
table_refs_mask |= get_table_ref_bitmask_for_operator(tables, source)?;
}
Operator::Join { left, right, .. } => {
SourceOperator::Join { left, right, .. } => {
table_refs_mask |= get_table_ref_bitmask_for_operator(tables, left)?;
table_refs_mask |= get_table_ref_bitmask_for_operator(tables, right)?;
}
Operator::Order { source, .. } => {
table_refs_mask |= get_table_ref_bitmask_for_operator(tables, source)?;
}
Operator::Projection { source, .. } => {
table_refs_mask |= get_table_ref_bitmask_for_operator(tables, source)?;
}
Operator::Scan {
SourceOperator::Scan {
table_reference, ..
} => {
table_refs_mask |= 1
@@ -542,7 +345,7 @@ pub fn get_table_ref_bitmask_for_operator<'a>(
.position(|t| Rc::ptr_eq(&t.table, &table_reference.table))
.unwrap();
}
Operator::Search {
SourceOperator::Search {
table_reference, ..
} => {
table_refs_mask |= 1
@@ -551,7 +354,7 @@ pub fn get_table_ref_bitmask_for_operator<'a>(
.position(|t| Rc::ptr_eq(&t.table, &table_reference.table))
.unwrap();
}
Operator::Nothing => {}
SourceOperator::Nothing => {}
}
Ok(table_refs_mask)
}

View File

@@ -1,4 +1,6 @@
use super::plan::{Aggregate, BTreeTableReference, Direction, Operator, Plan, ProjectionColumn};
use super::plan::{
Aggregate, BTreeTableReference, Direction, Plan, ResultSetColumn, SourceOperator,
};
use crate::{function::Func, schema::Schema, util::normalize_ident, Result};
use sqlite3_parser::ast::{self, FromClause, JoinType, ResultColumn};
@@ -66,6 +68,7 @@ fn bind_column_references(
referenced_tables: &[BTreeTableReference],
) -> Result<()> {
match expr {
ast::Expr::AggRef { .. } => unreachable!(),
ast::Expr::Id(id) => {
let mut match_result = None;
for (tbl_idx, table) in referenced_tables.iter().enumerate() {
@@ -237,146 +240,157 @@ pub fn prepare_select_plan<'a>(schema: &Schema, select: ast::Select) -> Result<P
let mut operator_id_counter = OperatorIdCounter::new();
// Parse the FROM clause
let (mut operator, referenced_tables) =
let (mut source, referenced_tables) =
parse_from(schema, from, &mut operator_id_counter)?;
let mut plan = Plan {
source,
result_columns: vec![],
where_clause: None,
group_by: None,
order_by: None,
aggregates: None,
limit: None,
referenced_tables,
available_indexes: schema.indexes.clone().into_values().flatten().collect(),
};
// Parse the WHERE clause
if let Some(w) = where_clause {
let mut predicates = vec![];
break_predicate_at_and_boundaries(w, &mut predicates);
for expr in predicates.iter_mut() {
bind_column_references(expr, &referenced_tables)?;
bind_column_references(expr, &plan.referenced_tables)?;
}
operator = Operator::Filter {
source: Box::new(operator),
predicates,
id: operator_id_counter.get_next_id(),
};
plan.where_clause = Some(predicates);
}
// If there are aggregate functions, we aggregate + project the columns.
// If there are no aggregate functions, we can simply project the columns.
// For a simple SELECT *, the projection operator is skipped as well.
let is_select_star = col_count == 1 && matches!(columns[0], ast::ResultColumn::Star);
if !is_select_star {
let mut aggregate_expressions = Vec::new();
let mut projection_expressions = Vec::with_capacity(col_count);
for column in columns.clone() {
match column {
ast::ResultColumn::Star => {
projection_expressions.push(ProjectionColumn::Star);
}
ast::ResultColumn::TableStar(name) => {
let name_normalized = normalize_ident(name.0.as_str());
let referenced_table = referenced_tables
.iter()
.find(|t| t.table_identifier == name_normalized);
if referenced_table.is_none() {
crate::bail_parse_error!("Table {} not found", name.0);
let mut aggregate_expressions = Vec::new();
for column in columns.clone() {
match column {
ast::ResultColumn::Star => {
for table_reference in plan.referenced_tables.iter() {
for (idx, col) in table_reference.table.columns.iter().enumerate() {
plan.result_columns.push(ResultSetColumn::Scalar(
ast::Expr::Column {
database: None, // TODO: support different databases
table: table_reference.table_index,
column: idx,
is_primary_key: col.primary_key,
},
));
}
let table_reference = referenced_table.unwrap();
projection_expressions
.push(ProjectionColumn::TableStar(table_reference.clone()));
}
ast::ResultColumn::Expr(mut expr, _) => {
bind_column_references(&mut expr, &referenced_tables)?;
projection_expressions.push(ProjectionColumn::Column(expr.clone()));
match expr.clone() {
ast::Expr::FunctionCall {
name,
distinctness: _,
args,
filter_over: _,
order_by: _,
} => {
let args_count = if let Some(args) = &args {
args.len()
} else {
0
};
match Func::resolve_function(
normalize_ident(name.0.as_str()).as_str(),
args_count,
) {
Ok(Func::Agg(f)) => {
aggregate_expressions.push(Aggregate {
func: f,
args: args.unwrap(),
original_expr: expr.clone(),
});
}
Ok(_) => {
resolve_aggregates(&expr, &mut aggregate_expressions);
}
_ => {}
}
}
ast::Expr::FunctionCallStar {
name,
filter_over: _,
} => {
if let Ok(Func::Agg(f)) = Func::resolve_function(
normalize_ident(name.0.as_str()).as_str(),
0,
) {
aggregate_expressions.push(Aggregate {
}
ast::ResultColumn::TableStar(name) => {
let name_normalized = normalize_ident(name.0.as_str());
let referenced_table = plan
.referenced_tables
.iter()
.find(|t| t.table_identifier == name_normalized);
if referenced_table.is_none() {
crate::bail_parse_error!("Table {} not found", name.0);
}
let table_reference = referenced_table.unwrap();
for (idx, col) in table_reference.table.columns.iter().enumerate() {
plan.result_columns
.push(ResultSetColumn::Scalar(ast::Expr::Column {
database: None, // TODO: support different databases
table: table_reference.table_index,
column: idx,
is_primary_key: col.primary_key,
}));
}
}
ast::ResultColumn::Expr(mut expr, _) => {
bind_column_references(&mut expr, &plan.referenced_tables)?;
match &expr {
ast::Expr::FunctionCall {
name,
distinctness: _,
args,
filter_over: _,
order_by: _,
} => {
let args_count = if let Some(args) = &args {
args.len()
} else {
0
};
match Func::resolve_function(
normalize_ident(name.0.as_str()).as_str(),
args_count,
) {
Ok(Func::Agg(f)) => {
let agg = Aggregate {
func: f,
args: vec![ast::Expr::Literal(ast::Literal::Numeric(
"1".to_string(),
))],
args: args.as_ref().unwrap().clone(),
original_expr: expr.clone(),
});
};
aggregate_expressions.push(agg.clone());
plan.result_columns.push(ResultSetColumn::Agg(agg));
}
Ok(_) => {
resolve_aggregates(&expr, &mut aggregate_expressions);
}
_ => {}
}
ast::Expr::Binary(lhs, _, rhs) => {
resolve_aggregates(&lhs, &mut aggregate_expressions);
resolve_aggregates(&rhs, &mut aggregate_expressions);
}
ast::Expr::FunctionCallStar {
name,
filter_over: _,
} => {
if let Ok(Func::Agg(f)) = Func::resolve_function(
normalize_ident(name.0.as_str()).as_str(),
0,
) {
let agg = Aggregate {
func: f,
args: vec![ast::Expr::Literal(ast::Literal::Numeric(
"1".to_string(),
))],
original_expr: expr.clone(),
};
aggregate_expressions.push(agg.clone());
plan.result_columns.push(ResultSetColumn::Agg(agg));
} else {
crate::bail_parse_error!(
"Invalid aggregate function: {}",
name.0
);
}
_ => {}
}
ast::Expr::Binary(lhs, _, rhs) => {
resolve_aggregates(&lhs, &mut aggregate_expressions);
resolve_aggregates(&rhs, &mut aggregate_expressions);
plan.result_columns
.push(ResultSetColumn::Scalar(expr.clone()));
}
e => {
plan.result_columns.push(ResultSetColumn::Scalar(e.clone()));
}
}
}
}
if let Some(group_by) = group_by.as_mut() {
for expr in group_by.exprs.iter_mut() {
bind_column_references(expr, &referenced_tables)?;
}
if aggregate_expressions.is_empty() {
crate::bail_parse_error!(
"GROUP BY clause without aggregate functions is not allowed"
);
}
for scalar in projection_expressions.iter() {
match scalar {
ProjectionColumn::Column(_) => {}
_ => {
crate::bail_parse_error!(
"Only column references are allowed in the SELECT clause when using GROUP BY"
);
}
}
}
}
if !aggregate_expressions.is_empty() {
operator = Operator::Aggregate {
source: Box::new(operator),
aggregates: aggregate_expressions,
group_by: group_by.map(|g| g.exprs), // TODO: support HAVING
id: operator_id_counter.get_next_id(),
step: 0,
}
}
if !projection_expressions.is_empty() {
operator = Operator::Projection {
source: Box::new(operator),
expressions: projection_expressions,
id: operator_id_counter.get_next_id(),
step: 0,
};
}
}
if let Some(group_by) = group_by.as_mut() {
for expr in group_by.exprs.iter_mut() {
bind_column_references(expr, &plan.referenced_tables)?;
}
if aggregate_expressions.is_empty() {
crate::bail_parse_error!(
"GROUP BY clause without aggregate functions is not allowed"
);
}
}
plan.group_by = group_by.map(|g| g.exprs);
plan.aggregates = if aggregate_expressions.is_empty() {
None
} else {
Some(aggregate_expressions)
};
// Parse the ORDER BY clause
if let Some(order_by) = select.order_by {
@@ -402,7 +416,7 @@ pub fn prepare_select_plan<'a>(schema: &Schema, select: ast::Select) -> Result<P
o.expr
};
bind_column_references(&mut expr, &referenced_tables)?;
bind_column_references(&mut expr, &plan.referenced_tables)?;
key.push((
expr,
@@ -412,40 +426,22 @@ pub fn prepare_select_plan<'a>(schema: &Schema, select: ast::Select) -> Result<P
}),
));
}
operator = Operator::Order {
source: Box::new(operator),
key,
id: operator_id_counter.get_next_id(),
step: 0,
};
plan.order_by = Some(key);
}
// Parse the LIMIT clause
if let Some(limit) = &select.limit {
operator = match &limit.expr {
plan.limit = match &limit.expr {
ast::Expr::Literal(ast::Literal::Numeric(n)) => {
let l = n.parse()?;
if l == 0 {
Operator::Nothing
} else {
Operator::Limit {
source: Box::new(operator),
limit: l,
id: operator_id_counter.get_next_id(),
step: 0,
}
}
Some(l)
}
_ => todo!(),
}
}
// Return the unoptimized query plan
Ok(Plan {
root_operator: operator,
referenced_tables,
available_indexes: schema.indexes.clone().into_values().flatten().collect(),
})
Ok(plan)
}
_ => todo!(),
}
@@ -456,9 +452,9 @@ fn parse_from(
schema: &Schema,
from: Option<FromClause>,
operator_id_counter: &mut OperatorIdCounter,
) -> Result<(Operator, Vec<BTreeTableReference>)> {
) -> Result<(SourceOperator, Vec<BTreeTableReference>)> {
if from.as_ref().and_then(|f| f.select.as_ref()).is_none() {
return Ok((Operator::Nothing, vec![]));
return Ok((SourceOperator::Nothing, vec![]));
}
let from = from.unwrap();
@@ -484,11 +480,10 @@ fn parse_from(
_ => todo!(),
};
let mut operator = Operator::Scan {
let mut operator = SourceOperator::Scan {
table_reference: first_table.clone(),
predicates: None,
id: operator_id_counter.get_next_id(),
step: 0,
iter_dir: None,
};
@@ -498,13 +493,12 @@ fn parse_from(
for join in from.joins.unwrap_or_default().into_iter() {
let (right, outer, predicates) =
parse_join(schema, join, operator_id_counter, &mut tables, table_index)?;
operator = Operator::Join {
operator = SourceOperator::Join {
left: Box::new(operator),
right: Box::new(right),
predicates,
outer,
id: operator_id_counter.get_next_id(),
step: 0,
};
table_index += 1;
}
@@ -518,7 +512,7 @@ fn parse_join(
operator_id_counter: &mut OperatorIdCounter,
tables: &mut Vec<BTreeTableReference>,
table_index: usize,
) -> Result<(Operator, bool, Option<Vec<ast::Expr>>)> {
) -> Result<(SourceOperator, bool, Option<Vec<ast::Expr>>)> {
let ast::JoinedSelectTable {
operator,
table,
@@ -574,11 +568,10 @@ fn parse_join(
}
Ok((
Operator::Scan {
SourceOperator::Scan {
table_reference: table.clone(),
predicates: None,
id: operator_id_counter.get_next_id(),
step: 0,
iter_dir: None,
},
outer,

View File

@@ -17,12 +17,7 @@ pub fn translate_select(
connection: Weak<Connection>,
) -> Result<Program> {
let select_plan = prepare_select_plan(schema, select)?;
let (optimized_plan, expr_result_cache) = optimize_plan(select_plan)?;
println!("{:?}", expr_result_cache);
emit_program(
database_header,
optimized_plan,
expr_result_cache,
connection,
)
let optimized_plan = optimize_plan(select_plan)?;
// println!("optimized_plan: {:?}", optimized_plan);
emit_program(database_header, optimized_plan, connection)
}

View File

@@ -638,6 +638,7 @@ impl ToTokens for Expr {
}
Self::Id(id) => id.to_tokens(s),
Self::Column { .. } => Ok(()),
Self::AggRef { .. } => Ok(()),
Self::InList { lhs, not, rhs } => {
lhs.to_tokens(s)?;
if *not {

View File

@@ -338,6 +338,11 @@ pub enum Expr {
/// is the column a primary key
is_primary_key: bool,
},
/// AggRef is a reference to a computed aggregate
AggRef {
/// index of the aggregate in the aggregates vector parsed from the query
index: usize,
},
/// `IN`
InList {
/// expression