From 6224cdbbd331d8cb2cef378dd3b8bc772a98828c Mon Sep 17 00:00:00 2001 From: Piotr Rzysko Date: Tue, 2 Sep 2025 08:56:02 +0200 Subject: [PATCH] Support WalkControl in walk_expr_mut Now walk_expr_mut can use WalkControl to skip parts of the expression tree. This makes it consistent with walk_expr. --- core/translate/expr.rs | 305 +++++++++++++------------- core/translate/optimizer/mod.rs | 129 +++++------ core/translate/planner.rs | 374 ++++++++++++++++---------------- core/translate/upsert.rs | 56 ++--- 4 files changed, 440 insertions(+), 424 deletions(-) diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 2cf3b7f96..a9f4dc0b5 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -3237,167 +3237,170 @@ where } /// Recursively walks a mutable expression, applying a function to each sub-expression. -pub fn walk_expr_mut(expr: &mut ast::Expr, func: &mut F) -> Result<()> +pub fn walk_expr_mut(expr: &mut ast::Expr, func: &mut F) -> Result where - F: FnMut(&mut ast::Expr) -> Result<()>, + F: FnMut(&mut ast::Expr) -> Result, { - func(expr)?; - match expr { - ast::Expr::Between { - lhs, start, end, .. - } => { - walk_expr_mut(lhs, func)?; - walk_expr_mut(start, func)?; - walk_expr_mut(end, func)?; - } - ast::Expr::Binary(lhs, _, rhs) => { - walk_expr_mut(lhs, func)?; - walk_expr_mut(rhs, func)?; - } - ast::Expr::Case { - base, - when_then_pairs, - else_expr, - } => { - if let Some(base_expr) = base { - walk_expr_mut(base_expr, func)?; - } - for (when_expr, then_expr) in when_then_pairs { - walk_expr_mut(when_expr, func)?; - walk_expr_mut(then_expr, func)?; - } - if let Some(else_expr) = else_expr { - walk_expr_mut(else_expr, func)?; - } - } - ast::Expr::Cast { expr, .. } => { - walk_expr_mut(expr, func)?; - } - ast::Expr::Collate(expr, _) => { - walk_expr_mut(expr, func)?; - } - ast::Expr::Exists(_) | ast::Expr::Subquery(_) => { - // TODO: Walk through select statements if needed - } - ast::Expr::FunctionCall { - args, - order_by, - filter_over, - .. - } => { - for arg in args { - walk_expr_mut(arg, func)?; - } - for sort_col in order_by { - walk_expr_mut(&mut sort_col.expr, func)?; - } - if let Some(filter_clause) = &mut filter_over.filter_clause { - walk_expr_mut(filter_clause, func)?; - } - if let Some(over_clause) = &mut filter_over.over_clause { - match over_clause { - ast::Over::Window(window) => { - for part_expr in &mut window.partition_by { - walk_expr_mut(part_expr, func)?; - } - for sort_col in &mut window.order_by { - walk_expr_mut(&mut sort_col.expr, func)?; - } - if let Some(frame_clause) = &mut window.frame_clause { - walk_expr_mut_frame_bound(&mut frame_clause.start, func)?; - if let Some(end_bound) = &mut frame_clause.end { - walk_expr_mut_frame_bound(end_bound, func)?; + match func(expr)? { + WalkControl::Continue => { + match expr { + ast::Expr::Between { + lhs, start, end, .. + } => { + walk_expr_mut(lhs, func)?; + walk_expr_mut(start, func)?; + walk_expr_mut(end, func)?; + } + ast::Expr::Binary(lhs, _, rhs) => { + walk_expr_mut(lhs, func)?; + walk_expr_mut(rhs, func)?; + } + ast::Expr::Case { + base, + when_then_pairs, + else_expr, + } => { + if let Some(base_expr) = base { + walk_expr_mut(base_expr, func)?; + } + for (when_expr, then_expr) in when_then_pairs { + walk_expr_mut(when_expr, func)?; + walk_expr_mut(then_expr, func)?; + } + if let Some(else_expr) = else_expr { + walk_expr_mut(else_expr, func)?; + } + } + ast::Expr::Cast { expr, .. } => { + walk_expr_mut(expr, func)?; + } + ast::Expr::Collate(expr, _) => { + walk_expr_mut(expr, func)?; + } + ast::Expr::Exists(_) | ast::Expr::Subquery(_) => { + // TODO: Walk through select statements if needed + } + ast::Expr::FunctionCall { + args, + order_by, + filter_over, + .. + } => { + for arg in args { + walk_expr_mut(arg, func)?; + } + for sort_col in order_by { + walk_expr_mut(&mut sort_col.expr, func)?; + } + if let Some(filter_clause) = &mut filter_over.filter_clause { + walk_expr_mut(filter_clause, func)?; + } + if let Some(over_clause) = &mut filter_over.over_clause { + match over_clause { + ast::Over::Window(window) => { + for part_expr in &mut window.partition_by { + walk_expr_mut(part_expr, func)?; + } + for sort_col in &mut window.order_by { + walk_expr_mut(&mut sort_col.expr, func)?; + } + if let Some(frame_clause) = &mut window.frame_clause { + walk_expr_mut_frame_bound(&mut frame_clause.start, func)?; + if let Some(end_bound) = &mut frame_clause.end { + walk_expr_mut_frame_bound(end_bound, func)?; + } + } } + ast::Over::Name(_) => {} } } - ast::Over::Name(_) => {} + } + ast::Expr::FunctionCallStar { filter_over, .. } => { + if let Some(ref mut filter_clause) = filter_over.filter_clause { + walk_expr_mut(filter_clause, func)?; + } + if let Some(ref mut over_clause) = filter_over.over_clause { + match over_clause { + ast::Over::Window(window) => { + for part_expr in &mut window.partition_by { + walk_expr_mut(part_expr, func)?; + } + for sort_col in &mut window.order_by { + walk_expr_mut(&mut sort_col.expr, func)?; + } + if let Some(frame_clause) = &mut window.frame_clause { + walk_expr_mut_frame_bound(&mut frame_clause.start, func)?; + if let Some(end_bound) = &mut frame_clause.end { + walk_expr_mut_frame_bound(end_bound, func)?; + } + } + } + ast::Over::Name(_) => {} + } + } + } + ast::Expr::InList { lhs, rhs, .. } => { + walk_expr_mut(lhs, func)?; + for expr in rhs { + walk_expr_mut(expr, func)?; + } + } + ast::Expr::InSelect { lhs, rhs: _, .. } => { + walk_expr_mut(lhs, func)?; + // TODO: Walk through select statements if needed + } + ast::Expr::InTable { lhs, args, .. } => { + walk_expr_mut(lhs, func)?; + for expr in args { + walk_expr_mut(expr, func)?; + } + } + ast::Expr::IsNull(expr) | ast::Expr::NotNull(expr) => { + walk_expr_mut(expr, func)?; + } + ast::Expr::Like { + lhs, rhs, escape, .. + } => { + walk_expr_mut(lhs, func)?; + walk_expr_mut(rhs, func)?; + if let Some(esc_expr) = escape { + walk_expr_mut(esc_expr, func)?; + } + } + ast::Expr::Parenthesized(exprs) => { + for expr in exprs { + walk_expr_mut(expr, func)?; + } + } + ast::Expr::Raise(_, expr) => { + if let Some(raise_expr) = expr { + walk_expr_mut(raise_expr, func)?; + } + } + ast::Expr::Unary(_, expr) => { + walk_expr_mut(expr, func)?; + } + ast::Expr::Id(_) + | ast::Expr::Column { .. } + | ast::Expr::RowId { .. } + | ast::Expr::Literal(_) + | ast::Expr::DoublyQualified(..) + | ast::Expr::Name(_) + | ast::Expr::Qualified(..) + | ast::Expr::Variable(_) + | ast::Expr::Register(_) => { + // No nested expressions } } } - ast::Expr::FunctionCallStar { filter_over, .. } => { - if let Some(ref mut filter_clause) = filter_over.filter_clause { - walk_expr_mut(filter_clause, func)?; - } - if let Some(ref mut over_clause) = filter_over.over_clause { - match over_clause { - ast::Over::Window(window) => { - for part_expr in &mut window.partition_by { - walk_expr_mut(part_expr, func)?; - } - for sort_col in &mut window.order_by { - walk_expr_mut(&mut sort_col.expr, func)?; - } - if let Some(frame_clause) = &mut window.frame_clause { - walk_expr_mut_frame_bound(&mut frame_clause.start, func)?; - if let Some(end_bound) = &mut frame_clause.end { - walk_expr_mut_frame_bound(end_bound, func)?; - } - } - } - ast::Over::Name(_) => {} - } - } - } - ast::Expr::InList { lhs, rhs, .. } => { - walk_expr_mut(lhs, func)?; - for expr in rhs { - walk_expr_mut(expr, func)?; - } - } - ast::Expr::InSelect { lhs, rhs: _, .. } => { - walk_expr_mut(lhs, func)?; - // TODO: Walk through select statements if needed - } - ast::Expr::InTable { lhs, args, .. } => { - walk_expr_mut(lhs, func)?; - for expr in args { - walk_expr_mut(expr, func)?; - } - } - ast::Expr::IsNull(expr) | ast::Expr::NotNull(expr) => { - walk_expr_mut(expr, func)?; - } - ast::Expr::Like { - lhs, rhs, escape, .. - } => { - walk_expr_mut(lhs, func)?; - walk_expr_mut(rhs, func)?; - if let Some(esc_expr) = escape { - walk_expr_mut(esc_expr, func)?; - } - } - ast::Expr::Parenthesized(exprs) => { - for expr in exprs { - walk_expr_mut(expr, func)?; - } - } - ast::Expr::Raise(_, expr) => { - if let Some(raise_expr) = expr { - walk_expr_mut(raise_expr, func)?; - } - } - ast::Expr::Unary(_, expr) => { - walk_expr_mut(expr, func)?; - } - ast::Expr::Id(_) - | ast::Expr::Column { .. } - | ast::Expr::RowId { .. } - | ast::Expr::Literal(_) - | ast::Expr::DoublyQualified(..) - | ast::Expr::Name(_) - | ast::Expr::Qualified(..) - | ast::Expr::Variable(_) - | ast::Expr::Register(_) => { - // No nested expressions - } - } - - Ok(()) + WalkControl::SkipChildren => return Ok(WalkControl::Continue), + }; + Ok(WalkControl::Continue) } -fn walk_expr_mut_frame_bound(bound: &mut ast::FrameBound, func: &mut F) -> Result<()> +fn walk_expr_mut_frame_bound(bound: &mut ast::FrameBound, func: &mut F) -> Result where - F: FnMut(&mut ast::Expr) -> Result<()>, + F: FnMut(&mut ast::Expr) -> Result, { match bound { ast::FrameBound::Following(expr) | ast::FrameBound::Preceding(expr) => { @@ -3408,7 +3411,7 @@ where | ast::FrameBound::UnboundedPreceding => {} } - Ok(()) + Ok(WalkControl::Continue) } pub fn get_expr_affinity( diff --git a/core/translate/optimizer/mod.rs b/core/translate/optimizer/mod.rs index f2222975c..74e458108 100644 --- a/core/translate/optimizer/mod.rs +++ b/core/translate/optimizer/mod.rs @@ -15,7 +15,7 @@ use crate::{ parameters::PARAM_PREFIX, schema::{Index, IndexColumn, Schema, Table}, translate::{ - expr::walk_expr_mut, optimizer::access_method::AccessMethodParams, + expr::walk_expr_mut, expr::WalkControl, optimizer::access_method::AccessMethodParams, optimizer::constraints::TableConstraints, plan::Scan, plan::TerminationKey, }, types::SeekOp, @@ -1443,71 +1443,74 @@ fn build_seek_def( }) } -pub fn rewrite_expr(top_level_expr: &mut ast::Expr, param_idx: &mut usize) -> Result<()> { - walk_expr_mut(top_level_expr, &mut |expr: &mut ast::Expr| -> Result<()> { - match expr { - ast::Expr::Id(id) => { - // Convert "true" and "false" to 1 and 0 - let id_bytes = id.as_str().as_bytes(); - match_ignore_ascii_case!(match id_bytes { - b"true" => { - *expr = ast::Expr::Literal(ast::Literal::Numeric("1".to_owned())); - } - b"false" => { - *expr = ast::Expr::Literal(ast::Literal::Numeric("0".to_owned())); - } - _ => {} - }) - } - ast::Expr::Variable(var) => { - if var.is_empty() { - // rewrite anonymous variables only, ensure that the `param_idx` starts at 1 and - // all the expressions are rewritten in the order they come in the statement - *expr = ast::Expr::Variable(format!("{PARAM_PREFIX}{param_idx}")); - *param_idx += 1; +pub fn rewrite_expr(top_level_expr: &mut ast::Expr, param_idx: &mut usize) -> Result { + walk_expr_mut( + top_level_expr, + &mut |expr: &mut ast::Expr| -> Result { + match expr { + ast::Expr::Id(id) => { + // Convert "true" and "false" to 1 and 0 + let id_bytes = id.as_str().as_bytes(); + match_ignore_ascii_case!(match id_bytes { + b"true" => { + *expr = ast::Expr::Literal(ast::Literal::Numeric("1".to_owned())); + } + b"false" => { + *expr = ast::Expr::Literal(ast::Literal::Numeric("0".to_owned())); + } + _ => {} + }) } - } - ast::Expr::Between { - lhs, - not, - start, - end, - } => { - // Convert `y NOT BETWEEN x AND z` to `x > y OR y > z` - let (lower_op, upper_op) = if *not { - (ast::Operator::Greater, ast::Operator::Greater) - } else { - // Convert `y BETWEEN x AND z` to `x <= y AND y <= z` - (ast::Operator::LessEquals, ast::Operator::LessEquals) - }; - - let start = start.take_ownership(); - let lhs = lhs.take_ownership(); - let end = end.take_ownership(); - - let lower_bound = - ast::Expr::Binary(Box::new(start), lower_op, Box::new(lhs.clone())); - let upper_bound = ast::Expr::Binary(Box::new(lhs), upper_op, Box::new(end)); - - if *not { - *expr = ast::Expr::Binary( - Box::new(lower_bound), - ast::Operator::Or, - Box::new(upper_bound), - ); - } else { - *expr = ast::Expr::Binary( - Box::new(lower_bound), - ast::Operator::And, - Box::new(upper_bound), - ); + ast::Expr::Variable(var) => { + if var.is_empty() { + // rewrite anonymous variables only, ensure that the `param_idx` starts at 1 and + // all the expressions are rewritten in the order they come in the statement + *expr = ast::Expr::Variable(format!("{PARAM_PREFIX}{param_idx}")); + *param_idx += 1; + } } - } - _ => {} - } + ast::Expr::Between { + lhs, + not, + start, + end, + } => { + // Convert `y NOT BETWEEN x AND z` to `x > y OR y > z` + let (lower_op, upper_op) = if *not { + (ast::Operator::Greater, ast::Operator::Greater) + } else { + // Convert `y BETWEEN x AND z` to `x <= y AND y <= z` + (ast::Operator::LessEquals, ast::Operator::LessEquals) + }; - Ok(()) - }) + let start = start.take_ownership(); + let lhs = lhs.take_ownership(); + let end = end.take_ownership(); + + let lower_bound = + ast::Expr::Binary(Box::new(start), lower_op, Box::new(lhs.clone())); + let upper_bound = ast::Expr::Binary(Box::new(lhs), upper_op, Box::new(end)); + + if *not { + *expr = ast::Expr::Binary( + Box::new(lower_bound), + ast::Operator::Or, + Box::new(upper_bound), + ); + } else { + *expr = ast::Expr::Binary( + Box::new(lower_bound), + ast::Operator::And, + Box::new(upper_bound), + ); + } + } + _ => {} + } + + Ok(WalkControl::Continue) + }, + ) } trait TakeOwnership { diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 92aa619bc..db5735381 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -152,65 +152,39 @@ pub fn bind_column_references( referenced_tables: &mut TableReferences, result_columns: Option<&[ResultSetColumn]>, connection: &Arc, -) -> Result<()> { - walk_expr_mut(top_level_expr, &mut |expr: &mut Expr| -> Result<()> { - match expr { - Expr::Id(id) => { - // true and false are special constants that are effectively aliases for 1 and 0 - // and not identifiers of columns - let id_bytes = id.as_str().as_bytes(); - match_ignore_ascii_case!(match id_bytes { - b"true" | b"false" => { - return Ok(()); - } - _ => {} - }); - let normalized_id = normalize_ident(id.as_str()); - - if !referenced_tables.joined_tables().is_empty() { - if let Some(row_id_expr) = parse_row_id( - &normalized_id, - referenced_tables.joined_tables()[0].internal_id, - || referenced_tables.joined_tables().len() != 1, - )? { - *expr = row_id_expr; - - return Ok(()); - } - } - let mut match_result = None; - - // First check joined tables - for joined_table in referenced_tables.joined_tables().iter() { - let col_idx = joined_table.table.columns().iter().position(|c| { - c.name - .as_ref() - .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id)) - }); - if col_idx.is_some() { - if match_result.is_some() { - crate::bail_parse_error!("Column {} is ambiguous", id.as_str()); +) -> Result { + walk_expr_mut( + top_level_expr, + &mut |expr: &mut Expr| -> Result { + match expr { + Expr::Id(id) => { + // true and false are special constants that are effectively aliases for 1 and 0 + // and not identifiers of columns + let id_bytes = id.as_str().as_bytes(); + match_ignore_ascii_case!(match id_bytes { + b"true" | b"false" => { + return Ok(WalkControl::Continue); } - let col = joined_table.table.columns().get(col_idx.unwrap()).unwrap(); - match_result = Some(( - joined_table.internal_id, - col_idx.unwrap(), - col.is_rowid_alias, - )); - } - } + _ => {} + }); + let normalized_id = normalize_ident(id.as_str()); - // Then check outer query references, if we still didn't find something. - // Normally finding multiple matches for a non-qualified column is an error (column x is ambiguous) - // but in the case of subqueries, the inner query takes precedence. - // For example: - // SELECT * FROM t WHERE x = (SELECT x FROM t2) - // In this case, there is no ambiguity: - // - x in the outer query refers to t.x, - // - x in the inner query refers to t2.x. - if match_result.is_none() { - for outer_ref in referenced_tables.outer_query_refs().iter() { - let col_idx = outer_ref.table.columns().iter().position(|c| { + if !referenced_tables.joined_tables().is_empty() { + if let Some(row_id_expr) = parse_row_id( + &normalized_id, + referenced_tables.joined_tables()[0].internal_id, + || referenced_tables.joined_tables().len() != 1, + )? { + *expr = row_id_expr; + + return Ok(WalkControl::Continue); + } + } + let mut match_result = None; + + // First check joined tables + for joined_table in referenced_tables.joined_tables().iter() { + let col_idx = joined_table.table.columns().iter().position(|c| { c.name .as_ref() .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id)) @@ -219,151 +193,183 @@ pub fn bind_column_references( if match_result.is_some() { crate::bail_parse_error!("Column {} is ambiguous", id.as_str()); } - let col = outer_ref.table.columns().get(col_idx.unwrap()).unwrap(); - match_result = - Some((outer_ref.internal_id, col_idx.unwrap(), col.is_rowid_alias)); + let col = joined_table.table.columns().get(col_idx.unwrap()).unwrap(); + match_result = Some(( + joined_table.internal_id, + col_idx.unwrap(), + col.is_rowid_alias, + )); } } - } - if let Some((table_id, col_idx, is_rowid_alias)) = match_result { - *expr = Expr::Column { - database: None, // TODO: support different databases - table: table_id, - column: col_idx, - is_rowid_alias, - }; - referenced_tables.mark_column_used(table_id, col_idx); - return Ok(()); - } - - if let Some(result_columns) = result_columns { - for result_column in result_columns.iter() { - if result_column - .name(referenced_tables) - .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id)) - { - *expr = result_column.expr.clone(); - return Ok(()); + // Then check outer query references, if we still didn't find something. + // Normally finding multiple matches for a non-qualified column is an error (column x is ambiguous) + // but in the case of subqueries, the inner query takes precedence. + // For example: + // SELECT * FROM t WHERE x = (SELECT x FROM t2) + // In this case, there is no ambiguity: + // - x in the outer query refers to t.x, + // - x in the inner query refers to t2.x. + if match_result.is_none() { + for outer_ref in referenced_tables.outer_query_refs().iter() { + let col_idx = outer_ref.table.columns().iter().position(|c| { + c.name + .as_ref() + .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id)) + }); + if col_idx.is_some() { + if match_result.is_some() { + crate::bail_parse_error!("Column {} is ambiguous", id.as_str()); + } + let col = outer_ref.table.columns().get(col_idx.unwrap()).unwrap(); + match_result = Some(( + outer_ref.internal_id, + col_idx.unwrap(), + col.is_rowid_alias, + )); + } } } + + if let Some((table_id, col_idx, is_rowid_alias)) = match_result { + *expr = Expr::Column { + database: None, // TODO: support different databases + table: table_id, + column: col_idx, + is_rowid_alias, + }; + referenced_tables.mark_column_used(table_id, col_idx); + return Ok(WalkControl::Continue); + } + + if let Some(result_columns) = result_columns { + for result_column in result_columns.iter() { + if result_column + .name(referenced_tables) + .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id)) + { + *expr = result_column.expr.clone(); + return Ok(WalkControl::Continue); + } + } + } + // SQLite behavior: Only double-quoted identifiers get fallback to string literals + // Single quotes are handled as literals earlier, unquoted identifiers must resolve to columns + if id.is_double_quoted() { + // Convert failed double-quoted identifier to string literal + *expr = Expr::Literal(Literal::String(id.as_str().to_string())); + Ok(WalkControl::Continue) + } else { + // Unquoted identifiers must resolve to columns - no fallback + crate::bail_parse_error!("no such column: {}", id.as_str()) + } } - // SQLite behavior: Only double-quoted identifiers get fallback to string literals - // Single quotes are handled as literals earlier, unquoted identifiers must resolve to columns - if id.is_double_quoted() { - // Convert failed double-quoted identifier to string literal - *expr = Expr::Literal(Literal::String(id.as_str().to_string())); - Ok(()) - } else { - // Unquoted identifiers must resolve to columns - no fallback - crate::bail_parse_error!("no such column: {}", id.as_str()) - } - } - Expr::Qualified(tbl, id) => { - let normalized_table_name = normalize_ident(tbl.as_str()); - let matching_tbl = referenced_tables - .find_table_and_internal_id_by_identifier(&normalized_table_name); - if matching_tbl.is_none() { - crate::bail_parse_error!("no such table: {}", normalized_table_name); - } - let (tbl_id, tbl) = matching_tbl.unwrap(); - let normalized_id = normalize_ident(id.as_str()); + Expr::Qualified(tbl, id) => { + let normalized_table_name = normalize_ident(tbl.as_str()); + let matching_tbl = referenced_tables + .find_table_and_internal_id_by_identifier(&normalized_table_name); + if matching_tbl.is_none() { + crate::bail_parse_error!("no such table: {}", normalized_table_name); + } + let (tbl_id, tbl) = matching_tbl.unwrap(); + let normalized_id = normalize_ident(id.as_str()); - if let Some(row_id_expr) = parse_row_id(&normalized_id, tbl_id, || false)? { - *expr = row_id_expr; + if let Some(row_id_expr) = parse_row_id(&normalized_id, tbl_id, || false)? { + *expr = row_id_expr; - return Ok(()); - } - let col_idx = tbl.columns().iter().position(|c| { - c.name - .as_ref() - .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id)) - }); - let Some(col_idx) = col_idx else { - crate::bail_parse_error!("no such column: {}", normalized_id); - }; - let col = tbl.columns().get(col_idx).unwrap(); - *expr = Expr::Column { - database: None, // TODO: support different databases - table: tbl_id, - column: col_idx, - is_rowid_alias: col.is_rowid_alias, - }; - referenced_tables.mark_column_used(tbl_id, col_idx); - Ok(()) - } - Expr::DoublyQualified(db_name, tbl_name, col_name) => { - let normalized_col_name = normalize_ident(col_name.as_str()); - - // Create a QualifiedName and use existing resolve_database_id method - let qualified_name = ast::QualifiedName { - db_name: Some(db_name.clone()), - name: tbl_name.clone(), - alias: None, - }; - let database_id = connection.resolve_database_id(&qualified_name)?; - - // Get the table from the specified database - let table = connection - .with_schema(database_id, |schema| schema.get_table(tbl_name.as_str())) - .ok_or_else(|| { - crate::LimboError::ParseError(format!( - "no such table: {}.{}", - db_name.as_str(), - tbl_name.as_str() - )) - })?; - - // Find the column in the table - let col_idx = table - .columns() - .iter() - .position(|c| { + return Ok(WalkControl::Continue); + } + let col_idx = tbl.columns().iter().position(|c| { c.name .as_ref() - .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_col_name)) - }) - .ok_or_else(|| { - crate::LimboError::ParseError(format!( - "Column: {}.{}.{} not found", - db_name.as_str(), - tbl_name.as_str(), - col_name.as_str() - )) - })?; - - let col = table.columns().get(col_idx).unwrap(); - - // Check if this is a rowid alias - let is_rowid_alias = col.is_rowid_alias; - - // Convert to Column expression - since this is a cross-database reference, - // we need to create a synthetic table reference for it - // For now, we'll error if the table isn't already in the referenced tables - let normalized_tbl_name = normalize_ident(tbl_name.as_str()); - let matching_tbl = referenced_tables - .find_table_and_internal_id_by_identifier(&normalized_tbl_name); - - if let Some((tbl_id, _)) = matching_tbl { - // Table is already in referenced tables, use existing internal ID + .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id)) + }); + let Some(col_idx) = col_idx else { + crate::bail_parse_error!("no such column: {}", normalized_id); + }; + let col = tbl.columns().get(col_idx).unwrap(); *expr = Expr::Column { - database: Some(database_id), + database: None, // TODO: support different databases table: tbl_id, column: col_idx, - is_rowid_alias, + is_rowid_alias: col.is_rowid_alias, }; referenced_tables.mark_column_used(tbl_id, col_idx); - } else { - return Err(crate::LimboError::ParseError(format!( - "table {normalized_tbl_name} is not in FROM clause - cross-database column references require the table to be explicitly joined" - ))); + Ok(WalkControl::Continue) } + Expr::DoublyQualified(db_name, tbl_name, col_name) => { + let normalized_col_name = normalize_ident(col_name.as_str()); - Ok(()) + // Create a QualifiedName and use existing resolve_database_id method + let qualified_name = ast::QualifiedName { + db_name: Some(db_name.clone()), + name: tbl_name.clone(), + alias: None, + }; + let database_id = connection.resolve_database_id(&qualified_name)?; + + // Get the table from the specified database + let table = connection + .with_schema(database_id, |schema| schema.get_table(tbl_name.as_str())) + .ok_or_else(|| { + crate::LimboError::ParseError(format!( + "no such table: {}.{}", + db_name.as_str(), + tbl_name.as_str() + )) + })?; + + // Find the column in the table + let col_idx = table + .columns() + .iter() + .position(|c| { + c.name + .as_ref() + .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_col_name)) + }) + .ok_or_else(|| { + crate::LimboError::ParseError(format!( + "Column: {}.{}.{} not found", + db_name.as_str(), + tbl_name.as_str(), + col_name.as_str() + )) + })?; + + let col = table.columns().get(col_idx).unwrap(); + + // Check if this is a rowid alias + let is_rowid_alias = col.is_rowid_alias; + + // Convert to Column expression - since this is a cross-database reference, + // we need to create a synthetic table reference for it + // For now, we'll error if the table isn't already in the referenced tables + let normalized_tbl_name = normalize_ident(tbl_name.as_str()); + let matching_tbl = referenced_tables + .find_table_and_internal_id_by_identifier(&normalized_tbl_name); + + if let Some((tbl_id, _)) = matching_tbl { + // Table is already in referenced tables, use existing internal ID + *expr = Expr::Column { + database: Some(database_id), + table: tbl_id, + column: col_idx, + is_rowid_alias, + }; + referenced_tables.mark_column_used(tbl_id, col_idx); + } else { + return Err(crate::LimboError::ParseError(format!( + "table {normalized_tbl_name} is not in FROM clause - cross-database column references require the table to be explicitly joined" + ))); + } + + Ok(WalkControl::Continue) + } + _ => Ok(WalkControl::Continue), } - _ => Ok(()), - } - }) + }, + ) } #[allow(clippy::too_many_arguments)] diff --git a/core/translate/upsert.rs b/core/translate/upsert.rs index 145a55eda..5b208ed38 100644 --- a/core/translate/upsert.rs +++ b/core/translate/upsert.rs @@ -2,6 +2,7 @@ use std::{collections::HashMap, sync::Arc}; use turso_parser::ast::{self, Upsert}; +use crate::translate::expr::WalkControl; use crate::{ bail_parse_error, error::SQLITE_CONSTRAINT_NOTNULL, @@ -524,7 +525,7 @@ fn rewrite_upsert_expr_in_place( current_start: usize, conflict_rowid_reg: usize, insertion: &Insertion, -) -> crate::Result<()> { +) -> crate::Result { use ast::Expr; // helper: return the CURRENT-row register for a column (including rowid alias) @@ -535,34 +536,37 @@ fn rewrite_upsert_expr_in_place( let (idx, _c) = table.get_column_by_name(&normalize_ident(name))?; Some(current_start + idx) }; - walk_expr_mut(e, &mut |expr: &mut ast::Expr| -> crate::Result<()> { - match expr { - // EXCLUDED.x -> insertion register - Expr::Qualified(ns, ast::Name::Ident(c)) => { - if ns.as_str().eq_ignore_ascii_case("excluded") { - let Some(reg) = insertion.get_col_mapping_by_name(c.as_str()) else { - bail_parse_error!("no such column in EXCLUDED: {}", c); - }; - *expr = Expr::Register(reg.register); - } else if ns.as_str().eq_ignore_ascii_case(table_name) - // t.x -> CURRENT, only if t matches the target table name (never "excluded") - { - if let Some(reg) = col_reg(c.as_str()) { + walk_expr_mut( + e, + &mut |expr: &mut ast::Expr| -> crate::Result { + match expr { + // EXCLUDED.x -> insertion register + Expr::Qualified(ns, ast::Name::Ident(c)) => { + if ns.as_str().eq_ignore_ascii_case("excluded") { + let Some(reg) = insertion.get_col_mapping_by_name(c.as_str()) else { + bail_parse_error!("no such column in EXCLUDED: {}", c); + }; + *expr = Expr::Register(reg.register); + } else if ns.as_str().eq_ignore_ascii_case(table_name) + // t.x -> CURRENT, only if t matches the target table name (never "excluded") + { + if let Some(reg) = col_reg(c.as_str()) { + *expr = Expr::Register(reg); + } + } + } + // Unqualified column id -> CURRENT + Expr::Id(ast::Name::Ident(name)) => { + if let Some(reg) = col_reg(name.as_str()) { *expr = Expr::Register(reg); } } - } - // Unqualified column id -> CURRENT - Expr::Id(ast::Name::Ident(name)) => { - if let Some(reg) = col_reg(name.as_str()) { - *expr = Expr::Register(reg); + Expr::RowId { .. } => { + *expr = Expr::Register(conflict_rowid_reg); } + _ => {} } - Expr::RowId { .. } => { - *expr = Expr::Register(conflict_rowid_reg); - } - _ => {} - } - Ok(()) - }) + Ok(WalkControl::Continue) + }, + ) }