Merge 'core: (another) refactor of read path query processing logic' from Jussi Saurio

# (another) refactor of read path query processing logic
This PR rewrites our select query processing architecture by moving away
from the stateful operator-based execution model, back to a more direct
bytecode generation approach that, IMO, is easier to follow. A large
part of the bytecode emission itself (`program.emit_insn(...)`) is just
copy-pasted from the old implementation (after all, it did _work_), but
just structured differently.
## Main Changes
1. Removed the `step()` state machine from operators. Previously, each
operator had internal state tracking its execution progress, and parent
operators would call `.step()` on their children until they needed to do
something else. Reading the code and trying to follow the execution was
not very easy, and the abstraction was also too general: there was a lot
of unnecessary pattern matching and special casing to make query
execution fit the model, when honestly the evaluation of a SELECT
without any CTEs or subqueries etc can only go a few different ways.
2. Because of the above change, the main codegen function
`emit_program()` now contains a series of linear conditional steps
instead of kicking off the state machines with `root_operator.step()`.
These steps are just things like: "open the cursors", "open the loops",
"emit a record into either the main output or a sorter", etc.
3. The `Plan` struct now (again) contains most of the familiar SELECT
query components (WHERE clause, GROUP BY, ORDER BY, etc.) rather than
having all of them embedded in a tree of operators. The operator tree
now ONLY consists of operators that read from a source table in some way
-- so it could just be called a join tree, I guess.
4. There's now `plan.result_columns` which is _ALWAYS_ evaluated to get
the final results of a SELECT. Previously the operator state machine
thing had a hodgepodge of different ways of arriving at the result row.
5. Removed operators:
   - Removed Filter operator (even in the previous version the Filter
operator -- which is really the where clause -- had its predicates
pushed down to the table loops, and it didn't really ever exist in the
bytecode emission phase anymore)
   - Removed Projection operator (`plan.result_columns`)
   - Removed Limit operator (`plan.limit`)
   - Removed Aggregate operator (`plan.group_by` and `plan.aggregates`)
   - Removed Order operator (`plan.order_by`)
6. Added `ast::Expr::Column` to the vendored sqlite3 parser -- column
resolution is now done as early as possible. This eliminates repeated
string comparisons during execution. I.e. no need for
`resolve_ident_table()` etc
7. Simplified expression result caching by removing the complex, and
frankly weird, ExpressionResultCache apparatus. The refactored code
handles this by tracking which cursor to read columns from at a given
time, and copies values from existing registers if the expression is a
computation that has already been done in a previous step of the
execution. For example in:
```
limbo> select concat(u.first_name, '-LOL'), sum(u.age) from users u group by concat(u.first_name, '-LOL') order by sum(u.age) desc limit 10;
Michael-LOL|11204
David-LOL|8758
Robert-LOL|8109
Jennifer-LOL|7700
John-LOL|7299
Christopher-LOL|6397
James-LOL|5921
Joseph-LOL|5711
Brian-LOL|5059
William-LOL|5047
```
the query execution engine knows that `concat(u.first_name, '-LOL')` is
the second column of the `ORDER_BY` sorter without any complex caching.
**HACK:** For deduplicating expressions in ORDER BY and the SELECT body,
the code still relies on expression `==` equality to make those
decisions which sucks (e.g. `sum(x) != SUM(x)` -- I've marked the parts
where this is used with a TODO, we should have a custom expression
equality comparison function instead...). This is not a correctness-
breaking thing, but still.
## In short
- No more state machines
- The operator tree is now only a "join tree", pretty much
- No weird general purpose `ExpressionResultCache`
- More direct mapping between SQL operations and generated bytecode --
there's really no harm in carrying the "group by" etc concepts in the
bytecode generation phase instead of burying them inside Operators
- When a ResultRow is emitted, it is _always_ done by evaluating
`plan.result_columns`, instead of the special-casing and hacks that
existed previously
- 600+ LOC removed

Closes #416
This commit is contained in:
jussisaurio
2024-11-30 10:03:58 +02:00
14 changed files with 2521 additions and 3194 deletions

View File

