add helper for simple binding of values in the AST

This commit is contained in:
Nikita Sivukhin
2025-10-27 18:18:48 +04:00
parent 35b96ae8d8
commit d65b7eddc0

View File

@@ -2,7 +2,9 @@
use crate::incremental::view::IncrementalView;
use crate::numeric::StrToF64;
use crate::translate::emitter::TransactionMode;
use crate::translate::expr::WalkControl;
use crate::translate::expr::{walk_expr_mut, WalkControl};
use crate::translate::plan::JoinedTable;
use crate::translate::planner::parse_row_id;
use crate::types::IOResult;
use crate::{
schema::{self, BTreeTable, Column, Schema, Table, Type, DBSP_TABLE_PREFIX},
@@ -318,6 +320,154 @@ pub fn check_literal_equivalency(lhs: &Literal, rhs: &Literal) -> bool {
}
}
/// bind AST identifiers to either Column or Rowid if possible
pub fn simple_bind_expr(
schema: &Schema,
joined_table: &JoinedTable,
result_columns: &[ast::ResultColumn],
expr: &mut ast::Expr,
) -> Result<()> {
let internal_id = joined_table.internal_id;
walk_expr_mut(expr, &mut |expr: &mut ast::Expr| -> Result<WalkControl> {
#[allow(clippy::single_match)]
match expr {
Expr::Id(id) => {
let normalized_id = normalize_ident(id.as_str());
for result_column in result_columns.iter() {
if let ast::ResultColumn::Expr(result, Some(ast::As::As(alias))) = result_column
{
if alias.as_str().eq_ignore_ascii_case(&normalized_id) {
*expr = *result.clone();
return Ok(WalkControl::Continue);
}
}
}
let col_idx = joined_table.columns().iter().position(|c| {
c.name
.as_ref()
.is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id))
});
if let Some(col_idx) = col_idx {
let col = joined_table.table.columns().get(col_idx).unwrap();
*expr = ast::Expr::Column {
database: None,
table: internal_id,
column: col_idx,
is_rowid_alias: col.is_rowid_alias,
};
} else {
// only if we haven't found a match, check for explicit rowid reference
let is_btree_table = matches!(joined_table.table, Table::BTree(_));
if is_btree_table {
if let Some(rowid) = parse_row_id(&normalized_id, internal_id, || false)? {
*expr = rowid;
}
}
}
}
_ => {}
}
Ok(WalkControl::Continue)
});
Ok(())
}
pub fn try_substitute_parameters(
pattern: &Expr,
parameters: &HashMap<i32, Expr>,
) -> Option<Box<Expr>> {
match pattern {
Expr::FunctionCall {
name,
distinctness,
args,
order_by,
filter_over,
} => {
let mut substituted = Vec::new();
for arg in args {
substituted.push(try_substitute_parameters(arg, parameters)?);
}
Some(Box::new(Expr::FunctionCall {
args: substituted,
distinctness: *distinctness,
name: name.clone(),
order_by: order_by.clone(),
filter_over: filter_over.clone(),
}))
}
Expr::Variable(var) => {
let Ok(var) = var.parse::<i32>() else {
return None;
};
Some(Box::new(parameters.get(&var)?.clone()))
}
_ => Some(Box::new(pattern.clone())),
}
}
pub fn try_capture_parameters(pattern: &Expr, query: &Expr) -> Option<HashMap<i32, Expr>> {
let mut captured = HashMap::new();
match (pattern, query) {
(
Expr::FunctionCall {
name: name1,
distinctness: distinct1,
args: args1,
order_by: order1,
filter_over: filter1,
},
Expr::FunctionCall {
name: name2,
distinctness: distinct2,
args: args2,
order_by: order2,
filter_over: filter2,
},
) => {
if !name1.as_str().eq_ignore_ascii_case(name2.as_str()) {
return None;
}
if distinct1.is_some() || distinct2.is_some() {
return None;
}
if !order1.is_empty() || !order2.is_empty() {
return None;
}
if filter1.filter_clause.is_some() || filter1.over_clause.is_some() {
return None;
}
if filter2.filter_clause.is_some() || filter2.over_clause.is_some() {
return None;
}
for (arg1, arg2) in args1.iter().zip(args2.iter()) {
let result = try_capture_parameters(arg1, arg2)?;
captured.extend(result);
}
Some(captured)
}
(Expr::Variable(var), expr) => {
let Ok(var) = var.parse::<i32>() else {
return None;
};
captured.insert(var, expr.clone());
Some(captured)
}
(
Expr::Id(_) | Expr::Name(_) | Expr::Column { .. },
Expr::Id(_) | Expr::Name(_) | Expr::Column { .. },
) => {
if pattern == query {
Some(captured)
} else {
None
}
}
(_, _) => None,
}
}
/// This function is used to determine whether two expressions are logically
/// equivalent in the context of queries, even if their representations
/// differ. e.g.: `SUM(x)` and `sum(x)`, `x + y` and `y + x`