mirror of
https://github.com/aljazceru/turso.git
synced 2026-02-23 00:45:37 +01:00
Add basic CTE support
This commit is contained in:
@@ -345,6 +345,7 @@ impl Connection {
|
||||
&self.schema.borrow(),
|
||||
*select,
|
||||
&self.db.syms.borrow(),
|
||||
None,
|
||||
)?;
|
||||
optimize_plan(&mut plan, &self.schema.borrow())?;
|
||||
println!("{}", plan);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use super::{
|
||||
plan::{
|
||||
Aggregate, JoinInfo, Operation, Plan, ResultSetColumn, SelectQueryType, TableReference,
|
||||
WhereTerm,
|
||||
Aggregate, JoinInfo, Operation, Plan, ResultSetColumn, SelectPlan, SelectQueryType,
|
||||
TableReference, WhereTerm,
|
||||
},
|
||||
select::prepare_select_plan,
|
||||
SymbolTable,
|
||||
@@ -13,7 +13,9 @@ use crate::{
|
||||
vdbe::BranchOffset,
|
||||
Result, VirtualTable,
|
||||
};
|
||||
use sqlite3_parser::ast::{self, Expr, FromClause, JoinType, Limit, UnaryOperator};
|
||||
use sqlite3_parser::ast::{
|
||||
self, Expr, FromClause, JoinType, Limit, Materialized, UnaryOperator, With,
|
||||
};
|
||||
|
||||
pub const ROWID: &str = "rowid";
|
||||
|
||||
@@ -278,46 +280,92 @@ pub fn bind_column_references(
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_from_clause_table(
|
||||
fn parse_from_clause_table<'a>(
|
||||
schema: &Schema,
|
||||
table: ast::SelectTable,
|
||||
cur_table_index: usize,
|
||||
scope: &mut Scope<'a>,
|
||||
syms: &SymbolTable,
|
||||
) -> Result<TableReference> {
|
||||
) -> Result<()> {
|
||||
match table {
|
||||
ast::SelectTable::Table(qualified_name, maybe_alias, _) => {
|
||||
let normalized_qualified_name = normalize_ident(qualified_name.name.0.as_str());
|
||||
let Some(table) = schema.get_table(&normalized_qualified_name) else {
|
||||
crate::bail_parse_error!("Table {} not found", normalized_qualified_name);
|
||||
// Check if the FROM clause table is referring to a CTE in the current scope.
|
||||
if let Some(cte) = scope
|
||||
.ctes
|
||||
.iter()
|
||||
.find(|cte| cte.name == normalized_qualified_name)
|
||||
{
|
||||
// CTE can be rewritten as a subquery.
|
||||
// TODO: find a way not to clone the CTE plan here.
|
||||
let cte_table =
|
||||
TableReference::new_subquery(cte.name.clone(), cte.plan.clone(), None);
|
||||
scope.tables.push(cte_table);
|
||||
return Ok(());
|
||||
};
|
||||
let alias = maybe_alias
|
||||
.map(|a| match a {
|
||||
ast::As::As(id) => id,
|
||||
ast::As::Elided(id) => id,
|
||||
})
|
||||
.map(|a| a.0);
|
||||
Ok(TableReference {
|
||||
op: Operation::Scan { iter_dir: None },
|
||||
table: Table::BTree(table.clone()),
|
||||
identifier: alias.unwrap_or(normalized_qualified_name),
|
||||
join_info: None,
|
||||
})
|
||||
// Check if our top level schema has this table.
|
||||
if let Some(table) = schema.get_table(&normalized_qualified_name) {
|
||||
let alias = maybe_alias
|
||||
.map(|a| match a {
|
||||
ast::As::As(id) => id,
|
||||
ast::As::Elided(id) => id,
|
||||
})
|
||||
.map(|a| a.0);
|
||||
scope.tables.push(TableReference {
|
||||
op: Operation::Scan { iter_dir: None },
|
||||
table: Table::BTree(table.clone()),
|
||||
identifier: alias.unwrap_or(normalized_qualified_name),
|
||||
join_info: None,
|
||||
});
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
// Check if the outer query scope has this table.
|
||||
if let Some(outer_scope) = scope.parent {
|
||||
if let Some(table_ref_idx) = outer_scope
|
||||
.tables
|
||||
.iter()
|
||||
.position(|t| t.identifier == normalized_qualified_name)
|
||||
{
|
||||
// TODO: avoid cloning the table reference here.
|
||||
scope.tables.push(outer_scope.tables[table_ref_idx].clone());
|
||||
return Ok(());
|
||||
}
|
||||
if let Some(cte) = outer_scope
|
||||
.ctes
|
||||
.iter()
|
||||
.find(|cte| cte.name == normalized_qualified_name)
|
||||
{
|
||||
// TODO: avoid cloning the CTE plan here.
|
||||
let cte_table =
|
||||
TableReference::new_subquery(cte.name.clone(), cte.plan.clone(), None);
|
||||
scope.tables.push(cte_table);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
crate::bail_parse_error!("Table {} not found", normalized_qualified_name);
|
||||
}
|
||||
ast::SelectTable::Select(subselect, maybe_alias) => {
|
||||
let Plan::Select(mut subplan) = prepare_select_plan(schema, *subselect, syms)? else {
|
||||
let Plan::Select(mut subplan) =
|
||||
prepare_select_plan(schema, *subselect, syms, Some(scope))?
|
||||
else {
|
||||
unreachable!();
|
||||
};
|
||||
subplan.query_type = SelectQueryType::Subquery {
|
||||
yield_reg: usize::MAX, // will be set later in bytecode emission
|
||||
coroutine_implementation_start: BranchOffset::Placeholder, // will be set later in bytecode emission
|
||||
};
|
||||
let cur_table_index = scope.tables.len();
|
||||
let identifier = maybe_alias
|
||||
.map(|a| match a {
|
||||
ast::As::As(id) => id.0.clone(),
|
||||
ast::As::Elided(id) => id.0.clone(),
|
||||
})
|
||||
.unwrap_or(format!("subquery_{}", cur_table_index));
|
||||
Ok(TableReference::new_subquery(identifier, subplan, None))
|
||||
scope
|
||||
.tables
|
||||
.push(TableReference::new_subquery(identifier, subplan, None));
|
||||
Ok(())
|
||||
}
|
||||
ast::SelectTable::TableCall(qualified_name, maybe_args, maybe_alias) => {
|
||||
let normalized_name = &normalize_ident(qualified_name.name.0.as_str());
|
||||
@@ -332,7 +380,7 @@ fn parse_from_clause_table(
|
||||
})
|
||||
.unwrap_or(normalized_name.to_string());
|
||||
|
||||
Ok(TableReference {
|
||||
scope.tables.push(TableReference {
|
||||
op: Operation::Scan { iter_dir: None },
|
||||
join_info: None,
|
||||
table: Table::Virtual(
|
||||
@@ -346,7 +394,8 @@ fn parse_from_clause_table(
|
||||
)
|
||||
.into(),
|
||||
identifier: alias.clone(),
|
||||
})
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
_ => todo!(),
|
||||
}
|
||||
@@ -391,25 +440,85 @@ pub struct Cte {
|
||||
/// Currently we only support SELECT queries in CTEs.
|
||||
plan: SelectPlan,
|
||||
}
|
||||
|
||||
pub fn parse_from<'a>(
|
||||
schema: &Schema,
|
||||
mut from: Option<FromClause>,
|
||||
syms: &SymbolTable,
|
||||
with: Option<With>,
|
||||
out_where_clause: &mut Vec<WhereTerm>,
|
||||
outer_scope: Option<&'a Scope<'a>>,
|
||||
) -> Result<Vec<TableReference>> {
|
||||
if from.as_ref().and_then(|f| f.select.as_ref()).is_none() {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
let mut scope = Scope {
|
||||
tables: vec![],
|
||||
ctes: vec![],
|
||||
parent: outer_scope,
|
||||
};
|
||||
|
||||
if let Some(with) = with {
|
||||
if with.recursive {
|
||||
crate::bail_parse_error!("Recursive CTEs are not yet supported");
|
||||
}
|
||||
for cte in with.ctes {
|
||||
if cte.materialized == Materialized::Yes {
|
||||
crate::bail_parse_error!("Materialized CTEs are not yet supported");
|
||||
}
|
||||
if cte.columns.is_some() {
|
||||
crate::bail_parse_error!("CTE columns are not yet supported");
|
||||
}
|
||||
|
||||
// Check if normalized name conflicts with catalog tables or other CTEs
|
||||
// TODO: sqlite actually allows overriding a catalog table with a CTE.
|
||||
// We should carry over the 'Scope' struct to all of our identifier resolution.
|
||||
let cte_name_normalized = normalize_ident(&cte.tbl_name.0);
|
||||
if schema.get_table(&cte_name_normalized).is_some() {
|
||||
crate::bail_parse_error!(
|
||||
"CTE name {} conflicts with catalog table name",
|
||||
cte.tbl_name.0
|
||||
);
|
||||
}
|
||||
if scope
|
||||
.tables
|
||||
.iter()
|
||||
.any(|t| t.identifier == cte_name_normalized)
|
||||
{
|
||||
crate::bail_parse_error!("CTE name {} conflicts with table name", cte.tbl_name.0);
|
||||
}
|
||||
if scope.ctes.iter().any(|c| c.name == cte_name_normalized) {
|
||||
crate::bail_parse_error!("duplicate WITH table name {}", cte.tbl_name.0);
|
||||
}
|
||||
|
||||
// CTE can refer to other CTEs that came before it, plus any schema tables or tables in the outer scope.
|
||||
let cte_plan = prepare_select_plan(schema, *cte.select, syms, Some(&scope))?;
|
||||
let Plan::Select(mut cte_plan) = cte_plan else {
|
||||
crate::bail_parse_error!("Only SELECT queries are currently supported in CTEs");
|
||||
};
|
||||
// CTE can be rewritten as a subquery.
|
||||
cte_plan.query_type = SelectQueryType::Subquery {
|
||||
yield_reg: usize::MAX, // will be set later in bytecode emission
|
||||
coroutine_implementation_start: BranchOffset::Placeholder, // will be set later in bytecode emission
|
||||
};
|
||||
scope.ctes.push(Cte {
|
||||
name: cte_name_normalized,
|
||||
plan: cte_plan,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let mut from_owned = std::mem::take(&mut from).unwrap();
|
||||
let select_owned = *std::mem::take(&mut from_owned.select).unwrap();
|
||||
let joins_owned = std::mem::take(&mut from_owned.joins).unwrap_or_default();
|
||||
let mut tables = vec![parse_from_clause_table(schema, select_owned, 0, syms)?];
|
||||
parse_from_clause_table(schema, select_owned, &mut scope, syms)?;
|
||||
|
||||
for join in joins_owned.into_iter() {
|
||||
parse_join(schema, join, syms, &mut tables, out_where_clause)?;
|
||||
parse_join(schema, join, syms, &mut scope, out_where_clause)?;
|
||||
}
|
||||
|
||||
Ok(tables)
|
||||
Ok(scope.tables)
|
||||
}
|
||||
|
||||
pub fn parse_where(
|
||||
@@ -489,11 +598,11 @@ fn get_rightmost_table_referenced_in_expr<'a>(predicate: &'a ast::Expr) -> Resul
|
||||
Ok(max_table_idx)
|
||||
}
|
||||
|
||||
fn parse_join(
|
||||
fn parse_join<'a>(
|
||||
schema: &Schema,
|
||||
join: ast::JoinedSelectTable,
|
||||
syms: &SymbolTable,
|
||||
tables: &mut Vec<TableReference>,
|
||||
scope: &mut Scope<'a>,
|
||||
out_where_clause: &mut Vec<WhereTerm>,
|
||||
) -> Result<()> {
|
||||
let ast::JoinedSelectTable {
|
||||
@@ -502,9 +611,7 @@ fn parse_join(
|
||||
constraint,
|
||||
} = join;
|
||||
|
||||
let cur_table_index = tables.len();
|
||||
let table = parse_from_clause_table(schema, table, cur_table_index, syms)?;
|
||||
tables.push(table);
|
||||
parse_from_clause_table(schema, table, scope, syms)?;
|
||||
|
||||
let (outer, natural) = match join_operator {
|
||||
ast::JoinOperator::TypedJoin(Some(join_type)) => {
|
||||
@@ -522,15 +629,15 @@ fn parse_join(
|
||||
}
|
||||
|
||||
let constraint = if natural {
|
||||
assert!(tables.len() >= 2);
|
||||
let rightmost_table = tables.last().unwrap();
|
||||
assert!(scope.tables.len() >= 2);
|
||||
let rightmost_table = scope.tables.last().unwrap();
|
||||
// NATURAL JOIN is first transformed into a USING join with the common columns
|
||||
let right_cols = rightmost_table.columns();
|
||||
let mut distinct_names: Option<ast::DistinctNames> = None;
|
||||
// TODO: O(n^2) maybe not great for large tables or big multiway joins
|
||||
for right_col in right_cols.iter() {
|
||||
let mut found_match = false;
|
||||
for left_table in tables.iter().take(tables.len() - 1) {
|
||||
for left_table in scope.tables.iter().take(scope.tables.len() - 1) {
|
||||
for left_col in left_table.columns().iter() {
|
||||
if left_col.name == right_col.name {
|
||||
if let Some(distinct_names) = distinct_names.as_mut() {
|
||||
@@ -568,10 +675,10 @@ fn parse_join(
|
||||
let mut preds = vec![];
|
||||
break_predicate_at_and_boundaries(expr, &mut preds);
|
||||
for predicate in preds.iter_mut() {
|
||||
bind_column_references(predicate, tables, None)?;
|
||||
bind_column_references(predicate, &scope.tables, None)?;
|
||||
}
|
||||
for pred in preds {
|
||||
let cur_table_idx = tables.len() - 1;
|
||||
let cur_table_idx = scope.tables.len() - 1;
|
||||
let eval_at_loop = if outer {
|
||||
cur_table_idx
|
||||
} else {
|
||||
@@ -588,10 +695,10 @@ fn parse_join(
|
||||
// USING join is replaced with a list of equality predicates
|
||||
for distinct_name in distinct_names.iter() {
|
||||
let name_normalized = normalize_ident(distinct_name.0.as_str());
|
||||
let cur_table_idx = tables.len() - 1;
|
||||
let left_tables = &tables[..cur_table_idx];
|
||||
let cur_table_idx = scope.tables.len() - 1;
|
||||
let left_tables = &scope.tables[..cur_table_idx];
|
||||
assert!(!left_tables.is_empty());
|
||||
let right_table = tables.last().unwrap();
|
||||
let right_table = scope.tables.last().unwrap();
|
||||
let mut left_col = None;
|
||||
for (left_table_idx, left_table) in left_tables.iter().enumerate() {
|
||||
left_col = left_table
|
||||
@@ -658,9 +765,9 @@ fn parse_join(
|
||||
}
|
||||
}
|
||||
|
||||
assert!(tables.len() >= 2);
|
||||
let last_idx = tables.len() - 1;
|
||||
let rightmost_table = tables.get_mut(last_idx).unwrap();
|
||||
assert!(scope.tables.len() >= 2);
|
||||
let last_idx = scope.tables.len() - 1;
|
||||
let rightmost_table = scope.tables.get_mut(last_idx).unwrap();
|
||||
rightmost_table.join_info = Some(JoinInfo { outer, using });
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use super::emitter::emit_program;
|
||||
use super::plan::{select_star, Operation, Search, SelectQueryType};
|
||||
use super::planner::Scope;
|
||||
use crate::function::{AggFunc, ExtFunc, Func};
|
||||
use crate::translate::optimizer::optimize_plan;
|
||||
use crate::translate::plan::{Aggregate, Direction, GroupBy, Plan, ResultSetColumn, SelectPlan};
|
||||
@@ -20,7 +21,7 @@ pub fn translate_select(
|
||||
select: ast::Select,
|
||||
syms: &SymbolTable,
|
||||
) -> Result<ProgramBuilder> {
|
||||
let mut select_plan = prepare_select_plan(schema, select, syms)?;
|
||||
let mut select_plan = prepare_select_plan(schema, select, syms, None)?;
|
||||
optimize_plan(&mut select_plan, schema)?;
|
||||
let Plan::Select(ref select) = select_plan else {
|
||||
panic!("select_plan is not a SelectPlan");
|
||||
@@ -36,10 +37,11 @@ pub fn translate_select(
|
||||
Ok(program)
|
||||
}
|
||||
|
||||
pub fn prepare_select_plan(
|
||||
pub fn prepare_select_plan<'a>(
|
||||
schema: &Schema,
|
||||
select: ast::Select,
|
||||
syms: &SymbolTable,
|
||||
outer_scope: Option<&'a Scope<'a>>,
|
||||
) -> Result<Plan> {
|
||||
match *select.body.select {
|
||||
ast::OneSelect::Select {
|
||||
@@ -56,8 +58,11 @@ pub fn prepare_select_plan(
|
||||
|
||||
let mut where_predicates = vec![];
|
||||
|
||||
let with = select.with;
|
||||
|
||||
// Parse the FROM clause into a vec of TableReferences. Fold all the join conditions expressions into the WHERE clause.
|
||||
let table_references = parse_from(schema, from, syms, &mut where_predicates)?;
|
||||
let table_references =
|
||||
parse_from(schema, from, syms, with, &mut where_predicates, outer_scope)?;
|
||||
|
||||
// Preallocate space for the result columns
|
||||
let result_columns = Vec::with_capacity(
|
||||
|
||||
Reference in New Issue
Block a user