@@ -235,8 +235,8 @@ impl Connection {
Cmd::ExplainQueryPlan(stmt) => {
match stmt {
ast::Stmt::Select(select) => {
let plan = prepare_select_plan(&self.schema.borrow(), select)?;
let (plan, _) = optimize_plan(plan)?;
let plan = prepare_select_plan(&*self.schema.borrow(), select)?;
let plan = optimize_plan(plan)?;
println!("{}", plan);
}
_ => todo!(),

View File

@@ -90,7 +90,7 @@ impl Table {
None => None,
},
Table::Pseudo(table) => match table.columns.get(index) {
Some(column) => Some(&column.name),
Some(_) => None,
None => None,
},
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -98,7 +98,6 @@ pub fn translate_insert(
expr,
column_registers_start + col,
None,
None,
)?;
}
program.emit_insn(Insn::Yield {

File diff suppressed because it is too large Load Diff

View File

@@ -9,20 +9,41 @@ use sqlite3_parser::ast;
use crate::{
function::AggFunc,
schema::{BTreeTable, Index},
util::normalize_ident,
Result,
};
#[derive(Debug)]
pub struct ResultSetColumn {
pub expr: ast::Expr,
// TODO: encode which aggregates (e.g. index bitmask of plan.aggregates) are present in this column
pub contains_aggregates: bool,
}
#[derive(Debug)]
pub struct Plan {
pub root_operator: Operator,
/// A tree of sources (tables).
pub source: SourceOperator,
/// the columns inside SELECT ... FROM
pub result_columns: Vec<ResultSetColumn>,
/// where clause split into a vec at 'AND' boundaries.
pub where_clause: Option<Vec<ast::Expr>>,
/// group by clause
pub group_by: Option<Vec<ast::Expr>>,
/// order by clause
pub order_by: Option<Vec<(ast::Expr, Direction)>>,
/// all the aggregates collected from the result columns, order by, and (TODO) having clauses
pub aggregates: Option<Vec<Aggregate>>,
/// limit clause
pub limit: Option<usize>,
/// all the tables referenced in the query
pub referenced_tables: Vec<BTreeTableReference>,
/// all the indexes available
pub available_indexes: Vec<Rc<Index>>,
}
impl Display for Plan {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.root_operator)
write!(f, "{}", self.source)
}
}
@@ -33,82 +54,20 @@ pub enum IterationDirection {
}
/**
An Operator is a Node in the query plan.
Operators form a tree structure, with each having zero or more children.
For example, a query like `SELECT t1.foo FROM t1 ORDER BY t1.foo LIMIT 1` would have the following structure:
Limit
Order
Project
Scan
Operators also have a unique ID, which is used to identify them in the query plan and attach metadata.
They also have a step counter, which is used to track the current step in the operator's execution.
TODO: perhaps 'step' shouldn't be in this struct, since it's an execution time concept, not a plan time concept.
A SourceOperator is a Node in the query plan that reads data from a table.
*/
#[derive(Clone, Debug)]
pub enum Operator {
// Aggregate operator
// This operator is used to compute aggregate functions like SUM, AVG, COUNT, etc.
// It takes a source operator and a list of aggregate functions to compute.
// GROUP BY is not supported yet.
Aggregate {
id: usize,
source: Box<Operator>,
aggregates: Vec<Aggregate>,
group_by: Option<Vec<ast::Expr>>,
step: usize,
},
// Filter operator
// This operator is used to filter rows from the source operator.
// It takes a source operator and a list of predicates to evaluate.
// Only rows for which all predicates evaluate to true are passed to the next operator.
// Generally filter operators will only exist in unoptimized plans,
// as the optimizer will try to push filters down to the lowest possible level,
// e.g. a table scan.
Filter {
id: usize,
source: Box<Operator>,
predicates: Vec<ast::Expr>,
},
// Limit operator
// This operator is used to limit the number of rows returned by the source operator.
Limit {
id: usize,
source: Box<Operator>,
limit: usize,
step: usize,
},
pub enum SourceOperator {
// Join operator
// This operator is used to join two source operators.
// It takes a left and right source operator, a list of predicates to evaluate,
// and a boolean indicating whether it is an outer join.
Join {
id: usize,
left: Box<Operator>,
right: Box<Operator>,
left: Box<SourceOperator>,
right: Box<SourceOperator>,
predicates: Option<Vec<ast::Expr>>,
outer: bool,
step: usize,
},
// Order operator
// This operator is used to sort the rows returned by the source operator.
Order {
id: usize,
source: Box<Operator>,
key: Vec<(ast::Expr, Direction)>,
step: usize,
},
// Projection operator
// This operator is used to project columns from the source operator.
// It takes a source operator and a list of expressions to evaluate.
// e.g. SELECT foo, bar FROM t1
// In this example, the expressions would be [foo, bar]
// and the source operator would be a Scan operator for table t1.
Projection {
id: usize,
source: Box<Operator>,
expressions: Vec<ProjectionColumn>,
step: usize,
},
// Scan operator
// This operator is used to scan a table.
@@ -123,7 +82,6 @@ pub enum Operator {
id: usize,
table_reference: BTreeTableReference,
predicates: Option<Vec<ast::Expr>>,
step: usize,
iter_dir: Option<IterationDirection>,
},
// Search operator
@@ -134,7 +92,6 @@ pub enum Operator {
table_reference: BTreeTableReference,
search: Search,
predicates: Option<Vec<ast::Expr>>,
step: usize,
},
// Nothing operator
// This operator is used to represent an empty query.
@@ -146,6 +103,7 @@ pub enum Operator {
pub struct BTreeTableReference {
pub table: Rc<BTreeTable>,
pub table_identifier: String,
pub table_index: usize,
}
/// An enum that represents a search operation that can be used to search for a row in a table using an index
@@ -168,136 +126,13 @@ pub enum Search {
},
}
#[derive(Clone, Debug)]
pub enum ProjectionColumn {
Column(ast::Expr),
Star,
TableStar(BTreeTableReference),
}
impl ProjectionColumn {
pub fn column_count(&self, referenced_tables: &[BTreeTableReference]) -> usize {
match self {
ProjectionColumn::Column(_) => 1,
ProjectionColumn::Star => {
let mut count = 0;
for table_reference in referenced_tables {
count += table_reference.table.columns.len();
}
count
}
ProjectionColumn::TableStar(table_reference) => table_reference.table.columns.len(),
}
}
}
impl Operator {
pub fn column_count(&self, referenced_tables: &[BTreeTableReference]) -> usize {
match self {
Operator::Aggregate {
group_by,
aggregates,
..
} => aggregates.len() + group_by.as_ref().map_or(0, |g| g.len()),
Operator::Filter { source, .. } => source.column_count(referenced_tables),
Operator::Limit { source, .. } => source.column_count(referenced_tables),
Operator::Join { left, right, .. } => {
left.column_count(referenced_tables) + right.column_count(referenced_tables)
}
Operator::Order { source, .. } => source.column_count(referenced_tables),
Operator::Projection { expressions, .. } => expressions
.iter()
.map(|e| e.column_count(referenced_tables))
.sum(),
Operator::Scan {
table_reference, ..
} => table_reference.table.columns.len(),
Operator::Search {
table_reference, ..
} => table_reference.table.columns.len(),
Operator::Nothing => 0,
}
}
pub fn column_names(&self) -> Vec<String> {
match self {
Operator::Aggregate {
aggregates,
group_by,
..
} => {
let mut names = vec![];
for agg in aggregates.iter() {
names.push(agg.func.to_string().to_string());
}
if let Some(group_by) = group_by {
for expr in group_by.iter() {
match expr {
ast::Expr::Id(ident) => names.push(ident.0.clone()),
ast::Expr::Qualified(tbl, ident) => {
names.push(format!("{}.{}", tbl.0, ident.0))
}
e => names.push(e.to_string()),
}
}
}
names
}
Operator::Filter { source, .. } => source.column_names(),
Operator::Limit { source, .. } => source.column_names(),
Operator::Join { left, right, .. } => {
let mut names = left.column_names();
names.extend(right.column_names());
names
}
Operator::Order { source, .. } => source.column_names(),
Operator::Projection { expressions, .. } => expressions
.iter()
.map(|e| match e {
ProjectionColumn::Column(expr) => match expr {
ast::Expr::Id(ident) => ident.0.clone(),
ast::Expr::Qualified(tbl, ident) => format!("{}.{}", tbl.0, ident.0),
_ => "expr".to_string(),
},
ProjectionColumn::Star => "*".to_string(),
ProjectionColumn::TableStar(table_reference) => {
format!("{}.{}", table_reference.table_identifier, "*")
}
})
.collect(),
Operator::Scan {
table_reference, ..
} => table_reference
.table
.columns
.iter()
.map(|c| c.name.clone())
.collect(),
Operator::Search {
table_reference, ..
} => table_reference
.table
.columns
.iter()
.map(|c| c.name.clone())
.collect(),
Operator::Nothing => vec![],
}
}
impl SourceOperator {
pub fn id(&self) -> usize {
match self {
Operator::Aggregate { id, .. } => *id,
Operator::Filter { id, .. } => *id,
Operator::Limit { id, .. } => *id,
Operator::Join { id, .. } => *id,
Operator::Order { id, .. } => *id,
Operator::Projection { id, .. } => *id,
Operator::Scan { id, .. } => *id,
Operator::Search { id, .. } => *id,
Operator::Nothing => unreachable!(),
SourceOperator::Join { id, .. } => *id,
SourceOperator::Scan { id, .. } => *id,
SourceOperator::Search { id, .. } => *id,
SourceOperator::Nothing => unreachable!(),
}
}
}
@@ -337,10 +172,10 @@ impl Display for Aggregate {
}
// For EXPLAIN QUERY PLAN
impl Display for Operator {
impl Display for SourceOperator {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
fn fmt_operator(
operator: &Operator,
operator: &SourceOperator,
f: &mut Formatter,
level: usize,
last: bool,
@@ -356,34 +191,7 @@ impl Display for Operator {
};
match operator {
Operator::Aggregate {
source, aggregates, ..
} => {
// e.g. Aggregate count(*), sum(x)
let aggregates_display_string = aggregates
.iter()
.map(|agg| agg.to_string())
.collect::<Vec<String>>()
.join(", ");
writeln!(f, "{}AGGREGATE {}", indent, aggregates_display_string)?;
fmt_operator(source, f, level + 1, true)
}
Operator::Filter {
source, predicates, ..
} => {
let predicates_string = predicates
.iter()
.map(|p| p.to_string())
.collect::<Vec<String>>()
.join(" AND ");
writeln!(f, "{}FILTER {}", indent, predicates_string)?;
fmt_operator(source, f, level + 1, true)
}
Operator::Limit { source, limit, .. } => {
writeln!(f, "{}TAKE {}", indent, limit)?;
fmt_operator(source, f, level + 1, true)
}
Operator::Join {
SourceOperator::Join {
left,
right,
predicates,
@@ -408,35 +216,7 @@ impl Display for Operator {
fmt_operator(left, f, level + 1, false)?;
fmt_operator(right, f, level + 1, true)
}
Operator::Order { source, key, .. } => {
let sort_keys_string = key
.iter()
.map(|(expr, dir)| format!("{} {}", expr, dir))
.collect::<Vec<String>>()
.join(", ");
writeln!(f, "{}SORT {}", indent, sort_keys_string)?;
fmt_operator(source, f, level + 1, true)
}
Operator::Projection {
source,
expressions,
..
} => {
let expressions = expressions
.iter()
.map(|expr| match expr {
ProjectionColumn::Column(c) => c.to_string(),
ProjectionColumn::Star => "*".to_string(),
ProjectionColumn::TableStar(table_reference) => {
format!("{}.{}", table_reference.table_identifier, "*")
}
})
.collect::<Vec<String>>()
.join(", ");
writeln!(f, "{}PROJECT {}", indent, expressions)?;
fmt_operator(source, f, level + 1, true)
}
Operator::Scan {
SourceOperator::Scan {
table_reference,
predicates: filter,
..
@@ -464,7 +244,7 @@ impl Display for Operator {
}?;
Ok(())
}
Operator::Search {
SourceOperator::Search {
table_reference,
search,
..
@@ -487,7 +267,7 @@ impl Display for Operator {
}
Ok(())
}
Operator::Nothing => Ok(()),
SourceOperator::Nothing => Ok(()),
}
}
writeln!(f, "QUERY PLAN")?;
@@ -505,35 +285,15 @@ impl Display for Operator {
*/
pub fn get_table_ref_bitmask_for_operator<'a>(
tables: &'a Vec<BTreeTableReference>,
operator: &'a Operator,
operator: &'a SourceOperator,
) -> Result<usize> {
let mut table_refs_mask = 0;
match operator {
Operator::Aggregate { source, .. } => {
table_refs_mask |= get_table_ref_bitmask_for_operator(tables, source)?;
}
Operator::Filter {
source, predicates, ..
} => {
table_refs_mask |= get_table_ref_bitmask_for_operator(tables, source)?;
for predicate in predicates {
table_refs_mask |= get_table_ref_bitmask_for_ast_expr(tables, predicate)?;
}
}
Operator::Limit { source, .. } => {
table_refs_mask |= get_table_ref_bitmask_for_operator(tables, source)?;
}
Operator::Join { left, right, .. } => {
SourceOperator::Join { left, right, .. } => {
table_refs_mask |= get_table_ref_bitmask_for_operator(tables, left)?;
table_refs_mask |= get_table_ref_bitmask_for_operator(tables, right)?;
}
Operator::Order { source, .. } => {
table_refs_mask |= get_table_ref_bitmask_for_operator(tables, source)?;
}
Operator::Projection { source, .. } => {
table_refs_mask |= get_table_ref_bitmask_for_operator(tables, source)?;
}
Operator::Scan {
SourceOperator::Scan {
table_reference, ..
} => {
table_refs_mask |= 1
@@ -542,7 +302,7 @@ pub fn get_table_ref_bitmask_for_operator<'a>(
.position(|t| Rc::ptr_eq(&t.table, &table_reference.table))
.unwrap();
}
Operator::Search {
SourceOperator::Search {
table_reference, ..
} => {
table_refs_mask |= 1
@@ -551,7 +311,7 @@ pub fn get_table_ref_bitmask_for_operator<'a>(
.position(|t| Rc::ptr_eq(&t.table, &table_reference.table))
.unwrap();
}
Operator::Nothing => {}
SourceOperator::Nothing => {}
}
Ok(table_refs_mask)
}
@@ -574,46 +334,12 @@ pub fn get_table_ref_bitmask_for_ast_expr<'a>(
table_refs_mask |= get_table_ref_bitmask_for_ast_expr(tables, e1)?;
table_refs_mask |= get_table_ref_bitmask_for_ast_expr(tables, e2)?;
}
ast::Expr::Id(ident) => {
let ident = normalize_ident(&ident.0);
let matching_tables = tables
.iter()
.enumerate()
.filter(|(_, table_reference)| table_reference.table.get_column(&ident).is_some());
let mut matches = 0;
let mut matching_tbl = None;
for table in matching_tables {
matching_tbl = Some(table);
matches += 1;
if matches > 1 {
crate::bail_parse_error!("ambiguous column name {}", &ident)
}
}
if let Some((tbl_index, _)) = matching_tbl {
table_refs_mask |= 1 << tbl_index;
} else {
crate::bail_parse_error!("column not found: {}", &ident)
}
ast::Expr::Column { table, .. } => {
table_refs_mask |= 1 << table;
}
ast::Expr::Qualified(tbl, ident) => {
let tbl = normalize_ident(&tbl.0);
let ident = normalize_ident(&ident.0);
let matching_table = tables
.iter()
.enumerate()
.find(|(_, t)| t.table_identifier == tbl);
if matching_table.is_none() {
crate::bail_parse_error!("introspect: table not found: {}", &tbl)
}
let (table_index, table_reference) = matching_table.unwrap();
if table_reference.table.get_column(&ident).is_none() {
crate::bail_parse_error!("column with qualified name {}.{} not found", &tbl, &ident)
}
table_refs_mask |= 1 << table_index;
ast::Expr::Id(_) => unreachable!("Id should be resolved to a Column before optimizer"),
ast::Expr::Qualified(_, _) => {
unreachable!("Qualified should be resolved to a Column before optimizer")
}
ast::Expr::Literal(_) => {}
ast::Expr::Like { lhs, rhs, .. } => {

View File

@@ -1,4 +1,6 @@
use super::plan::{Aggregate, BTreeTableReference, Direction, Operator, Plan, ProjectionColumn};
use super::plan::{
Aggregate, BTreeTableReference, Direction, Plan, ResultSetColumn, SourceOperator,
};
use crate::{function::Func, schema::Schema, util::normalize_ident, Result};
use sqlite3_parser::ast::{self, FromClause, JoinType, ResultColumn};
@@ -18,6 +20,9 @@ impl OperatorIdCounter {
}
fn resolve_aggregates(expr: &ast::Expr, aggs: &mut Vec<Aggregate>) {
if aggs.iter().any(|a| a.original_expr == *expr) {
return;
}
match expr {
ast::Expr::FunctionCall { name, args, .. } => {
let args_count = if let Some(args) = &args {
@@ -55,10 +60,171 @@ fn resolve_aggregates(expr: &ast::Expr, aggs: &mut Vec<Aggregate>) {
resolve_aggregates(lhs, aggs);
resolve_aggregates(rhs, aggs);
}
// TODO: handle other expressions that may contain aggregates
_ => {}
}
}
/// Recursively resolve column references in an expression.
/// Id, Qualified and DoublyQualified are converted to Column.
fn bind_column_references(
expr: &mut ast::Expr,
referenced_tables: &[BTreeTableReference],
) -> Result<()> {
match expr {
ast::Expr::Id(id) => {
let mut match_result = None;
for (tbl_idx, table) in referenced_tables.iter().enumerate() {
let col_idx = table
.table
.columns
.iter()
.position(|c| c.name.eq_ignore_ascii_case(&id.0));
if col_idx.is_some() {
if match_result.is_some() {
crate::bail_parse_error!("Column {} is ambiguous", id.0);
}
let col = table.table.columns.get(col_idx.unwrap()).unwrap();
match_result = Some((tbl_idx, col_idx.unwrap(), col.primary_key));
}
}
if match_result.is_none() {
crate::bail_parse_error!("Column {} not found", id.0);
}
let (tbl_idx, col_idx, is_primary_key) = match_result.unwrap();
*expr = ast::Expr::Column {
database: None, // TODO: support different databases
table: tbl_idx,
column: col_idx,
is_rowid_alias: is_primary_key,
};
Ok(())
}
ast::Expr::Qualified(tbl, id) => {
let matching_tbl_idx = referenced_tables
.iter()
.position(|t| t.table_identifier.eq_ignore_ascii_case(&tbl.0));
if matching_tbl_idx.is_none() {
crate::bail_parse_error!("Table {} not found", tbl.0);
}
let tbl_idx = matching_tbl_idx.unwrap();
let col_idx = referenced_tables[tbl_idx]
.table
.columns
.iter()
.position(|c| c.name.eq_ignore_ascii_case(&id.0));
if col_idx.is_none() {
crate::bail_parse_error!("Column {} not found", id.0);
}
let col = referenced_tables[tbl_idx]
.table
.columns
.get(col_idx.unwrap())
.unwrap();
*expr = ast::Expr::Column {
database: None, // TODO: support different databases
table: tbl_idx,
column: col_idx.unwrap(),
is_rowid_alias: col.primary_key,
};
Ok(())
}
ast::Expr::Between {
lhs,
not: _,
start,
end,
} => {
bind_column_references(lhs, referenced_tables)?;
bind_column_references(start, referenced_tables)?;
bind_column_references(end, referenced_tables)?;
Ok(())
}
ast::Expr::Binary(expr, _operator, expr1) => {
bind_column_references(expr, referenced_tables)?;
bind_column_references(expr1, referenced_tables)?;
Ok(())
}
ast::Expr::Case {
base,
when_then_pairs,
else_expr,
} => {
if let Some(base) = base {
bind_column_references(base, referenced_tables)?;
}
for (when, then) in when_then_pairs {
bind_column_references(when, referenced_tables)?;
bind_column_references(then, referenced_tables)?;
}
if let Some(else_expr) = else_expr {
bind_column_references(else_expr, referenced_tables)?;
}
Ok(())
}
ast::Expr::Cast { expr, type_name: _ } => bind_column_references(expr, referenced_tables),
ast::Expr::Collate(expr, _string) => bind_column_references(expr, referenced_tables),
ast::Expr::FunctionCall {
name: _,
distinctness: _,
args,
order_by: _,
filter_over: _,
} => {
if let Some(args) = args {
for arg in args {
bind_column_references(arg, referenced_tables)?;
}
}
Ok(())
}
// Column references cannot exist before binding
ast::Expr::Column { .. } => unreachable!(),
ast::Expr::DoublyQualified(_, _, _) => todo!(),
ast::Expr::Exists(_) => todo!(),
ast::Expr::FunctionCallStar { .. } => Ok(()),
ast::Expr::InList { lhs, not: _, rhs } => {
bind_column_references(lhs, referenced_tables)?;
if let Some(rhs) = rhs {
for arg in rhs {
bind_column_references(arg, referenced_tables)?;
}
}
Ok(())
}
ast::Expr::InSelect { .. } => todo!(),
ast::Expr::InTable { .. } => todo!(),
ast::Expr::IsNull(expr) => {
bind_column_references(expr, referenced_tables)?;
Ok(())
}
ast::Expr::Like { lhs, rhs, .. } => {
bind_column_references(lhs, referenced_tables)?;
bind_column_references(rhs, referenced_tables)?;
Ok(())
}
ast::Expr::Literal(_) => Ok(()),
ast::Expr::Name(_) => todo!(),
ast::Expr::NotNull(expr) => {
bind_column_references(expr, referenced_tables)?;
Ok(())
}
ast::Expr::Parenthesized(expr) => {
for e in expr.iter_mut() {
bind_column_references(e, referenced_tables)?;
}
Ok(())
}
ast::Expr::Raise(_, _) => todo!(),
ast::Expr::Subquery(_) => todo!(),
ast::Expr::Unary(_, expr) => {
bind_column_references(expr, referenced_tables)?;
Ok(())
}
ast::Expr::Variable(_) => todo!(),
}
}
#[allow(clippy::extra_unused_lifetimes)]
pub fn prepare_select_plan<'a>(schema: &Schema, select: ast::Select) -> Result<Plan> {
match select.body.select {
@@ -66,7 +232,7 @@ pub fn prepare_select_plan<'a>(schema: &Schema, select: ast::Select) -> Result<P
columns,
from,
where_clause,
group_by,
mut group_by,
..
} => {
let col_count = columns.len();
@@ -77,139 +243,173 @@ pub fn prepare_select_plan<'a>(schema: &Schema, select: ast::Select) -> Result<P
let mut operator_id_counter = OperatorIdCounter::new();
// Parse the FROM clause
let (mut operator, referenced_tables) =
parse_from(schema, from, &mut operator_id_counter)?;
let (source, referenced_tables) = parse_from(schema, from, &mut operator_id_counter)?;
let mut plan = Plan {
source,
result_columns: vec![],
where_clause: None,
group_by: None,
order_by: None,
aggregates: None,
limit: None,
referenced_tables,
available_indexes: schema.indexes.clone().into_values().flatten().collect(),
};
// Parse the WHERE clause
if let Some(w) = where_clause {
let mut predicates = vec![];
break_predicate_at_and_boundaries(w, &mut predicates);
operator = Operator::Filter {
source: Box::new(operator),
predicates,
id: operator_id_counter.get_next_id(),
};
for expr in predicates.iter_mut() {
bind_column_references(expr, &plan.referenced_tables)?;
}
plan.where_clause = Some(predicates);
}
// If there are aggregate functions, we aggregate + project the columns.
// If there are no aggregate functions, we can simply project the columns.
// For a simple SELECT *, the projection operator is skipped as well.
let is_select_star = col_count == 1 && matches!(columns[0], ast::ResultColumn::Star);
if !is_select_star {
let mut aggregate_expressions = Vec::new();
let mut projection_expressions = Vec::with_capacity(col_count);
for column in columns.clone() {
match column {
ast::ResultColumn::Star => {
projection_expressions.push(ProjectionColumn::Star);
}
ast::ResultColumn::TableStar(name) => {
let name_normalized = normalize_ident(name.0.as_str());
let referenced_table = referenced_tables
.iter()
.find(|t| t.table_identifier == name_normalized);
if referenced_table.is_none() {
crate::bail_parse_error!("Table {} not found", name.0);
let mut aggregate_expressions = Vec::new();
for column in columns.clone() {
match column {
ast::ResultColumn::Star => {
for table_reference in plan.referenced_tables.iter() {
for (idx, col) in table_reference.table.columns.iter().enumerate() {
plan.result_columns.push(ResultSetColumn {
expr: ast::Expr::Column {
database: None, // TODO: support different databases
table: table_reference.table_index,
column: idx,
is_rowid_alias: col.primary_key,
},
contains_aggregates: false,
});
}
let table_reference = referenced_table.unwrap();
projection_expressions
.push(ProjectionColumn::TableStar(table_reference.clone()));
}
ast::ResultColumn::Expr(expr, _) => {
projection_expressions.push(ProjectionColumn::Column(expr.clone()));
match expr.clone() {
ast::Expr::FunctionCall {
name,
distinctness: _,
args,
filter_over: _,
order_by: _,
} => {
let args_count = if let Some(args) = &args {
args.len()
} else {
0
};
match Func::resolve_function(
normalize_ident(name.0.as_str()).as_str(),
args_count,
) {
Ok(Func::Agg(f)) => {
aggregate_expressions.push(Aggregate {
func: f,
args: args.unwrap(),
original_expr: expr.clone(),
});
}
Ok(_) => {
resolve_aggregates(&expr, &mut aggregate_expressions);
}
_ => {}
}
}
ast::Expr::FunctionCallStar {
name,
filter_over: _,
} => {
if let Ok(Func::Agg(f)) = Func::resolve_function(
normalize_ident(name.0.as_str()).as_str(),
0,
) {
aggregate_expressions.push(Aggregate {
}
ast::ResultColumn::TableStar(name) => {
let name_normalized = normalize_ident(name.0.as_str());
let referenced_table = plan
.referenced_tables
.iter()
.find(|t| t.table_identifier == name_normalized);
if referenced_table.is_none() {
crate::bail_parse_error!("Table {} not found", name.0);
}
let table_reference = referenced_table.unwrap();
for (idx, col) in table_reference.table.columns.iter().enumerate() {
plan.result_columns.push(ResultSetColumn {
expr: ast::Expr::Column {
database: None, // TODO: support different databases
table: table_reference.table_index,
column: idx,
is_rowid_alias: col.primary_key,
},
contains_aggregates: false,
});
}
}
ast::ResultColumn::Expr(mut expr, _) => {
bind_column_references(&mut expr, &plan.referenced_tables)?;
match &expr {
ast::Expr::FunctionCall {
name,
distinctness: _,
args,
filter_over: _,
order_by: _,
} => {
let args_count = if let Some(args) = &args {
args.len()
} else {
0
};
match Func::resolve_function(
normalize_ident(name.0.as_str()).as_str(),
args_count,
) {
Ok(Func::Agg(f)) => {
let agg = Aggregate {
func: f,
args: vec![ast::Expr::Literal(ast::Literal::Numeric(
"1".to_string(),
))],
args: args.as_ref().unwrap().clone(),
original_expr: expr.clone(),
};
aggregate_expressions.push(agg.clone());
plan.result_columns.push(ResultSetColumn {
expr: expr.clone(),
contains_aggregates: true,
});
}
Ok(_) => {
let cur_agg_count = aggregate_expressions.len();
resolve_aggregates(&expr, &mut aggregate_expressions);
let contains_aggregates =
cur_agg_count != aggregate_expressions.len();
plan.result_columns.push(ResultSetColumn {
expr: expr.clone(),
contains_aggregates,
});
}
_ => {}
}
ast::Expr::Binary(lhs, _, rhs) => {
resolve_aggregates(&lhs, &mut aggregate_expressions);
resolve_aggregates(&rhs, &mut aggregate_expressions);
}
ast::Expr::FunctionCallStar {
name,
filter_over: _,
} => {
if let Ok(Func::Agg(f)) = Func::resolve_function(
normalize_ident(name.0.as_str()).as_str(),
0,
) {
let agg = Aggregate {
func: f,
args: vec![ast::Expr::Literal(ast::Literal::Numeric(
"1".to_string(),
))],
original_expr: expr.clone(),
};
aggregate_expressions.push(agg.clone());
plan.result_columns.push(ResultSetColumn {
expr: expr.clone(),
contains_aggregates: true,
});
} else {
crate::bail_parse_error!(
"Invalid aggregate function: {}",
name.0
);
}
_ => {}
}
expr => {
let cur_agg_count = aggregate_expressions.len();
resolve_aggregates(expr, &mut aggregate_expressions);
let contains_aggregates =
cur_agg_count != aggregate_expressions.len();
plan.result_columns.push(ResultSetColumn {
expr: expr.clone(),
contains_aggregates,
});
}
}
}
}
if let Some(_group_by) = group_by.as_ref() {
if aggregate_expressions.is_empty() {
crate::bail_parse_error!(
"GROUP BY clause without aggregate functions is not allowed"
);
}
for scalar in projection_expressions.iter() {
match scalar {
ProjectionColumn::Column(_) => {}
_ => {
crate::bail_parse_error!(
"Only column references are allowed in the SELECT clause when using GROUP BY"
);
}
}
}
}
if !aggregate_expressions.is_empty() {
operator = Operator::Aggregate {
source: Box::new(operator),
aggregates: aggregate_expressions,
group_by: group_by.map(|g| g.exprs), // TODO: support HAVING
id: operator_id_counter.get_next_id(),
step: 0,
}
}
if !projection_expressions.is_empty() {
operator = Operator::Projection {
source: Box::new(operator),
expressions: projection_expressions,
id: operator_id_counter.get_next_id(),
step: 0,
};
}
}
if let Some(group_by) = group_by.as_mut() {
for expr in group_by.exprs.iter_mut() {
bind_column_references(expr, &plan.referenced_tables)?;
}
if aggregate_expressions.is_empty() {
crate::bail_parse_error!(
"GROUP BY clause without aggregate functions is not allowed"
);
}
}
plan.group_by = group_by.map(|g| g.exprs);
plan.aggregates = if aggregate_expressions.is_empty() {
None
} else {
Some(aggregate_expressions)
};
// Parse the ORDER BY clause
if let Some(order_by) = select.order_by {
@@ -218,7 +418,7 @@ pub fn prepare_select_plan<'a>(schema: &Schema, select: ast::Select) -> Result<P
for o in order_by {
// if the ORDER BY expression is a number, interpret it as an 1-indexed column number
// otherwise, interpret it normally as an expression
let expr = if let ast::Expr::Literal(ast::Literal::Numeric(num)) = &o.expr {
let mut expr = if let ast::Expr::Literal(ast::Literal::Numeric(num)) = &o.expr {
let column_number = num.parse::<usize>()?;
if column_number == 0 {
crate::bail_parse_error!("invalid column index: {}", column_number);
@@ -235,6 +435,11 @@ pub fn prepare_select_plan<'a>(schema: &Schema, select: ast::Select) -> Result<P
o.expr
};
bind_column_references(&mut expr, &plan.referenced_tables)?;
if let Some(aggs) = &mut plan.aggregates {
resolve_aggregates(&expr, aggs);
}
key.push((
expr,
o.order.map_or(Direction::Ascending, |o| match o {
@@ -243,40 +448,22 @@ pub fn prepare_select_plan<'a>(schema: &Schema, select: ast::Select) -> Result<P
}),
));
}
operator = Operator::Order {
source: Box::new(operator),
key,
id: operator_id_counter.get_next_id(),
step: 0,
};
plan.order_by = Some(key);
}
// Parse the LIMIT clause
if let Some(limit) = &select.limit {
operator = match &limit.expr {
plan.limit = match &limit.expr {
ast::Expr::Literal(ast::Literal::Numeric(n)) => {
let l = n.parse()?;
if l == 0 {
Operator::Nothing
} else {
Operator::Limit {
source: Box::new(operator),
limit: l,
id: operator_id_counter.get_next_id(),
step: 0,
}
}
Some(l)
}
_ => todo!(),
}
}
// Return the unoptimized query plan
Ok(Plan {
root_operator: operator,
referenced_tables,
available_indexes: schema.indexes.clone().into_values().flatten().collect(),
})
Ok(plan)
}
_ => todo!(),
}
@@ -287,9 +474,9 @@ fn parse_from(
schema: &Schema,
from: Option<FromClause>,
operator_id_counter: &mut OperatorIdCounter,
) -> Result<(Operator, Vec<BTreeTableReference>)> {
) -> Result<(SourceOperator, Vec<BTreeTableReference>)> {
if from.as_ref().and_then(|f| f.select.as_ref()).is_none() {
return Ok((Operator::Nothing, vec![]));
return Ok((SourceOperator::Nothing, vec![]));
}
let from = from.unwrap();
@@ -309,32 +496,33 @@ fn parse_from(
BTreeTableReference {
table: table.clone(),
table_identifier: alias.unwrap_or(qualified_name.name.0),
table_index: 0,
}
}
_ => todo!(),
};
let mut operator = Operator::Scan {
let mut operator = SourceOperator::Scan {
table_reference: first_table.clone(),
predicates: None,
id: operator_id_counter.get_next_id(),
step: 0,
iter_dir: None,
};
let mut tables = vec![first_table];
let mut table_index = 1;
for join in from.joins.unwrap_or_default().into_iter() {
let (right, outer, predicates) =
parse_join(schema, join, operator_id_counter, &mut tables)?;
operator = Operator::Join {
parse_join(schema, join, operator_id_counter, &mut tables, table_index)?;
operator = SourceOperator::Join {
left: Box::new(operator),
right: Box::new(right),
predicates,
outer,
id: operator_id_counter.get_next_id(),
step: 0,
}
};
table_index += 1;
}
Ok((operator, tables))
@@ -345,7 +533,8 @@ fn parse_join(
join: ast::JoinedSelectTable,
operator_id_counter: &mut OperatorIdCounter,
tables: &mut Vec<BTreeTableReference>,
) -> Result<(Operator, bool, Option<Vec<ast::Expr>>)> {
table_index: usize,
) -> Result<(SourceOperator, bool, Option<Vec<ast::Expr>>)> {
let ast::JoinedSelectTable {
operator,
table,
@@ -366,6 +555,7 @@ fn parse_join(
BTreeTableReference {
table: table.clone(),
table_identifier: alias.unwrap_or(qualified_name.name.0),
table_index,
}
}
_ => todo!(),
@@ -384,21 +574,26 @@ fn parse_join(
_ => false,
};
let predicates = constraint.map(|c| match c {
ast::JoinConstraint::On(expr) => {
let mut predicates = vec![];
break_predicate_at_and_boundaries(expr, &mut predicates);
predicates
let mut predicates = None;
if let Some(constraint) = constraint {
match constraint {
ast::JoinConstraint::On(expr) => {
let mut preds = vec![];
break_predicate_at_and_boundaries(expr, &mut preds);
for predicate in preds.iter_mut() {
bind_column_references(predicate, tables)?;
}
predicates = Some(preds);
}
ast::JoinConstraint::Using(_) => todo!("USING joins not supported yet"),
}
ast::JoinConstraint::Using(_) => todo!("USING joins not supported yet"),
});
}
Ok((
Operator::Scan {
SourceOperator::Scan {
table_reference: table.clone(),
predicates: None,
id: operator_id_counter.get_next_id(),
step: 0,
iter_dir: None,
},
outer,

View File

@@ -17,11 +17,6 @@ pub fn translate_select(
connection: Weak<Connection>,
) -> Result<Program> {
let select_plan = prepare_select_plan(schema, select)?;
let (optimized_plan, expr_result_cache) = optimize_plan(select_plan)?;
emit_program(
database_header,
optimized_plan,
expr_result_cache,
connection,
)
let optimized_plan = optimize_plan(select_plan)?;
emit_program(database_header, optimized_plan, connection)
}

View File

@@ -343,14 +343,7 @@ impl ProgramBuilder {
}
// translate table to cursor id
pub fn resolve_cursor_id(
&self,
table_identifier: &str,
cursor_hint: Option<CursorID>,
) -> CursorID {
if let Some(cursor_hint) = cursor_hint {
return cursor_hint;
}
pub fn resolve_cursor_id(&self, table_identifier: &str) -> CursorID {
self.cursor_ref
.iter()
.position(|(t_ident, _)| {
@@ -361,10 +354,6 @@ impl ProgramBuilder {
.unwrap()
}
pub fn resolve_cursor_to_table(&self, cursor_id: CursorID) -> Option<Table> {
self.cursor_ref[cursor_id].1.clone()
}
pub fn resolve_deferred_labels(&mut self) {
for i in 0..self.deferred_label_resolutions.len() {
let (label, insn_reference) = self.deferred_label_resolutions[i];

View File

@@ -15,6 +15,18 @@ do_execsql_test add-int-float {
SELECT 10 + 0.1
} {10.1}
do_execsql_test add-agg-int-agg-int {
SELECT sum(1) + sum(2)
} {3}
do_execsql_test add-agg-int-agg-float {
SELECT sum(1) + sum(2.5)
} {3.5}
do_execsql_test add-agg-float-agg-int {
SELECT sum(1.5) + sum(2)
} {3.5}
do_execsql_test subtract-int {
SELECT 10 - 1
} {9}
@@ -27,6 +39,18 @@ do_execsql_test subtract-int-float {
SELECT 10 - 0.1
} {9.9}
do_execsql_test subtract-agg-int-agg-int {
SELECT sum(3) - sum(1)
} {2}
do_execsql_test subtract-agg-int-agg-float {
SELECT sum(3) - sum(1.5)
} {1.5}
do_execsql_test subtract-agg-float-agg-int {
SELECT sum(3.5) - sum(1)
} {2.5}
do_execsql_test multiply-int {
SELECT 10 * 2
} {20}
@@ -43,6 +67,18 @@ do_execsql_test multiply-float-int {
SELECT 1.45 * 10
} {14.5}
do_execsql_test multiply-agg-int-agg-int {
SELECT sum(2) * sum(3)
} {6}
do_execsql_test multiply-agg-int-agg-float {
SELECT sum(2) * sum(3.5)
} {7.0}
do_execsql_test multiply-agg-float-agg-int {
SELECT sum(2.5) * sum(3)
} {7.5}
do_execsql_test divide-int {
SELECT 10 / 2
} {5}
@@ -79,6 +115,17 @@ do_execsql_test divide-null {
SELECT null / null
} {}
do_execsql_test divide-agg-int-agg-int {
SELECT sum(4) / sum(2)
} {2}
do_execsql_test divide-agg-int-agg-float {
SELECT sum(4) / sum(2.0)
} {2.0}
do_execsql_test divide-agg-float-agg-int {
SELECT sum(4.0) / sum(2)
} {2.0}
do_execsql_test add-agg-int {

View File

@@ -115,4 +115,25 @@ Dennis|Ward|1
Whitney|Walker|1
Robert|Villanueva|1
Cynthia|Thomas|1
Brandon|Tate|1}
Brandon|Tate|1}
do_execsql_test order-by-case-insensitive-aggregate {
select u.first_name, sum(u.age) from users u group by u.first_name order by SUM(u.aGe) desc limit 10;
} {Michael|11204
David|8758
Robert|8109
Jennifer|7700
John|7299
Christopher|6397
James|5921
Joseph|5711
Brian|5059
William|5047}
do_execsql_test order-by-agg-not-mentioned-in-select {
select u.first_name, length(group_concat(u.last_name)) from users u group by u.first_name order by max(u.email) desc limit 5;
} {Louis|65
Carolyn|118
Katelyn|40
Erik|88
Collin|15}

View File

@@ -637,6 +637,7 @@ impl ToTokens for Expr {
Ok(())
}
Self::Id(id) => id.to_tokens(s),
Self::Column { .. } => Ok(()),
Self::InList { lhs, not, rhs } => {
lhs.to_tokens(s)?;
if *not {

View File

@@ -327,6 +327,17 @@ pub enum Expr {
},
/// Identifier
Id(Id),
/// Column
Column {
/// the x in `x.y.z`. index of the db in catalog.
database: Option<usize>,
/// the y in `x.y.z`. index of the table in catalog.
table: usize,
/// the z in `x.y.z`. index of the column in the table.
column: usize,
/// is the column a rowid alias
is_rowid_alias: bool,
},
/// `IN`
InList {
/// expression