From d65b7eddc06424fa806955be63fb73e1b64e1a48 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 27 Oct 2025 18:18:48 +0400 Subject: [PATCH] add helper for simple binding of values in the AST --- core/util.rs | 152 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 151 insertions(+), 1 deletion(-) diff --git a/core/util.rs b/core/util.rs index 8f0b5d9aa..745eb7a0b 100644 --- a/core/util.rs +++ b/core/util.rs @@ -2,7 +2,9 @@ use crate::incremental::view::IncrementalView; use crate::numeric::StrToF64; use crate::translate::emitter::TransactionMode; -use crate::translate::expr::WalkControl; +use crate::translate::expr::{walk_expr_mut, WalkControl}; +use crate::translate::plan::JoinedTable; +use crate::translate::planner::parse_row_id; use crate::types::IOResult; use crate::{ schema::{self, BTreeTable, Column, Schema, Table, Type, DBSP_TABLE_PREFIX}, @@ -318,6 +320,154 @@ pub fn check_literal_equivalency(lhs: &Literal, rhs: &Literal) -> bool { } } +/// bind AST identifiers to either Column or Rowid if possible +pub fn simple_bind_expr( + schema: &Schema, + joined_table: &JoinedTable, + result_columns: &[ast::ResultColumn], + expr: &mut ast::Expr, +) -> Result<()> { + let internal_id = joined_table.internal_id; + walk_expr_mut(expr, &mut |expr: &mut ast::Expr| -> Result { + #[allow(clippy::single_match)] + match expr { + Expr::Id(id) => { + let normalized_id = normalize_ident(id.as_str()); + + for result_column in result_columns.iter() { + if let ast::ResultColumn::Expr(result, Some(ast::As::As(alias))) = result_column + { + if alias.as_str().eq_ignore_ascii_case(&normalized_id) { + *expr = *result.clone(); + return Ok(WalkControl::Continue); + } + } + } + let col_idx = joined_table.columns().iter().position(|c| { + c.name + .as_ref() + .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id)) + }); + if let Some(col_idx) = col_idx { + let col = joined_table.table.columns().get(col_idx).unwrap(); + *expr = ast::Expr::Column { + database: None, + table: internal_id, + column: col_idx, + is_rowid_alias: col.is_rowid_alias, + }; + } else { + // only if we haven't found a match, check for explicit rowid reference + let is_btree_table = matches!(joined_table.table, Table::BTree(_)); + if is_btree_table { + if let Some(rowid) = parse_row_id(&normalized_id, internal_id, || false)? { + *expr = rowid; + } + } + } + } + _ => {} + } + Ok(WalkControl::Continue) + }); + Ok(()) +} + +pub fn try_substitute_parameters( + pattern: &Expr, + parameters: &HashMap, +) -> Option> { + match pattern { + Expr::FunctionCall { + name, + distinctness, + args, + order_by, + filter_over, + } => { + let mut substituted = Vec::new(); + for arg in args { + substituted.push(try_substitute_parameters(arg, parameters)?); + } + Some(Box::new(Expr::FunctionCall { + args: substituted, + distinctness: *distinctness, + name: name.clone(), + order_by: order_by.clone(), + filter_over: filter_over.clone(), + })) + } + Expr::Variable(var) => { + let Ok(var) = var.parse::() else { + return None; + }; + Some(Box::new(parameters.get(&var)?.clone())) + } + _ => Some(Box::new(pattern.clone())), + } +} + +pub fn try_capture_parameters(pattern: &Expr, query: &Expr) -> Option> { + let mut captured = HashMap::new(); + match (pattern, query) { + ( + Expr::FunctionCall { + name: name1, + distinctness: distinct1, + args: args1, + order_by: order1, + filter_over: filter1, + }, + Expr::FunctionCall { + name: name2, + distinctness: distinct2, + args: args2, + order_by: order2, + filter_over: filter2, + }, + ) => { + if !name1.as_str().eq_ignore_ascii_case(name2.as_str()) { + return None; + } + if distinct1.is_some() || distinct2.is_some() { + return None; + } + if !order1.is_empty() || !order2.is_empty() { + return None; + } + if filter1.filter_clause.is_some() || filter1.over_clause.is_some() { + return None; + } + if filter2.filter_clause.is_some() || filter2.over_clause.is_some() { + return None; + } + for (arg1, arg2) in args1.iter().zip(args2.iter()) { + let result = try_capture_parameters(arg1, arg2)?; + captured.extend(result); + } + Some(captured) + } + (Expr::Variable(var), expr) => { + let Ok(var) = var.parse::() else { + return None; + }; + captured.insert(var, expr.clone()); + Some(captured) + } + ( + Expr::Id(_) | Expr::Name(_) | Expr::Column { .. }, + Expr::Id(_) | Expr::Name(_) | Expr::Column { .. }, + ) => { + if pattern == query { + Some(captured) + } else { + None + } + } + (_, _) => None, + } +} + /// This function is used to determine whether two expressions are logically /// equivalent in the context of queries, even if their representations /// differ. e.g.: `SUM(x)` and `sum(x)`, `x + y` and `y + x`