Merge 'Support JOIN USING and NATURAL JOIN' from Jussi Saurio

Closes #360
Closes #361

Closes #422
This commit is contained in:
Pekka Enberg
2024-12-11 09:17:51 +02:00
5 changed files with 264 additions and 42 deletions

View File

@@ -61,6 +61,8 @@ This document describes the SQLite compatibility status of Limbo:
| SELECT ... CROSS JOIN | Partial | |
| SELECT ... INNER JOIN | Partial | |
| SELECT ... OUTER JOIN | Partial | |
| SELECT ... JOIN USING | Yes | |
| SELECT ... NATURAL JOIN | Yes | |
| UPDATE | No | |
| UPSERT | No | |
| VACUUM | No | |

View File

@@ -260,16 +260,24 @@ fn eliminate_constants(
/**
Recursively pushes predicates down the tree, as far as possible.
Where a predicate is pushed determines at which loop level it will be evaluated.
For example, in SELECT * FROM t1 JOIN t2 JOIN t3 WHERE t1.a = t2.a AND t2.b = t3.b AND t1.c = 1
the predicate t1.c = 1 can be pushed to t1 and will be evaluated in the first (outermost) loop,
the predicate t1.a = t2.a can be pushed to t2 and will be evaluated in the second loop
while t2.b = t3.b will be evaluated in the third loop.
*/
fn push_predicates(
operator: &mut SourceOperator,
where_clause: &mut Option<Vec<ast::Expr>>,
referenced_tables: &Vec<BTreeTableReference>,
) -> Result<()> {
// First try to push down any predicates from the WHERE clause
if let Some(predicates) = where_clause {
let mut i = 0;
while i < predicates.len() {
// Take ownership of predicate to try pushing it down
let predicate = predicates[i].take_ownership();
// If predicate was successfully pushed (None returned), remove it from WHERE
let Some(predicate) = push_predicate(operator, predicate, referenced_tables)? else {
predicates.remove(i);
continue;
@@ -277,10 +285,12 @@ fn push_predicates(
predicates[i] = predicate;
i += 1;
}
// Clean up empty WHERE clause
if predicates.is_empty() {
*where_clause = None;
}
}
match operator {
SourceOperator::Join {
left,
@@ -289,6 +299,7 @@ fn push_predicates(
outer,
..
} => {
// Recursively push predicates down both sides of join
push_predicates(left, where_clause, referenced_tables)?;
push_predicates(right, where_clause, referenced_tables)?;
@@ -300,34 +311,41 @@ fn push_predicates(
let mut i = 0;
while i < predicates.len() {
// try to push the predicate to the left side first, then to the right side
// temporarily take ownership of the predicate
let predicate_owned = predicates[i].take_ownership();
// left join predicates cant be pushed to the left side
// For a join like SELECT * FROM left INNER JOIN right ON left.id = right.id AND left.name = 'foo'
// the predicate 'left.name = 'foo' can already be evaluated in the outer loop (left side of join)
// because the row can immediately be skipped if left.name != 'foo'.
// But for a LEFT JOIN, we can't do this since we need to ensure that all rows from the left table are included,
// even if there are no matching rows from the right table. This is why we can't push LEFT JOIN predicates to the left side.
let push_result = if *outer {
Some(predicate_owned)
} else {
push_predicate(left, predicate_owned, referenced_tables)?
};
// if the predicate was pushed to a child, remove it from the list
// Try pushing to left side first (see comment above for reasoning)
let Some(predicate) = push_result else {
predicates.remove(i);
continue;
};
// otherwise try to push it to the right side
// if it was pushed to the right side, remove it from the list
// Then try right side
let Some(predicate) = push_predicate(right, predicate, referenced_tables)? else {
predicates.remove(i);
continue;
};
// otherwise keep the predicate in the list
// If neither side could take it, keep in join predicates (not sure if this actually happens in practice)
// this is effectively the same as pushing to the right side, so maybe it could be removed and assert here
// that we don't reach this code
predicates[i] = predicate;
i += 1;
}
Ok(())
}
// Base cases - nowhere else to push to
SourceOperator::Scan { .. } => Ok(()),
SourceOperator::Search { .. } => Ok(()),
SourceOperator::Nothing => Ok(()),
@@ -349,24 +367,29 @@ fn push_predicate(
table_reference,
..
} => {
// Find position of this table in referenced_tables array
let table_index = referenced_tables
.iter()
.position(|t| t.table_identifier == table_reference.table_identifier)
.unwrap();
// Get bitmask showing which tables this predicate references
let predicate_bitmask =
get_table_ref_bitmask_for_ast_expr(referenced_tables, &predicate)?;
// the expression is allowed to refer to tables on its left, i.e. the righter bits in the mask
// e.g. if this table is 0010, and the table on its right in the join is 0100:
// if predicate_bitmask is 0011, the predicate can be pushed (refers to this table and the table on its left)
// if predicate_bitmask is 0001, the predicate can be pushed (refers to the table on its left)
// if predicate_bitmask is 0101, the predicate can't be pushed (refers to this table and a table on its right)
// Each table has a bit position based on join order from left to right
// e.g. in SELECT * FROM t1 JOIN t2 JOIN t3
// t1 is position 0 (001), t2 is position 1 (010), t3 is position 2 (100)
// To push a predicate to a given table, it can only reference that table and tables to its left
// Example: For table t2 at position 1 (bit 010):
// - Can push: 011 (t2 + t1), 001 (just t1), 010 (just t2)
// - Can't push: 110 (t2 + t3)
let next_table_on_the_right_in_join_bitmask = 1 << (table_index + 1);
if predicate_bitmask >= next_table_on_the_right_in_join_bitmask {
return Ok(Some(predicate));
}
// Add predicate to this table's filters
if predicates.is_none() {
predicates.replace(vec![predicate]);
} else {
@@ -375,7 +398,8 @@ fn push_predicate(
Ok(None)
}
SourceOperator::Search { .. } => Ok(Some(predicate)),
// Search nodes don't exist yet at this point; Scans are transformed to Search in use_indexes()
SourceOperator::Search { .. } => unreachable!(),
SourceOperator::Join {
left,
right,
@@ -383,31 +407,36 @@ fn push_predicate(
outer,
..
} => {
// Try pushing to left side first
let push_result_left = push_predicate(left, predicate, referenced_tables)?;
if push_result_left.is_none() {
return Ok(None);
}
// Then try right side
let push_result_right =
push_predicate(right, push_result_left.unwrap(), referenced_tables)?;
if push_result_right.is_none() {
return Ok(None);
}
// For LEFT JOIN, predicates must stay at join level
if *outer {
return Ok(Some(push_result_right.unwrap()));
}
let pred = push_result_right.unwrap();
// Get bitmasks for tables referenced in predicate and both sides of join
let table_refs_bitmask = get_table_ref_bitmask_for_ast_expr(referenced_tables, &pred)?;
let left_bitmask = get_table_ref_bitmask_for_operator(referenced_tables, left)?;
let right_bitmask = get_table_ref_bitmask_for_operator(referenced_tables, right)?;
// If predicate doesn't reference tables from both sides, it can't be a join condition
if table_refs_bitmask & left_bitmask == 0 || table_refs_bitmask & right_bitmask == 0 {
return Ok(Some(pred));
}
// Add as join predicate since it references both sides
if join_on_preds.is_none() {
join_on_preds.replace(vec![pred]);
} else {

View File

@@ -8,7 +8,7 @@ use sqlite3_parser::ast;
use crate::{
function::AggFunc,
schema::{BTreeTable, Index},
schema::{BTreeTable, Column, Index},
Result,
};
@@ -60,6 +60,64 @@ pub enum IterationDirection {
Backwards,
}
impl SourceOperator {
pub fn select_star(&self, out_columns: &mut Vec<ResultSetColumn>) {
for (table_ref, col, idx) in self.select_star_helper() {
out_columns.push(ResultSetColumn {
expr: ast::Expr::Column {
database: None,
table: table_ref.table_index,
column: idx,
is_rowid_alias: col.primary_key,
},
contains_aggregates: false,
});
}
}
/// All this ceremony is required to deduplicate columns when joining with USING
fn select_star_helper(&self) -> Vec<(&BTreeTableReference, &Column, usize)> {
match self {
SourceOperator::Join {
left, right, using, ..
} => {
let mut columns = left.select_star_helper();
// Join columns are filtered out from the right side
// in the case of a USING join.
if let Some(using_cols) = using {
let right_columns = right.select_star_helper();
for (table_ref, col, idx) in right_columns {
if !using_cols
.iter()
.any(|using_col| col.name.eq_ignore_ascii_case(&using_col.0))
{
columns.push((table_ref, col, idx));
}
}
} else {
columns.extend(right.select_star_helper());
}
columns
}
SourceOperator::Scan {
table_reference, ..
}
| SourceOperator::Search {
table_reference, ..
} => table_reference
.table
.columns
.iter()
.enumerate()
.map(|(i, col)| (table_reference, col, i))
.collect(),
SourceOperator::Nothing => Vec::new(),
}
}
}
/**
A SourceOperator is a Node in the query plan that reads data from a table.
*/
@@ -75,6 +133,7 @@ pub enum SourceOperator {
right: Box<SourceOperator>,
predicates: Option<Vec<ast::Expr>>,
outer: bool,
using: Option<ast::DistinctNames>,
},
// Scan operator
// This operator is used to scan a table.
@@ -306,7 +365,7 @@ pub fn get_table_ref_bitmask_for_operator<'a>(
table_refs_mask |= 1
<< tables
.iter()
.position(|t| Rc::ptr_eq(&t.table, &table_reference.table))
.position(|t| &t.table_identifier == &table_reference.table_identifier)
.unwrap();
}
SourceOperator::Search {
@@ -315,7 +374,7 @@ pub fn get_table_ref_bitmask_for_operator<'a>(
table_refs_mask |= 1
<< tables
.iter()
.position(|t| Rc::ptr_eq(&t.table, &table_reference.table))
.position(|t| &t.table_identifier == &table_reference.table_identifier)
.unwrap();
}
SourceOperator::Nothing => {}

View File

@@ -281,19 +281,7 @@ pub fn prepare_select_plan<'a>(schema: &Schema, select: ast::Select) -> Result<P
for column in columns.clone() {
match column {
ast::ResultColumn::Star => {
for table_reference in plan.referenced_tables.iter() {
for (idx, col) in table_reference.table.columns.iter().enumerate() {
plan.result_columns.push(ResultSetColumn {
expr: ast::Expr::Column {
database: None, // TODO: support different databases
table: table_reference.table_index,
column: idx,
is_rowid_alias: col.primary_key,
},
contains_aggregates: false,
});
}
}
plan.source.select_star(&mut plan.result_columns);
}
ast::ResultColumn::TableStar(name) => {
let name_normalized = normalize_ident(name.0.as_str());
@@ -538,13 +526,14 @@ fn parse_from(
let mut table_index = 1;
for join in from.joins.unwrap_or_default().into_iter() {
let (right, outer, predicates) =
let (right, outer, using, predicates) =
parse_join(schema, join, operator_id_counter, &mut tables, table_index)?;
operator = SourceOperator::Join {
left: Box::new(operator),
right: Box::new(right),
predicates,
outer,
using,
id: operator_id_counter.get_next_id(),
};
table_index += 1;
@@ -559,7 +548,12 @@ fn parse_join(
operator_id_counter: &mut OperatorIdCounter,
tables: &mut Vec<BTreeTableReference>,
table_index: usize,
) -> Result<(SourceOperator, bool, Option<Vec<ast::Expr>>)> {
) -> Result<(
SourceOperator,
bool,
Option<ast::DistinctNames>,
Option<Vec<ast::Expr>>,
)> {
let ast::JoinedSelectTable {
operator,
table,
@@ -588,18 +582,62 @@ fn parse_join(
tables.push(table.clone());
let outer = match operator {
let (outer, natural) = match operator {
ast::JoinOperator::TypedJoin(Some(join_type)) => {
if join_type == JoinType::LEFT | JoinType::OUTER {
true
} else {
join_type == JoinType::RIGHT | JoinType::OUTER
}
let is_outer = join_type.contains(JoinType::OUTER);
let is_natural = join_type.contains(JoinType::NATURAL);
(is_outer, is_natural)
}
_ => false,
_ => (false, false),
};
let mut using = None;
let mut predicates = None;
if natural && constraint.is_some() {
crate::bail_parse_error!("NATURAL JOIN cannot be combined with ON or USING clause");
}
let constraint = if natural {
// NATURAL JOIN is first transformed into a USING join with the common columns
let left_tables = &tables[..table_index];
assert!(!left_tables.is_empty());
let right_table = &tables[table_index];
let right_cols = &right_table.table.columns;
let mut distinct_names = None;
// TODO: O(n^2) maybe not great for large tables or big multiway joins
for right_col in right_cols.iter() {
let mut found_match = false;
for left_table in left_tables.iter() {
for left_col in left_table.table.columns.iter() {
if left_col.name == right_col.name {
if distinct_names.is_none() {
distinct_names =
Some(ast::DistinctNames::new(ast::Name(left_col.name.clone())));
} else {
distinct_names
.as_mut()
.unwrap()
.insert(ast::Name(left_col.name.clone()))
.unwrap();
}
found_match = true;
break;
}
}
if found_match {
break;
}
}
}
if distinct_names.is_none() {
crate::bail_parse_error!("No columns found to NATURAL join on");
}
Some(ast::JoinConstraint::Using(distinct_names.unwrap()))
} else {
constraint
};
if let Some(constraint) = constraint {
match constraint {
ast::JoinConstraint::On(expr) => {
@@ -610,7 +648,66 @@ fn parse_join(
}
predicates = Some(preds);
}
ast::JoinConstraint::Using(_) => todo!("USING joins not supported yet"),
ast::JoinConstraint::Using(distinct_names) => {
// USING join is replaced with a list of equality predicates
let mut using_predicates = vec![];
for distinct_name in distinct_names.iter() {
let name_normalized = normalize_ident(distinct_name.0.as_str());
let left_tables = &tables[..table_index];
assert!(!left_tables.is_empty());
let right_table = &tables[table_index];
let mut left_col = None;
for (left_table_idx, left_table) in left_tables.iter().enumerate() {
left_col = left_table
.table
.columns
.iter()
.enumerate()
.find(|(_, col)| col.name == name_normalized)
.map(|(idx, col)| (left_table_idx, idx, col));
if left_col.is_some() {
break;
}
}
if left_col.is_none() {
crate::bail_parse_error!(
"cannot join using column {} - column not present in all tables",
distinct_name.0
);
}
let right_col = right_table
.table
.columns
.iter()
.enumerate()
.find(|(_, col)| col.name == name_normalized);
if right_col.is_none() {
crate::bail_parse_error!(
"cannot join using column {} - column not present in all tables",
distinct_name.0
);
}
let (left_table_idx, left_col_idx, left_col) = left_col.unwrap();
let (right_col_idx, right_col) = right_col.unwrap();
using_predicates.push(ast::Expr::Binary(
Box::new(ast::Expr::Column {
database: None,
table: left_table_idx,
column: left_col_idx,
is_rowid_alias: left_col.primary_key,
}),
ast::Operator::Equals,
Box::new(ast::Expr::Column {
database: None,
table: right_table.table_index,
column: right_col_idx,
is_rowid_alias: right_col.primary_key,
}),
));
}
predicates = Some(using_predicates);
using = Some(distinct_names);
}
}
}
@@ -622,6 +719,7 @@ fn parse_join(
iter_dir: None,
},
outer,
using,
predicates,
))
}

View File

@@ -212,4 +212,38 @@ do_execsql_test join-utilizing-both-seekrowid-and-secondary-index {
select u.first_name, p.name from users u join products p on u.id = p.id and u.age > 70;
} {Matthew|boots
Nicholas|shorts
Jamie|hat}
Jamie|hat}
# important difference between regular SELECT * join and a SELECT * USING join is that the join keys are deduplicated
# from the result in the USING case.
do_execsql_test join-using {
select * from users join products using (id) limit 3;
} {"1|Jamie|Foster|dylan00@example.com|496-522-9493|62375 Johnson Rest Suite 322|West Lauriestad|IL|35865|94|hat|79.0
2|Cindy|Salazar|williamsrebecca@example.com|287-934-1135|75615 Stacey Shore|South Stephanie|NC|85181|37|cap|82.0
3|Tommy|Perry|warechristopher@example.org|001-288-554-8139x0276|2896 Paul Fall Apt. 972|Michaelborough|VA|15691|18|shirt|18.0"}
do_execsql_test join-using-multiple {
select u.first_name, u.last_name, p.name from users u join users u2 using(id) join products p using(id) limit 3;
} {"Jamie|Foster|hat
Cindy|Salazar|cap
Tommy|Perry|shirt"}
# NATURAL JOIN desugars to JOIN USING (common_column1, common_column2...)
do_execsql_test join-using {
select * from users natural join products limit 3;
} {"1|Jamie|Foster|dylan00@example.com|496-522-9493|62375 Johnson Rest Suite 322|West Lauriestad|IL|35865|94|hat|79.0
2|Cindy|Salazar|williamsrebecca@example.com|287-934-1135|75615 Stacey Shore|South Stephanie|NC|85181|37|cap|82.0
3|Tommy|Perry|warechristopher@example.org|001-288-554-8139x0276|2896 Paul Fall Apt. 972|Michaelborough|VA|15691|18|shirt|18.0"}
do_execsql_test natural-join-multiple {
select u.first_name, u2.last_name, p.name from users u natural join users u2 natural join products p limit 3;
} {"Jamie|Foster|hat
Cindy|Salazar|cap
Tommy|Perry|shirt"}
# have to be able to join between 1st table and 3rd table as well
do_execsql_test natural-join-and-using-join {
select u.id, u2.id, p.id from users u natural join products p join users u2 using (first_name) limit 3;
} {"1|1|1
1|1204|1
1|1261|1"}