Add basic CTE support

This commit is contained in:
Jussi Saurio
2025-02-08 14:50:05 +02:00
parent 338c27dad6
commit 9e70e8fe02
3 changed files with 159 additions and 46 deletions

View File

@@ -345,6 +345,7 @@ impl Connection {
&self.schema.borrow(),
*select,
&self.db.syms.borrow(),
None,
)?;
optimize_plan(&mut plan, &self.schema.borrow())?;
println!("{}", plan);

View File

@@ -1,7 +1,7 @@
use super::{
plan::{
Aggregate, JoinInfo, Operation, Plan, ResultSetColumn, SelectQueryType, TableReference,
WhereTerm,
Aggregate, JoinInfo, Operation, Plan, ResultSetColumn, SelectPlan, SelectQueryType,
TableReference, WhereTerm,
},
select::prepare_select_plan,
SymbolTable,
@@ -13,7 +13,9 @@ use crate::{
vdbe::BranchOffset,
Result, VirtualTable,
};
use sqlite3_parser::ast::{self, Expr, FromClause, JoinType, Limit, UnaryOperator};
use sqlite3_parser::ast::{
self, Expr, FromClause, JoinType, Limit, Materialized, UnaryOperator, With,
};
pub const ROWID: &str = "rowid";
@@ -278,46 +280,92 @@ pub fn bind_column_references(
}
}
fn parse_from_clause_table(
fn parse_from_clause_table<'a>(
schema: &Schema,
table: ast::SelectTable,
cur_table_index: usize,
scope: &mut Scope<'a>,
syms: &SymbolTable,
) -> Result<TableReference> {
) -> Result<()> {
match table {
ast::SelectTable::Table(qualified_name, maybe_alias, _) => {
let normalized_qualified_name = normalize_ident(qualified_name.name.0.as_str());
let Some(table) = schema.get_table(&normalized_qualified_name) else {
crate::bail_parse_error!("Table {} not found", normalized_qualified_name);
// Check if the FROM clause table is referring to a CTE in the current scope.
if let Some(cte) = scope
.ctes
.iter()
.find(|cte| cte.name == normalized_qualified_name)
{
// CTE can be rewritten as a subquery.
// TODO: find a way not to clone the CTE plan here.
let cte_table =
TableReference::new_subquery(cte.name.clone(), cte.plan.clone(), None);
scope.tables.push(cte_table);
return Ok(());
};
let alias = maybe_alias
.map(|a| match a {
ast::As::As(id) => id,
ast::As::Elided(id) => id,
})
.map(|a| a.0);
Ok(TableReference {
op: Operation::Scan { iter_dir: None },
table: Table::BTree(table.clone()),
identifier: alias.unwrap_or(normalized_qualified_name),
join_info: None,
})
// Check if our top level schema has this table.
if let Some(table) = schema.get_table(&normalized_qualified_name) {
let alias = maybe_alias
.map(|a| match a {
ast::As::As(id) => id,
ast::As::Elided(id) => id,
})
.map(|a| a.0);
scope.tables.push(TableReference {
op: Operation::Scan { iter_dir: None },
table: Table::BTree(table.clone()),
identifier: alias.unwrap_or(normalized_qualified_name),
join_info: None,
});
return Ok(());
};
// Check if the outer query scope has this table.
if let Some(outer_scope) = scope.parent {
if let Some(table_ref_idx) = outer_scope
.tables
.iter()
.position(|t| t.identifier == normalized_qualified_name)
{
// TODO: avoid cloning the table reference here.
scope.tables.push(outer_scope.tables[table_ref_idx].clone());
return Ok(());
}
if let Some(cte) = outer_scope
.ctes
.iter()
.find(|cte| cte.name == normalized_qualified_name)
{
// TODO: avoid cloning the CTE plan here.
let cte_table =
TableReference::new_subquery(cte.name.clone(), cte.plan.clone(), None);
scope.tables.push(cte_table);
return Ok(());
}
}
crate::bail_parse_error!("Table {} not found", normalized_qualified_name);
}
ast::SelectTable::Select(subselect, maybe_alias) => {
let Plan::Select(mut subplan) = prepare_select_plan(schema, *subselect, syms)? else {
let Plan::Select(mut subplan) =
prepare_select_plan(schema, *subselect, syms, Some(scope))?
else {
unreachable!();
};
subplan.query_type = SelectQueryType::Subquery {
yield_reg: usize::MAX, // will be set later in bytecode emission
coroutine_implementation_start: BranchOffset::Placeholder, // will be set later in bytecode emission
};
let cur_table_index = scope.tables.len();
let identifier = maybe_alias
.map(|a| match a {
ast::As::As(id) => id.0.clone(),
ast::As::Elided(id) => id.0.clone(),
})
.unwrap_or(format!("subquery_{}", cur_table_index));
Ok(TableReference::new_subquery(identifier, subplan, None))
scope
.tables
.push(TableReference::new_subquery(identifier, subplan, None));
Ok(())
}
ast::SelectTable::TableCall(qualified_name, maybe_args, maybe_alias) => {
let normalized_name = &normalize_ident(qualified_name.name.0.as_str());
@@ -332,7 +380,7 @@ fn parse_from_clause_table(
})
.unwrap_or(normalized_name.to_string());
Ok(TableReference {
scope.tables.push(TableReference {
op: Operation::Scan { iter_dir: None },
join_info: None,
table: Table::Virtual(
@@ -346,7 +394,8 @@ fn parse_from_clause_table(
)
.into(),
identifier: alias.clone(),
})
});
Ok(())
}
_ => todo!(),
}
@@ -391,25 +440,85 @@ pub struct Cte {
/// Currently we only support SELECT queries in CTEs.
plan: SelectPlan,
}
pub fn parse_from<'a>(
schema: &Schema,
mut from: Option<FromClause>,
syms: &SymbolTable,
with: Option<With>,
out_where_clause: &mut Vec<WhereTerm>,
outer_scope: Option<&'a Scope<'a>>,
) -> Result<Vec<TableReference>> {
if from.as_ref().and_then(|f| f.select.as_ref()).is_none() {
return Ok(vec![]);
}
let mut scope = Scope {
tables: vec![],
ctes: vec![],
parent: outer_scope,
};
if let Some(with) = with {
if with.recursive {
crate::bail_parse_error!("Recursive CTEs are not yet supported");
}
for cte in with.ctes {
if cte.materialized == Materialized::Yes {
crate::bail_parse_error!("Materialized CTEs are not yet supported");
}
if cte.columns.is_some() {
crate::bail_parse_error!("CTE columns are not yet supported");
}
// Check if normalized name conflicts with catalog tables or other CTEs
// TODO: sqlite actually allows overriding a catalog table with a CTE.
// We should carry over the 'Scope' struct to all of our identifier resolution.
let cte_name_normalized = normalize_ident(&cte.tbl_name.0);
if schema.get_table(&cte_name_normalized).is_some() {
crate::bail_parse_error!(
"CTE name {} conflicts with catalog table name",
cte.tbl_name.0
);
}
if scope
.tables
.iter()
.any(|t| t.identifier == cte_name_normalized)
{
crate::bail_parse_error!("CTE name {} conflicts with table name", cte.tbl_name.0);
}
if scope.ctes.iter().any(|c| c.name == cte_name_normalized) {
crate::bail_parse_error!("duplicate WITH table name {}", cte.tbl_name.0);
}
// CTE can refer to other CTEs that came before it, plus any schema tables or tables in the outer scope.
let cte_plan = prepare_select_plan(schema, *cte.select, syms, Some(&scope))?;
let Plan::Select(mut cte_plan) = cte_plan else {
crate::bail_parse_error!("Only SELECT queries are currently supported in CTEs");
};
// CTE can be rewritten as a subquery.
cte_plan.query_type = SelectQueryType::Subquery {
yield_reg: usize::MAX, // will be set later in bytecode emission
coroutine_implementation_start: BranchOffset::Placeholder, // will be set later in bytecode emission
};
scope.ctes.push(Cte {
name: cte_name_normalized,
plan: cte_plan,
});
}
}
let mut from_owned = std::mem::take(&mut from).unwrap();
let select_owned = *std::mem::take(&mut from_owned.select).unwrap();
let joins_owned = std::mem::take(&mut from_owned.joins).unwrap_or_default();
let mut tables = vec![parse_from_clause_table(schema, select_owned, 0, syms)?];
parse_from_clause_table(schema, select_owned, &mut scope, syms)?;
for join in joins_owned.into_iter() {
parse_join(schema, join, syms, &mut tables, out_where_clause)?;
parse_join(schema, join, syms, &mut scope, out_where_clause)?;
}
Ok(tables)
Ok(scope.tables)
}
pub fn parse_where(
@@ -489,11 +598,11 @@ fn get_rightmost_table_referenced_in_expr<'a>(predicate: &'a ast::Expr) -> Resul
Ok(max_table_idx)
}
fn parse_join(
fn parse_join<'a>(
schema: &Schema,
join: ast::JoinedSelectTable,
syms: &SymbolTable,
tables: &mut Vec<TableReference>,
scope: &mut Scope<'a>,
out_where_clause: &mut Vec<WhereTerm>,
) -> Result<()> {
let ast::JoinedSelectTable {
@@ -502,9 +611,7 @@ fn parse_join(
constraint,
} = join;
let cur_table_index = tables.len();
let table = parse_from_clause_table(schema, table, cur_table_index, syms)?;
tables.push(table);
parse_from_clause_table(schema, table, scope, syms)?;
let (outer, natural) = match join_operator {
ast::JoinOperator::TypedJoin(Some(join_type)) => {
@@ -522,15 +629,15 @@ fn parse_join(
}
let constraint = if natural {
assert!(tables.len() >= 2);
let rightmost_table = tables.last().unwrap();
assert!(scope.tables.len() >= 2);
let rightmost_table = scope.tables.last().unwrap();
// NATURAL JOIN is first transformed into a USING join with the common columns
let right_cols = rightmost_table.columns();
let mut distinct_names: Option<ast::DistinctNames> = None;
// TODO: O(n^2) maybe not great for large tables or big multiway joins
for right_col in right_cols.iter() {
let mut found_match = false;
for left_table in tables.iter().take(tables.len() - 1) {
for left_table in scope.tables.iter().take(scope.tables.len() - 1) {
for left_col in left_table.columns().iter() {
if left_col.name == right_col.name {
if let Some(distinct_names) = distinct_names.as_mut() {
@@ -568,10 +675,10 @@ fn parse_join(
let mut preds = vec![];
break_predicate_at_and_boundaries(expr, &mut preds);
for predicate in preds.iter_mut() {
bind_column_references(predicate, tables, None)?;
bind_column_references(predicate, &scope.tables, None)?;
}
for pred in preds {
let cur_table_idx = tables.len() - 1;
let cur_table_idx = scope.tables.len() - 1;
let eval_at_loop = if outer {
cur_table_idx
} else {
@@ -588,10 +695,10 @@ fn parse_join(
// USING join is replaced with a list of equality predicates
for distinct_name in distinct_names.iter() {
let name_normalized = normalize_ident(distinct_name.0.as_str());
let cur_table_idx = tables.len() - 1;
let left_tables = &tables[..cur_table_idx];
let cur_table_idx = scope.tables.len() - 1;
let left_tables = &scope.tables[..cur_table_idx];
assert!(!left_tables.is_empty());
let right_table = tables.last().unwrap();
let right_table = scope.tables.last().unwrap();
let mut left_col = None;
for (left_table_idx, left_table) in left_tables.iter().enumerate() {
left_col = left_table
@@ -658,9 +765,9 @@ fn parse_join(
}
}
assert!(tables.len() >= 2);
let last_idx = tables.len() - 1;
let rightmost_table = tables.get_mut(last_idx).unwrap();
assert!(scope.tables.len() >= 2);
let last_idx = scope.tables.len() - 1;
let rightmost_table = scope.tables.get_mut(last_idx).unwrap();
rightmost_table.join_info = Some(JoinInfo { outer, using });
Ok(())

View File

@@ -1,5 +1,6 @@
use super::emitter::emit_program;
use super::plan::{select_star, Operation, Search, SelectQueryType};
use super::planner::Scope;
use crate::function::{AggFunc, ExtFunc, Func};
use crate::translate::optimizer::optimize_plan;
use crate::translate::plan::{Aggregate, Direction, GroupBy, Plan, ResultSetColumn, SelectPlan};
@@ -20,7 +21,7 @@ pub fn translate_select(
select: ast::Select,
syms: &SymbolTable,
) -> Result<ProgramBuilder> {
let mut select_plan = prepare_select_plan(schema, select, syms)?;
let mut select_plan = prepare_select_plan(schema, select, syms, None)?;
optimize_plan(&mut select_plan, schema)?;
let Plan::Select(ref select) = select_plan else {
panic!("select_plan is not a SelectPlan");
@@ -36,10 +37,11 @@ pub fn translate_select(
Ok(program)
}
pub fn prepare_select_plan(
pub fn prepare_select_plan<'a>(
schema: &Schema,
select: ast::Select,
syms: &SymbolTable,
outer_scope: Option<&'a Scope<'a>>,
) -> Result<Plan> {
match *select.body.select {
ast::OneSelect::Select {
@@ -56,8 +58,11 @@ pub fn prepare_select_plan(
let mut where_predicates = vec![];
let with = select.with;
// Parse the FROM clause into a vec of TableReferences. Fold all the join conditions expressions into the WHERE clause.
let table_references = parse_from(schema, from, syms, &mut where_predicates)?;
let table_references =
parse_from(schema, from, syms, with, &mut where_predicates, outer_scope)?;
// Preallocate space for the result columns
let result_columns = Vec::with_capacity(