mirror of
https://github.com/aljazceru/turso.git
synced 2026-02-03 23:34:24 +01:00
Support WalkControl in walk_expr_mut
Now walk_expr_mut can use WalkControl to skip parts of the expression tree. This makes it consistent with walk_expr.
This commit is contained in:
@@ -3237,167 +3237,170 @@ where
|
||||
}
|
||||
|
||||
/// Recursively walks a mutable expression, applying a function to each sub-expression.
|
||||
pub fn walk_expr_mut<F>(expr: &mut ast::Expr, func: &mut F) -> Result<()>
|
||||
pub fn walk_expr_mut<F>(expr: &mut ast::Expr, func: &mut F) -> Result<WalkControl>
|
||||
where
|
||||
F: FnMut(&mut ast::Expr) -> Result<()>,
|
||||
F: FnMut(&mut ast::Expr) -> Result<WalkControl>,
|
||||
{
|
||||
func(expr)?;
|
||||
match expr {
|
||||
ast::Expr::Between {
|
||||
lhs, start, end, ..
|
||||
} => {
|
||||
walk_expr_mut(lhs, func)?;
|
||||
walk_expr_mut(start, func)?;
|
||||
walk_expr_mut(end, func)?;
|
||||
}
|
||||
ast::Expr::Binary(lhs, _, rhs) => {
|
||||
walk_expr_mut(lhs, func)?;
|
||||
walk_expr_mut(rhs, func)?;
|
||||
}
|
||||
ast::Expr::Case {
|
||||
base,
|
||||
when_then_pairs,
|
||||
else_expr,
|
||||
} => {
|
||||
if let Some(base_expr) = base {
|
||||
walk_expr_mut(base_expr, func)?;
|
||||
}
|
||||
for (when_expr, then_expr) in when_then_pairs {
|
||||
walk_expr_mut(when_expr, func)?;
|
||||
walk_expr_mut(then_expr, func)?;
|
||||
}
|
||||
if let Some(else_expr) = else_expr {
|
||||
walk_expr_mut(else_expr, func)?;
|
||||
}
|
||||
}
|
||||
ast::Expr::Cast { expr, .. } => {
|
||||
walk_expr_mut(expr, func)?;
|
||||
}
|
||||
ast::Expr::Collate(expr, _) => {
|
||||
walk_expr_mut(expr, func)?;
|
||||
}
|
||||
ast::Expr::Exists(_) | ast::Expr::Subquery(_) => {
|
||||
// TODO: Walk through select statements if needed
|
||||
}
|
||||
ast::Expr::FunctionCall {
|
||||
args,
|
||||
order_by,
|
||||
filter_over,
|
||||
..
|
||||
} => {
|
||||
for arg in args {
|
||||
walk_expr_mut(arg, func)?;
|
||||
}
|
||||
for sort_col in order_by {
|
||||
walk_expr_mut(&mut sort_col.expr, func)?;
|
||||
}
|
||||
if let Some(filter_clause) = &mut filter_over.filter_clause {
|
||||
walk_expr_mut(filter_clause, func)?;
|
||||
}
|
||||
if let Some(over_clause) = &mut filter_over.over_clause {
|
||||
match over_clause {
|
||||
ast::Over::Window(window) => {
|
||||
for part_expr in &mut window.partition_by {
|
||||
walk_expr_mut(part_expr, func)?;
|
||||
}
|
||||
for sort_col in &mut window.order_by {
|
||||
walk_expr_mut(&mut sort_col.expr, func)?;
|
||||
}
|
||||
if let Some(frame_clause) = &mut window.frame_clause {
|
||||
walk_expr_mut_frame_bound(&mut frame_clause.start, func)?;
|
||||
if let Some(end_bound) = &mut frame_clause.end {
|
||||
walk_expr_mut_frame_bound(end_bound, func)?;
|
||||
match func(expr)? {
|
||||
WalkControl::Continue => {
|
||||
match expr {
|
||||
ast::Expr::Between {
|
||||
lhs, start, end, ..
|
||||
} => {
|
||||
walk_expr_mut(lhs, func)?;
|
||||
walk_expr_mut(start, func)?;
|
||||
walk_expr_mut(end, func)?;
|
||||
}
|
||||
ast::Expr::Binary(lhs, _, rhs) => {
|
||||
walk_expr_mut(lhs, func)?;
|
||||
walk_expr_mut(rhs, func)?;
|
||||
}
|
||||
ast::Expr::Case {
|
||||
base,
|
||||
when_then_pairs,
|
||||
else_expr,
|
||||
} => {
|
||||
if let Some(base_expr) = base {
|
||||
walk_expr_mut(base_expr, func)?;
|
||||
}
|
||||
for (when_expr, then_expr) in when_then_pairs {
|
||||
walk_expr_mut(when_expr, func)?;
|
||||
walk_expr_mut(then_expr, func)?;
|
||||
}
|
||||
if let Some(else_expr) = else_expr {
|
||||
walk_expr_mut(else_expr, func)?;
|
||||
}
|
||||
}
|
||||
ast::Expr::Cast { expr, .. } => {
|
||||
walk_expr_mut(expr, func)?;
|
||||
}
|
||||
ast::Expr::Collate(expr, _) => {
|
||||
walk_expr_mut(expr, func)?;
|
||||
}
|
||||
ast::Expr::Exists(_) | ast::Expr::Subquery(_) => {
|
||||
// TODO: Walk through select statements if needed
|
||||
}
|
||||
ast::Expr::FunctionCall {
|
||||
args,
|
||||
order_by,
|
||||
filter_over,
|
||||
..
|
||||
} => {
|
||||
for arg in args {
|
||||
walk_expr_mut(arg, func)?;
|
||||
}
|
||||
for sort_col in order_by {
|
||||
walk_expr_mut(&mut sort_col.expr, func)?;
|
||||
}
|
||||
if let Some(filter_clause) = &mut filter_over.filter_clause {
|
||||
walk_expr_mut(filter_clause, func)?;
|
||||
}
|
||||
if let Some(over_clause) = &mut filter_over.over_clause {
|
||||
match over_clause {
|
||||
ast::Over::Window(window) => {
|
||||
for part_expr in &mut window.partition_by {
|
||||
walk_expr_mut(part_expr, func)?;
|
||||
}
|
||||
for sort_col in &mut window.order_by {
|
||||
walk_expr_mut(&mut sort_col.expr, func)?;
|
||||
}
|
||||
if let Some(frame_clause) = &mut window.frame_clause {
|
||||
walk_expr_mut_frame_bound(&mut frame_clause.start, func)?;
|
||||
if let Some(end_bound) = &mut frame_clause.end {
|
||||
walk_expr_mut_frame_bound(end_bound, func)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
ast::Over::Name(_) => {}
|
||||
}
|
||||
}
|
||||
ast::Over::Name(_) => {}
|
||||
}
|
||||
ast::Expr::FunctionCallStar { filter_over, .. } => {
|
||||
if let Some(ref mut filter_clause) = filter_over.filter_clause {
|
||||
walk_expr_mut(filter_clause, func)?;
|
||||
}
|
||||
if let Some(ref mut over_clause) = filter_over.over_clause {
|
||||
match over_clause {
|
||||
ast::Over::Window(window) => {
|
||||
for part_expr in &mut window.partition_by {
|
||||
walk_expr_mut(part_expr, func)?;
|
||||
}
|
||||
for sort_col in &mut window.order_by {
|
||||
walk_expr_mut(&mut sort_col.expr, func)?;
|
||||
}
|
||||
if let Some(frame_clause) = &mut window.frame_clause {
|
||||
walk_expr_mut_frame_bound(&mut frame_clause.start, func)?;
|
||||
if let Some(end_bound) = &mut frame_clause.end {
|
||||
walk_expr_mut_frame_bound(end_bound, func)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
ast::Over::Name(_) => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
ast::Expr::InList { lhs, rhs, .. } => {
|
||||
walk_expr_mut(lhs, func)?;
|
||||
for expr in rhs {
|
||||
walk_expr_mut(expr, func)?;
|
||||
}
|
||||
}
|
||||
ast::Expr::InSelect { lhs, rhs: _, .. } => {
|
||||
walk_expr_mut(lhs, func)?;
|
||||
// TODO: Walk through select statements if needed
|
||||
}
|
||||
ast::Expr::InTable { lhs, args, .. } => {
|
||||
walk_expr_mut(lhs, func)?;
|
||||
for expr in args {
|
||||
walk_expr_mut(expr, func)?;
|
||||
}
|
||||
}
|
||||
ast::Expr::IsNull(expr) | ast::Expr::NotNull(expr) => {
|
||||
walk_expr_mut(expr, func)?;
|
||||
}
|
||||
ast::Expr::Like {
|
||||
lhs, rhs, escape, ..
|
||||
} => {
|
||||
walk_expr_mut(lhs, func)?;
|
||||
walk_expr_mut(rhs, func)?;
|
||||
if let Some(esc_expr) = escape {
|
||||
walk_expr_mut(esc_expr, func)?;
|
||||
}
|
||||
}
|
||||
ast::Expr::Parenthesized(exprs) => {
|
||||
for expr in exprs {
|
||||
walk_expr_mut(expr, func)?;
|
||||
}
|
||||
}
|
||||
ast::Expr::Raise(_, expr) => {
|
||||
if let Some(raise_expr) = expr {
|
||||
walk_expr_mut(raise_expr, func)?;
|
||||
}
|
||||
}
|
||||
ast::Expr::Unary(_, expr) => {
|
||||
walk_expr_mut(expr, func)?;
|
||||
}
|
||||
ast::Expr::Id(_)
|
||||
| ast::Expr::Column { .. }
|
||||
| ast::Expr::RowId { .. }
|
||||
| ast::Expr::Literal(_)
|
||||
| ast::Expr::DoublyQualified(..)
|
||||
| ast::Expr::Name(_)
|
||||
| ast::Expr::Qualified(..)
|
||||
| ast::Expr::Variable(_)
|
||||
| ast::Expr::Register(_) => {
|
||||
// No nested expressions
|
||||
}
|
||||
}
|
||||
}
|
||||
ast::Expr::FunctionCallStar { filter_over, .. } => {
|
||||
if let Some(ref mut filter_clause) = filter_over.filter_clause {
|
||||
walk_expr_mut(filter_clause, func)?;
|
||||
}
|
||||
if let Some(ref mut over_clause) = filter_over.over_clause {
|
||||
match over_clause {
|
||||
ast::Over::Window(window) => {
|
||||
for part_expr in &mut window.partition_by {
|
||||
walk_expr_mut(part_expr, func)?;
|
||||
}
|
||||
for sort_col in &mut window.order_by {
|
||||
walk_expr_mut(&mut sort_col.expr, func)?;
|
||||
}
|
||||
if let Some(frame_clause) = &mut window.frame_clause {
|
||||
walk_expr_mut_frame_bound(&mut frame_clause.start, func)?;
|
||||
if let Some(end_bound) = &mut frame_clause.end {
|
||||
walk_expr_mut_frame_bound(end_bound, func)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
ast::Over::Name(_) => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
ast::Expr::InList { lhs, rhs, .. } => {
|
||||
walk_expr_mut(lhs, func)?;
|
||||
for expr in rhs {
|
||||
walk_expr_mut(expr, func)?;
|
||||
}
|
||||
}
|
||||
ast::Expr::InSelect { lhs, rhs: _, .. } => {
|
||||
walk_expr_mut(lhs, func)?;
|
||||
// TODO: Walk through select statements if needed
|
||||
}
|
||||
ast::Expr::InTable { lhs, args, .. } => {
|
||||
walk_expr_mut(lhs, func)?;
|
||||
for expr in args {
|
||||
walk_expr_mut(expr, func)?;
|
||||
}
|
||||
}
|
||||
ast::Expr::IsNull(expr) | ast::Expr::NotNull(expr) => {
|
||||
walk_expr_mut(expr, func)?;
|
||||
}
|
||||
ast::Expr::Like {
|
||||
lhs, rhs, escape, ..
|
||||
} => {
|
||||
walk_expr_mut(lhs, func)?;
|
||||
walk_expr_mut(rhs, func)?;
|
||||
if let Some(esc_expr) = escape {
|
||||
walk_expr_mut(esc_expr, func)?;
|
||||
}
|
||||
}
|
||||
ast::Expr::Parenthesized(exprs) => {
|
||||
for expr in exprs {
|
||||
walk_expr_mut(expr, func)?;
|
||||
}
|
||||
}
|
||||
ast::Expr::Raise(_, expr) => {
|
||||
if let Some(raise_expr) = expr {
|
||||
walk_expr_mut(raise_expr, func)?;
|
||||
}
|
||||
}
|
||||
ast::Expr::Unary(_, expr) => {
|
||||
walk_expr_mut(expr, func)?;
|
||||
}
|
||||
ast::Expr::Id(_)
|
||||
| ast::Expr::Column { .. }
|
||||
| ast::Expr::RowId { .. }
|
||||
| ast::Expr::Literal(_)
|
||||
| ast::Expr::DoublyQualified(..)
|
||||
| ast::Expr::Name(_)
|
||||
| ast::Expr::Qualified(..)
|
||||
| ast::Expr::Variable(_)
|
||||
| ast::Expr::Register(_) => {
|
||||
// No nested expressions
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
WalkControl::SkipChildren => return Ok(WalkControl::Continue),
|
||||
};
|
||||
Ok(WalkControl::Continue)
|
||||
}
|
||||
|
||||
fn walk_expr_mut_frame_bound<F>(bound: &mut ast::FrameBound, func: &mut F) -> Result<()>
|
||||
fn walk_expr_mut_frame_bound<F>(bound: &mut ast::FrameBound, func: &mut F) -> Result<WalkControl>
|
||||
where
|
||||
F: FnMut(&mut ast::Expr) -> Result<()>,
|
||||
F: FnMut(&mut ast::Expr) -> Result<WalkControl>,
|
||||
{
|
||||
match bound {
|
||||
ast::FrameBound::Following(expr) | ast::FrameBound::Preceding(expr) => {
|
||||
@@ -3408,7 +3411,7 @@ where
|
||||
| ast::FrameBound::UnboundedPreceding => {}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
Ok(WalkControl::Continue)
|
||||
}
|
||||
|
||||
pub fn get_expr_affinity(
|
||||
|
||||
@@ -15,7 +15,7 @@ use crate::{
|
||||
parameters::PARAM_PREFIX,
|
||||
schema::{Index, IndexColumn, Schema, Table},
|
||||
translate::{
|
||||
expr::walk_expr_mut, optimizer::access_method::AccessMethodParams,
|
||||
expr::walk_expr_mut, expr::WalkControl, optimizer::access_method::AccessMethodParams,
|
||||
optimizer::constraints::TableConstraints, plan::Scan, plan::TerminationKey,
|
||||
},
|
||||
types::SeekOp,
|
||||
@@ -1443,71 +1443,74 @@ fn build_seek_def(
|
||||
})
|
||||
}
|
||||
|
||||
pub fn rewrite_expr(top_level_expr: &mut ast::Expr, param_idx: &mut usize) -> Result<()> {
|
||||
walk_expr_mut(top_level_expr, &mut |expr: &mut ast::Expr| -> Result<()> {
|
||||
match expr {
|
||||
ast::Expr::Id(id) => {
|
||||
// Convert "true" and "false" to 1 and 0
|
||||
let id_bytes = id.as_str().as_bytes();
|
||||
match_ignore_ascii_case!(match id_bytes {
|
||||
b"true" => {
|
||||
*expr = ast::Expr::Literal(ast::Literal::Numeric("1".to_owned()));
|
||||
}
|
||||
b"false" => {
|
||||
*expr = ast::Expr::Literal(ast::Literal::Numeric("0".to_owned()));
|
||||
}
|
||||
_ => {}
|
||||
})
|
||||
}
|
||||
ast::Expr::Variable(var) => {
|
||||
if var.is_empty() {
|
||||
// rewrite anonymous variables only, ensure that the `param_idx` starts at 1 and
|
||||
// all the expressions are rewritten in the order they come in the statement
|
||||
*expr = ast::Expr::Variable(format!("{PARAM_PREFIX}{param_idx}"));
|
||||
*param_idx += 1;
|
||||
pub fn rewrite_expr(top_level_expr: &mut ast::Expr, param_idx: &mut usize) -> Result<WalkControl> {
|
||||
walk_expr_mut(
|
||||
top_level_expr,
|
||||
&mut |expr: &mut ast::Expr| -> Result<WalkControl> {
|
||||
match expr {
|
||||
ast::Expr::Id(id) => {
|
||||
// Convert "true" and "false" to 1 and 0
|
||||
let id_bytes = id.as_str().as_bytes();
|
||||
match_ignore_ascii_case!(match id_bytes {
|
||||
b"true" => {
|
||||
*expr = ast::Expr::Literal(ast::Literal::Numeric("1".to_owned()));
|
||||
}
|
||||
b"false" => {
|
||||
*expr = ast::Expr::Literal(ast::Literal::Numeric("0".to_owned()));
|
||||
}
|
||||
_ => {}
|
||||
})
|
||||
}
|
||||
}
|
||||
ast::Expr::Between {
|
||||
lhs,
|
||||
not,
|
||||
start,
|
||||
end,
|
||||
} => {
|
||||
// Convert `y NOT BETWEEN x AND z` to `x > y OR y > z`
|
||||
let (lower_op, upper_op) = if *not {
|
||||
(ast::Operator::Greater, ast::Operator::Greater)
|
||||
} else {
|
||||
// Convert `y BETWEEN x AND z` to `x <= y AND y <= z`
|
||||
(ast::Operator::LessEquals, ast::Operator::LessEquals)
|
||||
};
|
||||
|
||||
let start = start.take_ownership();
|
||||
let lhs = lhs.take_ownership();
|
||||
let end = end.take_ownership();
|
||||
|
||||
let lower_bound =
|
||||
ast::Expr::Binary(Box::new(start), lower_op, Box::new(lhs.clone()));
|
||||
let upper_bound = ast::Expr::Binary(Box::new(lhs), upper_op, Box::new(end));
|
||||
|
||||
if *not {
|
||||
*expr = ast::Expr::Binary(
|
||||
Box::new(lower_bound),
|
||||
ast::Operator::Or,
|
||||
Box::new(upper_bound),
|
||||
);
|
||||
} else {
|
||||
*expr = ast::Expr::Binary(
|
||||
Box::new(lower_bound),
|
||||
ast::Operator::And,
|
||||
Box::new(upper_bound),
|
||||
);
|
||||
ast::Expr::Variable(var) => {
|
||||
if var.is_empty() {
|
||||
// rewrite anonymous variables only, ensure that the `param_idx` starts at 1 and
|
||||
// all the expressions are rewritten in the order they come in the statement
|
||||
*expr = ast::Expr::Variable(format!("{PARAM_PREFIX}{param_idx}"));
|
||||
*param_idx += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
ast::Expr::Between {
|
||||
lhs,
|
||||
not,
|
||||
start,
|
||||
end,
|
||||
} => {
|
||||
// Convert `y NOT BETWEEN x AND z` to `x > y OR y > z`
|
||||
let (lower_op, upper_op) = if *not {
|
||||
(ast::Operator::Greater, ast::Operator::Greater)
|
||||
} else {
|
||||
// Convert `y BETWEEN x AND z` to `x <= y AND y <= z`
|
||||
(ast::Operator::LessEquals, ast::Operator::LessEquals)
|
||||
};
|
||||
|
||||
Ok(())
|
||||
})
|
||||
let start = start.take_ownership();
|
||||
let lhs = lhs.take_ownership();
|
||||
let end = end.take_ownership();
|
||||
|
||||
let lower_bound =
|
||||
ast::Expr::Binary(Box::new(start), lower_op, Box::new(lhs.clone()));
|
||||
let upper_bound = ast::Expr::Binary(Box::new(lhs), upper_op, Box::new(end));
|
||||
|
||||
if *not {
|
||||
*expr = ast::Expr::Binary(
|
||||
Box::new(lower_bound),
|
||||
ast::Operator::Or,
|
||||
Box::new(upper_bound),
|
||||
);
|
||||
} else {
|
||||
*expr = ast::Expr::Binary(
|
||||
Box::new(lower_bound),
|
||||
ast::Operator::And,
|
||||
Box::new(upper_bound),
|
||||
);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(WalkControl::Continue)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
trait TakeOwnership {
|
||||
|
||||
@@ -152,65 +152,39 @@ pub fn bind_column_references(
|
||||
referenced_tables: &mut TableReferences,
|
||||
result_columns: Option<&[ResultSetColumn]>,
|
||||
connection: &Arc<crate::Connection>,
|
||||
) -> Result<()> {
|
||||
walk_expr_mut(top_level_expr, &mut |expr: &mut Expr| -> Result<()> {
|
||||
match expr {
|
||||
Expr::Id(id) => {
|
||||
// true and false are special constants that are effectively aliases for 1 and 0
|
||||
// and not identifiers of columns
|
||||
let id_bytes = id.as_str().as_bytes();
|
||||
match_ignore_ascii_case!(match id_bytes {
|
||||
b"true" | b"false" => {
|
||||
return Ok(());
|
||||
}
|
||||
_ => {}
|
||||
});
|
||||
let normalized_id = normalize_ident(id.as_str());
|
||||
|
||||
if !referenced_tables.joined_tables().is_empty() {
|
||||
if let Some(row_id_expr) = parse_row_id(
|
||||
&normalized_id,
|
||||
referenced_tables.joined_tables()[0].internal_id,
|
||||
|| referenced_tables.joined_tables().len() != 1,
|
||||
)? {
|
||||
*expr = row_id_expr;
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
let mut match_result = None;
|
||||
|
||||
// First check joined tables
|
||||
for joined_table in referenced_tables.joined_tables().iter() {
|
||||
let col_idx = joined_table.table.columns().iter().position(|c| {
|
||||
c.name
|
||||
.as_ref()
|
||||
.is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id))
|
||||
});
|
||||
if col_idx.is_some() {
|
||||
if match_result.is_some() {
|
||||
crate::bail_parse_error!("Column {} is ambiguous", id.as_str());
|
||||
) -> Result<WalkControl> {
|
||||
walk_expr_mut(
|
||||
top_level_expr,
|
||||
&mut |expr: &mut Expr| -> Result<WalkControl> {
|
||||
match expr {
|
||||
Expr::Id(id) => {
|
||||
// true and false are special constants that are effectively aliases for 1 and 0
|
||||
// and not identifiers of columns
|
||||
let id_bytes = id.as_str().as_bytes();
|
||||
match_ignore_ascii_case!(match id_bytes {
|
||||
b"true" | b"false" => {
|
||||
return Ok(WalkControl::Continue);
|
||||
}
|
||||
let col = joined_table.table.columns().get(col_idx.unwrap()).unwrap();
|
||||
match_result = Some((
|
||||
joined_table.internal_id,
|
||||
col_idx.unwrap(),
|
||||
col.is_rowid_alias,
|
||||
));
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
});
|
||||
let normalized_id = normalize_ident(id.as_str());
|
||||
|
||||
// Then check outer query references, if we still didn't find something.
|
||||
// Normally finding multiple matches for a non-qualified column is an error (column x is ambiguous)
|
||||
// but in the case of subqueries, the inner query takes precedence.
|
||||
// For example:
|
||||
// SELECT * FROM t WHERE x = (SELECT x FROM t2)
|
||||
// In this case, there is no ambiguity:
|
||||
// - x in the outer query refers to t.x,
|
||||
// - x in the inner query refers to t2.x.
|
||||
if match_result.is_none() {
|
||||
for outer_ref in referenced_tables.outer_query_refs().iter() {
|
||||
let col_idx = outer_ref.table.columns().iter().position(|c| {
|
||||
if !referenced_tables.joined_tables().is_empty() {
|
||||
if let Some(row_id_expr) = parse_row_id(
|
||||
&normalized_id,
|
||||
referenced_tables.joined_tables()[0].internal_id,
|
||||
|| referenced_tables.joined_tables().len() != 1,
|
||||
)? {
|
||||
*expr = row_id_expr;
|
||||
|
||||
return Ok(WalkControl::Continue);
|
||||
}
|
||||
}
|
||||
let mut match_result = None;
|
||||
|
||||
// First check joined tables
|
||||
for joined_table in referenced_tables.joined_tables().iter() {
|
||||
let col_idx = joined_table.table.columns().iter().position(|c| {
|
||||
c.name
|
||||
.as_ref()
|
||||
.is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id))
|
||||
@@ -219,151 +193,183 @@ pub fn bind_column_references(
|
||||
if match_result.is_some() {
|
||||
crate::bail_parse_error!("Column {} is ambiguous", id.as_str());
|
||||
}
|
||||
let col = outer_ref.table.columns().get(col_idx.unwrap()).unwrap();
|
||||
match_result =
|
||||
Some((outer_ref.internal_id, col_idx.unwrap(), col.is_rowid_alias));
|
||||
let col = joined_table.table.columns().get(col_idx.unwrap()).unwrap();
|
||||
match_result = Some((
|
||||
joined_table.internal_id,
|
||||
col_idx.unwrap(),
|
||||
col.is_rowid_alias,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some((table_id, col_idx, is_rowid_alias)) = match_result {
|
||||
*expr = Expr::Column {
|
||||
database: None, // TODO: support different databases
|
||||
table: table_id,
|
||||
column: col_idx,
|
||||
is_rowid_alias,
|
||||
};
|
||||
referenced_tables.mark_column_used(table_id, col_idx);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if let Some(result_columns) = result_columns {
|
||||
for result_column in result_columns.iter() {
|
||||
if result_column
|
||||
.name(referenced_tables)
|
||||
.is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id))
|
||||
{
|
||||
*expr = result_column.expr.clone();
|
||||
return Ok(());
|
||||
// Then check outer query references, if we still didn't find something.
|
||||
// Normally finding multiple matches for a non-qualified column is an error (column x is ambiguous)
|
||||
// but in the case of subqueries, the inner query takes precedence.
|
||||
// For example:
|
||||
// SELECT * FROM t WHERE x = (SELECT x FROM t2)
|
||||
// In this case, there is no ambiguity:
|
||||
// - x in the outer query refers to t.x,
|
||||
// - x in the inner query refers to t2.x.
|
||||
if match_result.is_none() {
|
||||
for outer_ref in referenced_tables.outer_query_refs().iter() {
|
||||
let col_idx = outer_ref.table.columns().iter().position(|c| {
|
||||
c.name
|
||||
.as_ref()
|
||||
.is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id))
|
||||
});
|
||||
if col_idx.is_some() {
|
||||
if match_result.is_some() {
|
||||
crate::bail_parse_error!("Column {} is ambiguous", id.as_str());
|
||||
}
|
||||
let col = outer_ref.table.columns().get(col_idx.unwrap()).unwrap();
|
||||
match_result = Some((
|
||||
outer_ref.internal_id,
|
||||
col_idx.unwrap(),
|
||||
col.is_rowid_alias,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some((table_id, col_idx, is_rowid_alias)) = match_result {
|
||||
*expr = Expr::Column {
|
||||
database: None, // TODO: support different databases
|
||||
table: table_id,
|
||||
column: col_idx,
|
||||
is_rowid_alias,
|
||||
};
|
||||
referenced_tables.mark_column_used(table_id, col_idx);
|
||||
return Ok(WalkControl::Continue);
|
||||
}
|
||||
|
||||
if let Some(result_columns) = result_columns {
|
||||
for result_column in result_columns.iter() {
|
||||
if result_column
|
||||
.name(referenced_tables)
|
||||
.is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id))
|
||||
{
|
||||
*expr = result_column.expr.clone();
|
||||
return Ok(WalkControl::Continue);
|
||||
}
|
||||
}
|
||||
}
|
||||
// SQLite behavior: Only double-quoted identifiers get fallback to string literals
|
||||
// Single quotes are handled as literals earlier, unquoted identifiers must resolve to columns
|
||||
if id.is_double_quoted() {
|
||||
// Convert failed double-quoted identifier to string literal
|
||||
*expr = Expr::Literal(Literal::String(id.as_str().to_string()));
|
||||
Ok(WalkControl::Continue)
|
||||
} else {
|
||||
// Unquoted identifiers must resolve to columns - no fallback
|
||||
crate::bail_parse_error!("no such column: {}", id.as_str())
|
||||
}
|
||||
}
|
||||
// SQLite behavior: Only double-quoted identifiers get fallback to string literals
|
||||
// Single quotes are handled as literals earlier, unquoted identifiers must resolve to columns
|
||||
if id.is_double_quoted() {
|
||||
// Convert failed double-quoted identifier to string literal
|
||||
*expr = Expr::Literal(Literal::String(id.as_str().to_string()));
|
||||
Ok(())
|
||||
} else {
|
||||
// Unquoted identifiers must resolve to columns - no fallback
|
||||
crate::bail_parse_error!("no such column: {}", id.as_str())
|
||||
}
|
||||
}
|
||||
Expr::Qualified(tbl, id) => {
|
||||
let normalized_table_name = normalize_ident(tbl.as_str());
|
||||
let matching_tbl = referenced_tables
|
||||
.find_table_and_internal_id_by_identifier(&normalized_table_name);
|
||||
if matching_tbl.is_none() {
|
||||
crate::bail_parse_error!("no such table: {}", normalized_table_name);
|
||||
}
|
||||
let (tbl_id, tbl) = matching_tbl.unwrap();
|
||||
let normalized_id = normalize_ident(id.as_str());
|
||||
Expr::Qualified(tbl, id) => {
|
||||
let normalized_table_name = normalize_ident(tbl.as_str());
|
||||
let matching_tbl = referenced_tables
|
||||
.find_table_and_internal_id_by_identifier(&normalized_table_name);
|
||||
if matching_tbl.is_none() {
|
||||
crate::bail_parse_error!("no such table: {}", normalized_table_name);
|
||||
}
|
||||
let (tbl_id, tbl) = matching_tbl.unwrap();
|
||||
let normalized_id = normalize_ident(id.as_str());
|
||||
|
||||
if let Some(row_id_expr) = parse_row_id(&normalized_id, tbl_id, || false)? {
|
||||
*expr = row_id_expr;
|
||||
if let Some(row_id_expr) = parse_row_id(&normalized_id, tbl_id, || false)? {
|
||||
*expr = row_id_expr;
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
let col_idx = tbl.columns().iter().position(|c| {
|
||||
c.name
|
||||
.as_ref()
|
||||
.is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id))
|
||||
});
|
||||
let Some(col_idx) = col_idx else {
|
||||
crate::bail_parse_error!("no such column: {}", normalized_id);
|
||||
};
|
||||
let col = tbl.columns().get(col_idx).unwrap();
|
||||
*expr = Expr::Column {
|
||||
database: None, // TODO: support different databases
|
||||
table: tbl_id,
|
||||
column: col_idx,
|
||||
is_rowid_alias: col.is_rowid_alias,
|
||||
};
|
||||
referenced_tables.mark_column_used(tbl_id, col_idx);
|
||||
Ok(())
|
||||
}
|
||||
Expr::DoublyQualified(db_name, tbl_name, col_name) => {
|
||||
let normalized_col_name = normalize_ident(col_name.as_str());
|
||||
|
||||
// Create a QualifiedName and use existing resolve_database_id method
|
||||
let qualified_name = ast::QualifiedName {
|
||||
db_name: Some(db_name.clone()),
|
||||
name: tbl_name.clone(),
|
||||
alias: None,
|
||||
};
|
||||
let database_id = connection.resolve_database_id(&qualified_name)?;
|
||||
|
||||
// Get the table from the specified database
|
||||
let table = connection
|
||||
.with_schema(database_id, |schema| schema.get_table(tbl_name.as_str()))
|
||||
.ok_or_else(|| {
|
||||
crate::LimboError::ParseError(format!(
|
||||
"no such table: {}.{}",
|
||||
db_name.as_str(),
|
||||
tbl_name.as_str()
|
||||
))
|
||||
})?;
|
||||
|
||||
// Find the column in the table
|
||||
let col_idx = table
|
||||
.columns()
|
||||
.iter()
|
||||
.position(|c| {
|
||||
return Ok(WalkControl::Continue);
|
||||
}
|
||||
let col_idx = tbl.columns().iter().position(|c| {
|
||||
c.name
|
||||
.as_ref()
|
||||
.is_some_and(|name| name.eq_ignore_ascii_case(&normalized_col_name))
|
||||
})
|
||||
.ok_or_else(|| {
|
||||
crate::LimboError::ParseError(format!(
|
||||
"Column: {}.{}.{} not found",
|
||||
db_name.as_str(),
|
||||
tbl_name.as_str(),
|
||||
col_name.as_str()
|
||||
))
|
||||
})?;
|
||||
|
||||
let col = table.columns().get(col_idx).unwrap();
|
||||
|
||||
// Check if this is a rowid alias
|
||||
let is_rowid_alias = col.is_rowid_alias;
|
||||
|
||||
// Convert to Column expression - since this is a cross-database reference,
|
||||
// we need to create a synthetic table reference for it
|
||||
// For now, we'll error if the table isn't already in the referenced tables
|
||||
let normalized_tbl_name = normalize_ident(tbl_name.as_str());
|
||||
let matching_tbl = referenced_tables
|
||||
.find_table_and_internal_id_by_identifier(&normalized_tbl_name);
|
||||
|
||||
if let Some((tbl_id, _)) = matching_tbl {
|
||||
// Table is already in referenced tables, use existing internal ID
|
||||
.is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id))
|
||||
});
|
||||
let Some(col_idx) = col_idx else {
|
||||
crate::bail_parse_error!("no such column: {}", normalized_id);
|
||||
};
|
||||
let col = tbl.columns().get(col_idx).unwrap();
|
||||
*expr = Expr::Column {
|
||||
database: Some(database_id),
|
||||
database: None, // TODO: support different databases
|
||||
table: tbl_id,
|
||||
column: col_idx,
|
||||
is_rowid_alias,
|
||||
is_rowid_alias: col.is_rowid_alias,
|
||||
};
|
||||
referenced_tables.mark_column_used(tbl_id, col_idx);
|
||||
} else {
|
||||
return Err(crate::LimboError::ParseError(format!(
|
||||
"table {normalized_tbl_name} is not in FROM clause - cross-database column references require the table to be explicitly joined"
|
||||
)));
|
||||
Ok(WalkControl::Continue)
|
||||
}
|
||||
Expr::DoublyQualified(db_name, tbl_name, col_name) => {
|
||||
let normalized_col_name = normalize_ident(col_name.as_str());
|
||||
|
||||
Ok(())
|
||||
// Create a QualifiedName and use existing resolve_database_id method
|
||||
let qualified_name = ast::QualifiedName {
|
||||
db_name: Some(db_name.clone()),
|
||||
name: tbl_name.clone(),
|
||||
alias: None,
|
||||
};
|
||||
let database_id = connection.resolve_database_id(&qualified_name)?;
|
||||
|
||||
// Get the table from the specified database
|
||||
let table = connection
|
||||
.with_schema(database_id, |schema| schema.get_table(tbl_name.as_str()))
|
||||
.ok_or_else(|| {
|
||||
crate::LimboError::ParseError(format!(
|
||||
"no such table: {}.{}",
|
||||
db_name.as_str(),
|
||||
tbl_name.as_str()
|
||||
))
|
||||
})?;
|
||||
|
||||
// Find the column in the table
|
||||
let col_idx = table
|
||||
.columns()
|
||||
.iter()
|
||||
.position(|c| {
|
||||
c.name
|
||||
.as_ref()
|
||||
.is_some_and(|name| name.eq_ignore_ascii_case(&normalized_col_name))
|
||||
})
|
||||
.ok_or_else(|| {
|
||||
crate::LimboError::ParseError(format!(
|
||||
"Column: {}.{}.{} not found",
|
||||
db_name.as_str(),
|
||||
tbl_name.as_str(),
|
||||
col_name.as_str()
|
||||
))
|
||||
})?;
|
||||
|
||||
let col = table.columns().get(col_idx).unwrap();
|
||||
|
||||
// Check if this is a rowid alias
|
||||
let is_rowid_alias = col.is_rowid_alias;
|
||||
|
||||
// Convert to Column expression - since this is a cross-database reference,
|
||||
// we need to create a synthetic table reference for it
|
||||
// For now, we'll error if the table isn't already in the referenced tables
|
||||
let normalized_tbl_name = normalize_ident(tbl_name.as_str());
|
||||
let matching_tbl = referenced_tables
|
||||
.find_table_and_internal_id_by_identifier(&normalized_tbl_name);
|
||||
|
||||
if let Some((tbl_id, _)) = matching_tbl {
|
||||
// Table is already in referenced tables, use existing internal ID
|
||||
*expr = Expr::Column {
|
||||
database: Some(database_id),
|
||||
table: tbl_id,
|
||||
column: col_idx,
|
||||
is_rowid_alias,
|
||||
};
|
||||
referenced_tables.mark_column_used(tbl_id, col_idx);
|
||||
} else {
|
||||
return Err(crate::LimboError::ParseError(format!(
|
||||
"table {normalized_tbl_name} is not in FROM clause - cross-database column references require the table to be explicitly joined"
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(WalkControl::Continue)
|
||||
}
|
||||
_ => Ok(WalkControl::Continue),
|
||||
}
|
||||
_ => Ok(()),
|
||||
}
|
||||
})
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
|
||||
@@ -2,6 +2,7 @@ use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use turso_parser::ast::{self, Upsert};
|
||||
|
||||
use crate::translate::expr::WalkControl;
|
||||
use crate::{
|
||||
bail_parse_error,
|
||||
error::SQLITE_CONSTRAINT_NOTNULL,
|
||||
@@ -524,7 +525,7 @@ fn rewrite_upsert_expr_in_place(
|
||||
current_start: usize,
|
||||
conflict_rowid_reg: usize,
|
||||
insertion: &Insertion,
|
||||
) -> crate::Result<()> {
|
||||
) -> crate::Result<WalkControl> {
|
||||
use ast::Expr;
|
||||
|
||||
// helper: return the CURRENT-row register for a column (including rowid alias)
|
||||
@@ -535,34 +536,37 @@ fn rewrite_upsert_expr_in_place(
|
||||
let (idx, _c) = table.get_column_by_name(&normalize_ident(name))?;
|
||||
Some(current_start + idx)
|
||||
};
|
||||
walk_expr_mut(e, &mut |expr: &mut ast::Expr| -> crate::Result<()> {
|
||||
match expr {
|
||||
// EXCLUDED.x -> insertion register
|
||||
Expr::Qualified(ns, ast::Name::Ident(c)) => {
|
||||
if ns.as_str().eq_ignore_ascii_case("excluded") {
|
||||
let Some(reg) = insertion.get_col_mapping_by_name(c.as_str()) else {
|
||||
bail_parse_error!("no such column in EXCLUDED: {}", c);
|
||||
};
|
||||
*expr = Expr::Register(reg.register);
|
||||
} else if ns.as_str().eq_ignore_ascii_case(table_name)
|
||||
// t.x -> CURRENT, only if t matches the target table name (never "excluded")
|
||||
{
|
||||
if let Some(reg) = col_reg(c.as_str()) {
|
||||
walk_expr_mut(
|
||||
e,
|
||||
&mut |expr: &mut ast::Expr| -> crate::Result<WalkControl> {
|
||||
match expr {
|
||||
// EXCLUDED.x -> insertion register
|
||||
Expr::Qualified(ns, ast::Name::Ident(c)) => {
|
||||
if ns.as_str().eq_ignore_ascii_case("excluded") {
|
||||
let Some(reg) = insertion.get_col_mapping_by_name(c.as_str()) else {
|
||||
bail_parse_error!("no such column in EXCLUDED: {}", c);
|
||||
};
|
||||
*expr = Expr::Register(reg.register);
|
||||
} else if ns.as_str().eq_ignore_ascii_case(table_name)
|
||||
// t.x -> CURRENT, only if t matches the target table name (never "excluded")
|
||||
{
|
||||
if let Some(reg) = col_reg(c.as_str()) {
|
||||
*expr = Expr::Register(reg);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Unqualified column id -> CURRENT
|
||||
Expr::Id(ast::Name::Ident(name)) => {
|
||||
if let Some(reg) = col_reg(name.as_str()) {
|
||||
*expr = Expr::Register(reg);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Unqualified column id -> CURRENT
|
||||
Expr::Id(ast::Name::Ident(name)) => {
|
||||
if let Some(reg) = col_reg(name.as_str()) {
|
||||
*expr = Expr::Register(reg);
|
||||
Expr::RowId { .. } => {
|
||||
*expr = Expr::Register(conflict_rowid_reg);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
Expr::RowId { .. } => {
|
||||
*expr = Expr::Register(conflict_rowid_reg);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
Ok(WalkControl::Continue)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user