refactor: use walk_expr_mut() in bind_column_references()

This commit is contained in:
Jussi Saurio
2025-05-23 15:56:49 +03:00
parent 40a4d162bc
commit 2ab5c5f6a9

View File

@@ -10,6 +10,7 @@ use super::{
use crate::{
function::Func,
schema::{Schema, Table},
translate::expr::walk_expr_mut,
util::{exprs_are_equivalent, normalize_ident, vtable_args},
vdbe::BranchOffset,
Result,
@@ -99,202 +100,108 @@ pub fn resolve_aggregates(expr: &Expr, aggs: &mut Vec<Aggregate>) -> Result<bool
}
pub fn bind_column_references(
expr: &mut Expr,
top_level_expr: &mut Expr,
referenced_tables: &mut [TableReference],
result_columns: Option<&[ResultSetColumn]>,
) -> Result<()> {
match expr {
Expr::Id(id) => {
// true and false are special constants that are effectively aliases for 1 and 0
// and not identifiers of columns
if id.0.eq_ignore_ascii_case("true") || id.0.eq_ignore_ascii_case("false") {
return Ok(());
}
let normalized_id = normalize_ident(id.0.as_str());
walk_expr_mut(top_level_expr, &mut |expr: &mut Expr| -> Result<()> {
match expr {
Expr::Id(id) => {
// true and false are special constants that are effectively aliases for 1 and 0
// and not identifiers of columns
if id.0.eq_ignore_ascii_case("true") || id.0.eq_ignore_ascii_case("false") {
return Ok(());
}
let normalized_id = normalize_ident(id.0.as_str());
if !referenced_tables.is_empty() {
if let Some(row_id_expr) =
parse_row_id(&normalized_id, 0, || referenced_tables.len() != 1)?
{
if !referenced_tables.is_empty() {
if let Some(row_id_expr) =
parse_row_id(&normalized_id, 0, || referenced_tables.len() != 1)?
{
*expr = row_id_expr;
return Ok(());
}
}
let mut match_result = None;
for (tbl_idx, table) in referenced_tables.iter().enumerate() {
let col_idx = table.columns().iter().position(|c| {
c.name
.as_ref()
.map_or(false, |name| name.eq_ignore_ascii_case(&normalized_id))
});
if col_idx.is_some() {
if match_result.is_some() {
crate::bail_parse_error!("Column {} is ambiguous", id.0);
}
let col = table.columns().get(col_idx.unwrap()).unwrap();
match_result = Some((tbl_idx, col_idx.unwrap(), col.is_rowid_alias));
}
}
if let Some((tbl_idx, col_idx, is_rowid_alias)) = match_result {
*expr = Expr::Column {
database: None, // TODO: support different databases
table: tbl_idx,
column: col_idx,
is_rowid_alias,
};
referenced_tables[tbl_idx].mark_column_used(col_idx);
return Ok(());
}
if let Some(result_columns) = result_columns {
for result_column in result_columns.iter() {
if result_column
.name(referenced_tables)
.map_or(false, |name| name.eq_ignore_ascii_case(&normalized_id))
{
*expr = result_column.expr.clone();
return Ok(());
}
}
}
crate::bail_parse_error!("Column {} not found", id.0);
}
Expr::Qualified(tbl, id) => {
let normalized_table_name = normalize_ident(tbl.0.as_str());
let matching_tbl_idx = referenced_tables
.iter()
.position(|t| t.identifier.eq_ignore_ascii_case(&normalized_table_name));
if matching_tbl_idx.is_none() {
crate::bail_parse_error!("Table {} not found", normalized_table_name);
}
let tbl_idx = matching_tbl_idx.unwrap();
let normalized_id = normalize_ident(id.0.as_str());
if let Some(row_id_expr) = parse_row_id(&normalized_id, tbl_idx, || false)? {
*expr = row_id_expr;
return Ok(());
}
}
let mut match_result = None;
for (tbl_idx, table) in referenced_tables.iter().enumerate() {
let col_idx = table.columns().iter().position(|c| {
let col_idx = referenced_tables[tbl_idx].columns().iter().position(|c| {
c.name
.as_ref()
.map_or(false, |name| name.eq_ignore_ascii_case(&normalized_id))
});
if col_idx.is_some() {
if match_result.is_some() {
crate::bail_parse_error!("Column {} is ambiguous", id.0);
}
let col = table.columns().get(col_idx.unwrap()).unwrap();
match_result = Some((tbl_idx, col_idx.unwrap(), col.is_rowid_alias));
if col_idx.is_none() {
crate::bail_parse_error!("Column {} not found", normalized_id);
}
}
if let Some((tbl_idx, col_idx, is_rowid_alias)) = match_result {
let col = referenced_tables[tbl_idx]
.columns()
.get(col_idx.unwrap())
.unwrap();
*expr = Expr::Column {
database: None, // TODO: support different databases
table: tbl_idx,
column: col_idx,
is_rowid_alias,
column: col_idx.unwrap(),
is_rowid_alias: col.is_rowid_alias,
};
referenced_tables[tbl_idx].mark_column_used(col_idx);
return Ok(());
referenced_tables[tbl_idx].mark_column_used(col_idx.unwrap());
Ok(())
}
if let Some(result_columns) = result_columns {
for result_column in result_columns.iter() {
if result_column
.name(referenced_tables)
.map_or(false, |name| name.eq_ignore_ascii_case(&normalized_id))
{
*expr = result_column.expr.clone();
return Ok(());
}
}
}
crate::bail_parse_error!("Column {} not found", id.0);
_ => Ok(()),
}
Expr::Qualified(tbl, id) => {
let normalized_table_name = normalize_ident(tbl.0.as_str());
let matching_tbl_idx = referenced_tables
.iter()
.position(|t| t.identifier.eq_ignore_ascii_case(&normalized_table_name));
if matching_tbl_idx.is_none() {
crate::bail_parse_error!("Table {} not found", normalized_table_name);
}
let tbl_idx = matching_tbl_idx.unwrap();
let normalized_id = normalize_ident(id.0.as_str());
if let Some(row_id_expr) = parse_row_id(&normalized_id, tbl_idx, || false)? {
*expr = row_id_expr;
return Ok(());
}
let col_idx = referenced_tables[tbl_idx].columns().iter().position(|c| {
c.name
.as_ref()
.map_or(false, |name| name.eq_ignore_ascii_case(&normalized_id))
});
if col_idx.is_none() {
crate::bail_parse_error!("Column {} not found", normalized_id);
}
let col = referenced_tables[tbl_idx]
.columns()
.get(col_idx.unwrap())
.unwrap();
*expr = Expr::Column {
database: None, // TODO: support different databases
table: tbl_idx,
column: col_idx.unwrap(),
is_rowid_alias: col.is_rowid_alias,
};
referenced_tables[tbl_idx].mark_column_used(col_idx.unwrap());
Ok(())
}
Expr::Between {
lhs,
not: _,
start,
end,
} => {
bind_column_references(lhs, referenced_tables, result_columns)?;
bind_column_references(start, referenced_tables, result_columns)?;
bind_column_references(end, referenced_tables, result_columns)?;
Ok(())
}
Expr::Binary(expr, _operator, expr1) => {
bind_column_references(expr, referenced_tables, result_columns)?;
bind_column_references(expr1, referenced_tables, result_columns)?;
Ok(())
}
Expr::Case {
base,
when_then_pairs,
else_expr,
} => {
if let Some(base) = base {
bind_column_references(base, referenced_tables, result_columns)?;
}
for (when, then) in when_then_pairs {
bind_column_references(when, referenced_tables, result_columns)?;
bind_column_references(then, referenced_tables, result_columns)?;
}
if let Some(else_expr) = else_expr {
bind_column_references(else_expr, referenced_tables, result_columns)?;
}
Ok(())
}
Expr::Cast { expr, type_name: _ } => {
bind_column_references(expr, referenced_tables, result_columns)
}
Expr::Collate(expr, _string) => {
bind_column_references(expr, referenced_tables, result_columns)
}
Expr::FunctionCall {
name: _,
distinctness: _,
args,
order_by: _,
filter_over: _,
} => {
if let Some(args) = args {
for arg in args {
bind_column_references(arg, referenced_tables, result_columns)?;
}
}
Ok(())
}
// Already bound earlier
Expr::Column { .. } | Expr::RowId { .. } => Ok(()),
Expr::DoublyQualified(_, _, _) => todo!(),
Expr::Exists(_) => todo!(),
Expr::FunctionCallStar { .. } => Ok(()),
Expr::InList { lhs, not: _, rhs } => {
bind_column_references(lhs, referenced_tables, result_columns)?;
if let Some(rhs) = rhs {
for arg in rhs {
bind_column_references(arg, referenced_tables, result_columns)?;
}
}
Ok(())
}
Expr::InSelect { .. } => todo!(),
Expr::InTable { .. } => todo!(),
Expr::IsNull(expr) => {
bind_column_references(expr, referenced_tables, result_columns)?;
Ok(())
}
Expr::Like { lhs, rhs, .. } => {
bind_column_references(lhs, referenced_tables, result_columns)?;
bind_column_references(rhs, referenced_tables, result_columns)?;
Ok(())
}
Expr::Literal(_) => Ok(()),
Expr::Name(_) => todo!(),
Expr::NotNull(expr) => {
bind_column_references(expr, referenced_tables, result_columns)?;
Ok(())
}
Expr::Parenthesized(expr) => {
for e in expr.iter_mut() {
bind_column_references(e, referenced_tables, result_columns)?;
}
Ok(())
}
Expr::Raise(_, _) => todo!(),
Expr::Subquery(_) => todo!(),
Expr::Unary(_, expr) => {
bind_column_references(expr, referenced_tables, result_columns)?;
Ok(())
}
Expr::Variable(_) => Ok(()),
}
})
}
fn parse_from_clause_table<'a>(