Support WalkControl in walk_expr_mut

Now walk_expr_mut can use WalkControl to skip parts of the expression
tree. This makes it consistent with walk_expr.
This commit is contained in:
Piotr Rzysko
2025-09-02 08:56:02 +02:00
parent b911e80607
commit 6224cdbbd3
4 changed files with 440 additions and 424 deletions

View File

@@ -3237,167 +3237,170 @@ where
}
/// Recursively walks a mutable expression, applying a function to each sub-expression.
pub fn walk_expr_mut<F>(expr: &mut ast::Expr, func: &mut F) -> Result<()>
pub fn walk_expr_mut<F>(expr: &mut ast::Expr, func: &mut F) -> Result<WalkControl>
where
F: FnMut(&mut ast::Expr) -> Result<()>,
F: FnMut(&mut ast::Expr) -> Result<WalkControl>,
{
func(expr)?;
match expr {
ast::Expr::Between {
lhs, start, end, ..
} => {
walk_expr_mut(lhs, func)?;
walk_expr_mut(start, func)?;
walk_expr_mut(end, func)?;
}
ast::Expr::Binary(lhs, _, rhs) => {
walk_expr_mut(lhs, func)?;
walk_expr_mut(rhs, func)?;
}
ast::Expr::Case {
base,
when_then_pairs,
else_expr,
} => {
if let Some(base_expr) = base {
walk_expr_mut(base_expr, func)?;
}
for (when_expr, then_expr) in when_then_pairs {
walk_expr_mut(when_expr, func)?;
walk_expr_mut(then_expr, func)?;
}
if let Some(else_expr) = else_expr {
walk_expr_mut(else_expr, func)?;
}
}
ast::Expr::Cast { expr, .. } => {
walk_expr_mut(expr, func)?;
}
ast::Expr::Collate(expr, _) => {
walk_expr_mut(expr, func)?;
}
ast::Expr::Exists(_) | ast::Expr::Subquery(_) => {
// TODO: Walk through select statements if needed
}
ast::Expr::FunctionCall {
args,
order_by,
filter_over,
..
} => {
for arg in args {
walk_expr_mut(arg, func)?;
}
for sort_col in order_by {
walk_expr_mut(&mut sort_col.expr, func)?;
}
if let Some(filter_clause) = &mut filter_over.filter_clause {
walk_expr_mut(filter_clause, func)?;
}
if let Some(over_clause) = &mut filter_over.over_clause {
match over_clause {
ast::Over::Window(window) => {
for part_expr in &mut window.partition_by {
walk_expr_mut(part_expr, func)?;
}
for sort_col in &mut window.order_by {
walk_expr_mut(&mut sort_col.expr, func)?;
}
if let Some(frame_clause) = &mut window.frame_clause {
walk_expr_mut_frame_bound(&mut frame_clause.start, func)?;
if let Some(end_bound) = &mut frame_clause.end {
walk_expr_mut_frame_bound(end_bound, func)?;
match func(expr)? {
WalkControl::Continue => {
match expr {
ast::Expr::Between {
lhs, start, end, ..
} => {
walk_expr_mut(lhs, func)?;
walk_expr_mut(start, func)?;
walk_expr_mut(end, func)?;
}
ast::Expr::Binary(lhs, _, rhs) => {
walk_expr_mut(lhs, func)?;
walk_expr_mut(rhs, func)?;
}
ast::Expr::Case {
base,
when_then_pairs,
else_expr,
} => {
if let Some(base_expr) = base {
walk_expr_mut(base_expr, func)?;
}
for (when_expr, then_expr) in when_then_pairs {
walk_expr_mut(when_expr, func)?;
walk_expr_mut(then_expr, func)?;
}
if let Some(else_expr) = else_expr {
walk_expr_mut(else_expr, func)?;
}
}
ast::Expr::Cast { expr, .. } => {
walk_expr_mut(expr, func)?;
}
ast::Expr::Collate(expr, _) => {
walk_expr_mut(expr, func)?;
}
ast::Expr::Exists(_) | ast::Expr::Subquery(_) => {
// TODO: Walk through select statements if needed
}
ast::Expr::FunctionCall {
args,
order_by,
filter_over,
..
} => {
for arg in args {
walk_expr_mut(arg, func)?;
}
for sort_col in order_by {
walk_expr_mut(&mut sort_col.expr, func)?;
}
if let Some(filter_clause) = &mut filter_over.filter_clause {
walk_expr_mut(filter_clause, func)?;
}
if let Some(over_clause) = &mut filter_over.over_clause {
match over_clause {
ast::Over::Window(window) => {
for part_expr in &mut window.partition_by {
walk_expr_mut(part_expr, func)?;
}
for sort_col in &mut window.order_by {
walk_expr_mut(&mut sort_col.expr, func)?;
}
if let Some(frame_clause) = &mut window.frame_clause {
walk_expr_mut_frame_bound(&mut frame_clause.start, func)?;
if let Some(end_bound) = &mut frame_clause.end {
walk_expr_mut_frame_bound(end_bound, func)?;
}
}
}
ast::Over::Name(_) => {}
}
}
ast::Over::Name(_) => {}
}
ast::Expr::FunctionCallStar { filter_over, .. } => {
if let Some(ref mut filter_clause) = filter_over.filter_clause {
walk_expr_mut(filter_clause, func)?;
}
if let Some(ref mut over_clause) = filter_over.over_clause {
match over_clause {
ast::Over::Window(window) => {
for part_expr in &mut window.partition_by {
walk_expr_mut(part_expr, func)?;
}
for sort_col in &mut window.order_by {
walk_expr_mut(&mut sort_col.expr, func)?;
}
if let Some(frame_clause) = &mut window.frame_clause {
walk_expr_mut_frame_bound(&mut frame_clause.start, func)?;
if let Some(end_bound) = &mut frame_clause.end {
walk_expr_mut_frame_bound(end_bound, func)?;
}
}
}
ast::Over::Name(_) => {}
}
}
}
ast::Expr::InList { lhs, rhs, .. } => {
walk_expr_mut(lhs, func)?;
for expr in rhs {
walk_expr_mut(expr, func)?;
}
}
ast::Expr::InSelect { lhs, rhs: _, .. } => {
walk_expr_mut(lhs, func)?;
// TODO: Walk through select statements if needed
}
ast::Expr::InTable { lhs, args, .. } => {
walk_expr_mut(lhs, func)?;
for expr in args {
walk_expr_mut(expr, func)?;
}
}
ast::Expr::IsNull(expr) | ast::Expr::NotNull(expr) => {
walk_expr_mut(expr, func)?;
}
ast::Expr::Like {
lhs, rhs, escape, ..
} => {
walk_expr_mut(lhs, func)?;
walk_expr_mut(rhs, func)?;
if let Some(esc_expr) = escape {
walk_expr_mut(esc_expr, func)?;
}
}
ast::Expr::Parenthesized(exprs) => {
for expr in exprs {
walk_expr_mut(expr, func)?;
}
}
ast::Expr::Raise(_, expr) => {
if let Some(raise_expr) = expr {
walk_expr_mut(raise_expr, func)?;
}
}
ast::Expr::Unary(_, expr) => {
walk_expr_mut(expr, func)?;
}
ast::Expr::Id(_)
| ast::Expr::Column { .. }
| ast::Expr::RowId { .. }
| ast::Expr::Literal(_)
| ast::Expr::DoublyQualified(..)
| ast::Expr::Name(_)
| ast::Expr::Qualified(..)
| ast::Expr::Variable(_)
| ast::Expr::Register(_) => {
// No nested expressions
}
}
}
ast::Expr::FunctionCallStar { filter_over, .. } => {
if let Some(ref mut filter_clause) = filter_over.filter_clause {
walk_expr_mut(filter_clause, func)?;
}
if let Some(ref mut over_clause) = filter_over.over_clause {
match over_clause {
ast::Over::Window(window) => {
for part_expr in &mut window.partition_by {
walk_expr_mut(part_expr, func)?;
}
for sort_col in &mut window.order_by {
walk_expr_mut(&mut sort_col.expr, func)?;
}
if let Some(frame_clause) = &mut window.frame_clause {
walk_expr_mut_frame_bound(&mut frame_clause.start, func)?;
if let Some(end_bound) = &mut frame_clause.end {
walk_expr_mut_frame_bound(end_bound, func)?;
}
}
}
ast::Over::Name(_) => {}
}
}
}
ast::Expr::InList { lhs, rhs, .. } => {
walk_expr_mut(lhs, func)?;
for expr in rhs {
walk_expr_mut(expr, func)?;
}
}
ast::Expr::InSelect { lhs, rhs: _, .. } => {
walk_expr_mut(lhs, func)?;
// TODO: Walk through select statements if needed
}
ast::Expr::InTable { lhs, args, .. } => {
walk_expr_mut(lhs, func)?;
for expr in args {
walk_expr_mut(expr, func)?;
}
}
ast::Expr::IsNull(expr) | ast::Expr::NotNull(expr) => {
walk_expr_mut(expr, func)?;
}
ast::Expr::Like {
lhs, rhs, escape, ..
} => {
walk_expr_mut(lhs, func)?;
walk_expr_mut(rhs, func)?;
if let Some(esc_expr) = escape {
walk_expr_mut(esc_expr, func)?;
}
}
ast::Expr::Parenthesized(exprs) => {
for expr in exprs {
walk_expr_mut(expr, func)?;
}
}
ast::Expr::Raise(_, expr) => {
if let Some(raise_expr) = expr {
walk_expr_mut(raise_expr, func)?;
}
}
ast::Expr::Unary(_, expr) => {
walk_expr_mut(expr, func)?;
}
ast::Expr::Id(_)
| ast::Expr::Column { .. }
| ast::Expr::RowId { .. }
| ast::Expr::Literal(_)
| ast::Expr::DoublyQualified(..)
| ast::Expr::Name(_)
| ast::Expr::Qualified(..)
| ast::Expr::Variable(_)
| ast::Expr::Register(_) => {
// No nested expressions
}
}
Ok(())
WalkControl::SkipChildren => return Ok(WalkControl::Continue),
};
Ok(WalkControl::Continue)
}
fn walk_expr_mut_frame_bound<F>(bound: &mut ast::FrameBound, func: &mut F) -> Result<()>
fn walk_expr_mut_frame_bound<F>(bound: &mut ast::FrameBound, func: &mut F) -> Result<WalkControl>
where
F: FnMut(&mut ast::Expr) -> Result<()>,
F: FnMut(&mut ast::Expr) -> Result<WalkControl>,
{
match bound {
ast::FrameBound::Following(expr) | ast::FrameBound::Preceding(expr) => {
@@ -3408,7 +3411,7 @@ where
| ast::FrameBound::UnboundedPreceding => {}
}
Ok(())
Ok(WalkControl::Continue)
}
pub fn get_expr_affinity(

View File

@@ -15,7 +15,7 @@ use crate::{
parameters::PARAM_PREFIX,
schema::{Index, IndexColumn, Schema, Table},
translate::{
expr::walk_expr_mut, optimizer::access_method::AccessMethodParams,
expr::walk_expr_mut, expr::WalkControl, optimizer::access_method::AccessMethodParams,
optimizer::constraints::TableConstraints, plan::Scan, plan::TerminationKey,
},
types::SeekOp,
@@ -1443,71 +1443,74 @@ fn build_seek_def(
})
}
pub fn rewrite_expr(top_level_expr: &mut ast::Expr, param_idx: &mut usize) -> Result<()> {
walk_expr_mut(top_level_expr, &mut |expr: &mut ast::Expr| -> Result<()> {
match expr {
ast::Expr::Id(id) => {
// Convert "true" and "false" to 1 and 0
let id_bytes = id.as_str().as_bytes();
match_ignore_ascii_case!(match id_bytes {
b"true" => {
*expr = ast::Expr::Literal(ast::Literal::Numeric("1".to_owned()));
}
b"false" => {
*expr = ast::Expr::Literal(ast::Literal::Numeric("0".to_owned()));
}
_ => {}
})
}
ast::Expr::Variable(var) => {
if var.is_empty() {
// rewrite anonymous variables only, ensure that the `param_idx` starts at 1 and
// all the expressions are rewritten in the order they come in the statement
*expr = ast::Expr::Variable(format!("{PARAM_PREFIX}{param_idx}"));
*param_idx += 1;
pub fn rewrite_expr(top_level_expr: &mut ast::Expr, param_idx: &mut usize) -> Result<WalkControl> {
walk_expr_mut(
top_level_expr,
&mut |expr: &mut ast::Expr| -> Result<WalkControl> {
match expr {
ast::Expr::Id(id) => {
// Convert "true" and "false" to 1 and 0
let id_bytes = id.as_str().as_bytes();
match_ignore_ascii_case!(match id_bytes {
b"true" => {
*expr = ast::Expr::Literal(ast::Literal::Numeric("1".to_owned()));
}
b"false" => {
*expr = ast::Expr::Literal(ast::Literal::Numeric("0".to_owned()));
}
_ => {}
})
}
}
ast::Expr::Between {
lhs,
not,
start,
end,
} => {
// Convert `y NOT BETWEEN x AND z` to `x > y OR y > z`
let (lower_op, upper_op) = if *not {
(ast::Operator::Greater, ast::Operator::Greater)
} else {
// Convert `y BETWEEN x AND z` to `x <= y AND y <= z`
(ast::Operator::LessEquals, ast::Operator::LessEquals)
};
let start = start.take_ownership();
let lhs = lhs.take_ownership();
let end = end.take_ownership();
let lower_bound =
ast::Expr::Binary(Box::new(start), lower_op, Box::new(lhs.clone()));
let upper_bound = ast::Expr::Binary(Box::new(lhs), upper_op, Box::new(end));
if *not {
*expr = ast::Expr::Binary(
Box::new(lower_bound),
ast::Operator::Or,
Box::new(upper_bound),
);
} else {
*expr = ast::Expr::Binary(
Box::new(lower_bound),
ast::Operator::And,
Box::new(upper_bound),
);
ast::Expr::Variable(var) => {
if var.is_empty() {
// rewrite anonymous variables only, ensure that the `param_idx` starts at 1 and
// all the expressions are rewritten in the order they come in the statement
*expr = ast::Expr::Variable(format!("{PARAM_PREFIX}{param_idx}"));
*param_idx += 1;
}
}
}
_ => {}
}
ast::Expr::Between {
lhs,
not,
start,
end,
} => {
// Convert `y NOT BETWEEN x AND z` to `x > y OR y > z`
let (lower_op, upper_op) = if *not {
(ast::Operator::Greater, ast::Operator::Greater)
} else {
// Convert `y BETWEEN x AND z` to `x <= y AND y <= z`
(ast::Operator::LessEquals, ast::Operator::LessEquals)
};
Ok(())
})
let start = start.take_ownership();
let lhs = lhs.take_ownership();
let end = end.take_ownership();
let lower_bound =
ast::Expr::Binary(Box::new(start), lower_op, Box::new(lhs.clone()));
let upper_bound = ast::Expr::Binary(Box::new(lhs), upper_op, Box::new(end));
if *not {
*expr = ast::Expr::Binary(
Box::new(lower_bound),
ast::Operator::Or,
Box::new(upper_bound),
);
} else {
*expr = ast::Expr::Binary(
Box::new(lower_bound),
ast::Operator::And,
Box::new(upper_bound),
);
}
}
_ => {}
}
Ok(WalkControl::Continue)
},
)
}
trait TakeOwnership {

View File

@@ -152,65 +152,39 @@ pub fn bind_column_references(
referenced_tables: &mut TableReferences,
result_columns: Option<&[ResultSetColumn]>,
connection: &Arc<crate::Connection>,
) -> Result<()> {
walk_expr_mut(top_level_expr, &mut |expr: &mut Expr| -> Result<()> {
match expr {
Expr::Id(id) => {
// true and false are special constants that are effectively aliases for 1 and 0
// and not identifiers of columns
let id_bytes = id.as_str().as_bytes();
match_ignore_ascii_case!(match id_bytes {
b"true" | b"false" => {
return Ok(());
}
_ => {}
});
let normalized_id = normalize_ident(id.as_str());
if !referenced_tables.joined_tables().is_empty() {
if let Some(row_id_expr) = parse_row_id(
&normalized_id,
referenced_tables.joined_tables()[0].internal_id,
|| referenced_tables.joined_tables().len() != 1,
)? {
*expr = row_id_expr;
return Ok(());
}
}
let mut match_result = None;
// First check joined tables
for joined_table in referenced_tables.joined_tables().iter() {
let col_idx = joined_table.table.columns().iter().position(|c| {
c.name
.as_ref()
.is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id))
});
if col_idx.is_some() {
if match_result.is_some() {
crate::bail_parse_error!("Column {} is ambiguous", id.as_str());
) -> Result<WalkControl> {
walk_expr_mut(
top_level_expr,
&mut |expr: &mut Expr| -> Result<WalkControl> {
match expr {
Expr::Id(id) => {
// true and false are special constants that are effectively aliases for 1 and 0
// and not identifiers of columns
let id_bytes = id.as_str().as_bytes();
match_ignore_ascii_case!(match id_bytes {
b"true" | b"false" => {
return Ok(WalkControl::Continue);
}
let col = joined_table.table.columns().get(col_idx.unwrap()).unwrap();
match_result = Some((
joined_table.internal_id,
col_idx.unwrap(),
col.is_rowid_alias,
));
}
}
_ => {}
});
let normalized_id = normalize_ident(id.as_str());
// Then check outer query references, if we still didn't find something.
// Normally finding multiple matches for a non-qualified column is an error (column x is ambiguous)
// but in the case of subqueries, the inner query takes precedence.
// For example:
// SELECT * FROM t WHERE x = (SELECT x FROM t2)
// In this case, there is no ambiguity:
// - x in the outer query refers to t.x,
// - x in the inner query refers to t2.x.
if match_result.is_none() {
for outer_ref in referenced_tables.outer_query_refs().iter() {
let col_idx = outer_ref.table.columns().iter().position(|c| {
if !referenced_tables.joined_tables().is_empty() {
if let Some(row_id_expr) = parse_row_id(
&normalized_id,
referenced_tables.joined_tables()[0].internal_id,
|| referenced_tables.joined_tables().len() != 1,
)? {
*expr = row_id_expr;
return Ok(WalkControl::Continue);
}
}
let mut match_result = None;
// First check joined tables
for joined_table in referenced_tables.joined_tables().iter() {
let col_idx = joined_table.table.columns().iter().position(|c| {
c.name
.as_ref()
.is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id))
@@ -219,151 +193,183 @@ pub fn bind_column_references(
if match_result.is_some() {
crate::bail_parse_error!("Column {} is ambiguous", id.as_str());
}
let col = outer_ref.table.columns().get(col_idx.unwrap()).unwrap();
match_result =
Some((outer_ref.internal_id, col_idx.unwrap(), col.is_rowid_alias));
let col = joined_table.table.columns().get(col_idx.unwrap()).unwrap();
match_result = Some((
joined_table.internal_id,
col_idx.unwrap(),
col.is_rowid_alias,
));
}
}
}
if let Some((table_id, col_idx, is_rowid_alias)) = match_result {
*expr = Expr::Column {
database: None, // TODO: support different databases
table: table_id,
column: col_idx,
is_rowid_alias,
};
referenced_tables.mark_column_used(table_id, col_idx);
return Ok(());
}
if let Some(result_columns) = result_columns {
for result_column in result_columns.iter() {
if result_column
.name(referenced_tables)
.is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id))
{
*expr = result_column.expr.clone();
return Ok(());
// Then check outer query references, if we still didn't find something.
// Normally finding multiple matches for a non-qualified column is an error (column x is ambiguous)
// but in the case of subqueries, the inner query takes precedence.
// For example:
// SELECT * FROM t WHERE x = (SELECT x FROM t2)
// In this case, there is no ambiguity:
// - x in the outer query refers to t.x,
// - x in the inner query refers to t2.x.
if match_result.is_none() {
for outer_ref in referenced_tables.outer_query_refs().iter() {
let col_idx = outer_ref.table.columns().iter().position(|c| {
c.name
.as_ref()
.is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id))
});
if col_idx.is_some() {
if match_result.is_some() {
crate::bail_parse_error!("Column {} is ambiguous", id.as_str());
}
let col = outer_ref.table.columns().get(col_idx.unwrap()).unwrap();
match_result = Some((
outer_ref.internal_id,
col_idx.unwrap(),
col.is_rowid_alias,
));
}
}
}
if let Some((table_id, col_idx, is_rowid_alias)) = match_result {
*expr = Expr::Column {
database: None, // TODO: support different databases
table: table_id,
column: col_idx,
is_rowid_alias,
};
referenced_tables.mark_column_used(table_id, col_idx);
return Ok(WalkControl::Continue);
}
if let Some(result_columns) = result_columns {
for result_column in result_columns.iter() {
if result_column
.name(referenced_tables)
.is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id))
{
*expr = result_column.expr.clone();
return Ok(WalkControl::Continue);
}
}
}
// SQLite behavior: Only double-quoted identifiers get fallback to string literals
// Single quotes are handled as literals earlier, unquoted identifiers must resolve to columns
if id.is_double_quoted() {
// Convert failed double-quoted identifier to string literal
*expr = Expr::Literal(Literal::String(id.as_str().to_string()));
Ok(WalkControl::Continue)
} else {
// Unquoted identifiers must resolve to columns - no fallback
crate::bail_parse_error!("no such column: {}", id.as_str())
}
}
// SQLite behavior: Only double-quoted identifiers get fallback to string literals
// Single quotes are handled as literals earlier, unquoted identifiers must resolve to columns
if id.is_double_quoted() {
// Convert failed double-quoted identifier to string literal
*expr = Expr::Literal(Literal::String(id.as_str().to_string()));
Ok(())
} else {
// Unquoted identifiers must resolve to columns - no fallback
crate::bail_parse_error!("no such column: {}", id.as_str())
}
}
Expr::Qualified(tbl, id) => {
let normalized_table_name = normalize_ident(tbl.as_str());
let matching_tbl = referenced_tables
.find_table_and_internal_id_by_identifier(&normalized_table_name);
if matching_tbl.is_none() {
crate::bail_parse_error!("no such table: {}", normalized_table_name);
}
let (tbl_id, tbl) = matching_tbl.unwrap();
let normalized_id = normalize_ident(id.as_str());
Expr::Qualified(tbl, id) => {
let normalized_table_name = normalize_ident(tbl.as_str());
let matching_tbl = referenced_tables
.find_table_and_internal_id_by_identifier(&normalized_table_name);
if matching_tbl.is_none() {
crate::bail_parse_error!("no such table: {}", normalized_table_name);
}
let (tbl_id, tbl) = matching_tbl.unwrap();
let normalized_id = normalize_ident(id.as_str());
if let Some(row_id_expr) = parse_row_id(&normalized_id, tbl_id, || false)? {
*expr = row_id_expr;
if let Some(row_id_expr) = parse_row_id(&normalized_id, tbl_id, || false)? {
*expr = row_id_expr;
return Ok(());
}
let col_idx = tbl.columns().iter().position(|c| {
c.name
.as_ref()
.is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id))
});
let Some(col_idx) = col_idx else {
crate::bail_parse_error!("no such column: {}", normalized_id);
};
let col = tbl.columns().get(col_idx).unwrap();
*expr = Expr::Column {
database: None, // TODO: support different databases
table: tbl_id,
column: col_idx,
is_rowid_alias: col.is_rowid_alias,
};
referenced_tables.mark_column_used(tbl_id, col_idx);
Ok(())
}
Expr::DoublyQualified(db_name, tbl_name, col_name) => {
let normalized_col_name = normalize_ident(col_name.as_str());
// Create a QualifiedName and use existing resolve_database_id method
let qualified_name = ast::QualifiedName {
db_name: Some(db_name.clone()),
name: tbl_name.clone(),
alias: None,
};
let database_id = connection.resolve_database_id(&qualified_name)?;
// Get the table from the specified database
let table = connection
.with_schema(database_id, |schema| schema.get_table(tbl_name.as_str()))
.ok_or_else(|| {
crate::LimboError::ParseError(format!(
"no such table: {}.{}",
db_name.as_str(),
tbl_name.as_str()
))
})?;
// Find the column in the table
let col_idx = table
.columns()
.iter()
.position(|c| {
return Ok(WalkControl::Continue);
}
let col_idx = tbl.columns().iter().position(|c| {
c.name
.as_ref()
.is_some_and(|name| name.eq_ignore_ascii_case(&normalized_col_name))
})
.ok_or_else(|| {
crate::LimboError::ParseError(format!(
"Column: {}.{}.{} not found",
db_name.as_str(),
tbl_name.as_str(),
col_name.as_str()
))
})?;
let col = table.columns().get(col_idx).unwrap();
// Check if this is a rowid alias
let is_rowid_alias = col.is_rowid_alias;
// Convert to Column expression - since this is a cross-database reference,
// we need to create a synthetic table reference for it
// For now, we'll error if the table isn't already in the referenced tables
let normalized_tbl_name = normalize_ident(tbl_name.as_str());
let matching_tbl = referenced_tables
.find_table_and_internal_id_by_identifier(&normalized_tbl_name);
if let Some((tbl_id, _)) = matching_tbl {
// Table is already in referenced tables, use existing internal ID
.is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id))
});
let Some(col_idx) = col_idx else {
crate::bail_parse_error!("no such column: {}", normalized_id);
};
let col = tbl.columns().get(col_idx).unwrap();
*expr = Expr::Column {
database: Some(database_id),
database: None, // TODO: support different databases
table: tbl_id,
column: col_idx,
is_rowid_alias,
is_rowid_alias: col.is_rowid_alias,
};
referenced_tables.mark_column_used(tbl_id, col_idx);
} else {
return Err(crate::LimboError::ParseError(format!(
"table {normalized_tbl_name} is not in FROM clause - cross-database column references require the table to be explicitly joined"
)));
Ok(WalkControl::Continue)
}
Expr::DoublyQualified(db_name, tbl_name, col_name) => {
let normalized_col_name = normalize_ident(col_name.as_str());
Ok(())
// Create a QualifiedName and use existing resolve_database_id method
let qualified_name = ast::QualifiedName {
db_name: Some(db_name.clone()),
name: tbl_name.clone(),
alias: None,
};
let database_id = connection.resolve_database_id(&qualified_name)?;
// Get the table from the specified database
let table = connection
.with_schema(database_id, |schema| schema.get_table(tbl_name.as_str()))
.ok_or_else(|| {
crate::LimboError::ParseError(format!(
"no such table: {}.{}",
db_name.as_str(),
tbl_name.as_str()
))
})?;
// Find the column in the table
let col_idx = table
.columns()
.iter()
.position(|c| {
c.name
.as_ref()
.is_some_and(|name| name.eq_ignore_ascii_case(&normalized_col_name))
})
.ok_or_else(|| {
crate::LimboError::ParseError(format!(
"Column: {}.{}.{} not found",
db_name.as_str(),
tbl_name.as_str(),
col_name.as_str()
))
})?;
let col = table.columns().get(col_idx).unwrap();
// Check if this is a rowid alias
let is_rowid_alias = col.is_rowid_alias;
// Convert to Column expression - since this is a cross-database reference,
// we need to create a synthetic table reference for it
// For now, we'll error if the table isn't already in the referenced tables
let normalized_tbl_name = normalize_ident(tbl_name.as_str());
let matching_tbl = referenced_tables
.find_table_and_internal_id_by_identifier(&normalized_tbl_name);
if let Some((tbl_id, _)) = matching_tbl {
// Table is already in referenced tables, use existing internal ID
*expr = Expr::Column {
database: Some(database_id),
table: tbl_id,
column: col_idx,
is_rowid_alias,
};
referenced_tables.mark_column_used(tbl_id, col_idx);
} else {
return Err(crate::LimboError::ParseError(format!(
"table {normalized_tbl_name} is not in FROM clause - cross-database column references require the table to be explicitly joined"
)));
}
Ok(WalkControl::Continue)
}
_ => Ok(WalkControl::Continue),
}
_ => Ok(()),
}
})
},
)
}
#[allow(clippy::too_many_arguments)]

View File

@@ -2,6 +2,7 @@ use std::{collections::HashMap, sync::Arc};
use turso_parser::ast::{self, Upsert};
use crate::translate::expr::WalkControl;
use crate::{
bail_parse_error,
error::SQLITE_CONSTRAINT_NOTNULL,
@@ -524,7 +525,7 @@ fn rewrite_upsert_expr_in_place(
current_start: usize,
conflict_rowid_reg: usize,
insertion: &Insertion,
) -> crate::Result<()> {
) -> crate::Result<WalkControl> {
use ast::Expr;
// helper: return the CURRENT-row register for a column (including rowid alias)
@@ -535,34 +536,37 @@ fn rewrite_upsert_expr_in_place(
let (idx, _c) = table.get_column_by_name(&normalize_ident(name))?;
Some(current_start + idx)
};
walk_expr_mut(e, &mut |expr: &mut ast::Expr| -> crate::Result<()> {
match expr {
// EXCLUDED.x -> insertion register
Expr::Qualified(ns, ast::Name::Ident(c)) => {
if ns.as_str().eq_ignore_ascii_case("excluded") {
let Some(reg) = insertion.get_col_mapping_by_name(c.as_str()) else {
bail_parse_error!("no such column in EXCLUDED: {}", c);
};
*expr = Expr::Register(reg.register);
} else if ns.as_str().eq_ignore_ascii_case(table_name)
// t.x -> CURRENT, only if t matches the target table name (never "excluded")
{
if let Some(reg) = col_reg(c.as_str()) {
walk_expr_mut(
e,
&mut |expr: &mut ast::Expr| -> crate::Result<WalkControl> {
match expr {
// EXCLUDED.x -> insertion register
Expr::Qualified(ns, ast::Name::Ident(c)) => {
if ns.as_str().eq_ignore_ascii_case("excluded") {
let Some(reg) = insertion.get_col_mapping_by_name(c.as_str()) else {
bail_parse_error!("no such column in EXCLUDED: {}", c);
};
*expr = Expr::Register(reg.register);
} else if ns.as_str().eq_ignore_ascii_case(table_name)
// t.x -> CURRENT, only if t matches the target table name (never "excluded")
{
if let Some(reg) = col_reg(c.as_str()) {
*expr = Expr::Register(reg);
}
}
}
// Unqualified column id -> CURRENT
Expr::Id(ast::Name::Ident(name)) => {
if let Some(reg) = col_reg(name.as_str()) {
*expr = Expr::Register(reg);
}
}
}
// Unqualified column id -> CURRENT
Expr::Id(ast::Name::Ident(name)) => {
if let Some(reg) = col_reg(name.as_str()) {
*expr = Expr::Register(reg);
Expr::RowId { .. } => {
*expr = Expr::Register(conflict_rowid_reg);
}
_ => {}
}
Expr::RowId { .. } => {
*expr = Expr::Register(conflict_rowid_reg);
}
_ => {}
}
Ok(())
})
Ok(WalkControl::Continue)
},
)
}