Fix not using index when expr is paren-wrapped: e.g. SELECT * FROM t WHERE (x > 5)

This commit is contained in:
Jussi Saurio
2025-04-11 10:53:35 +03:00
parent c6bea835f9
commit 0d97e2a311

View File

@@ -797,6 +797,47 @@ enum BinaryExprSide {
Rhs,
}
/// Recursively unwrap parentheses from an expression
/// e.g. (((t.x > 5))) -> t.x > 5
fn unwrap_parens<T>(expr: T) -> Result<T>
where
T: UnwrapParens,
{
expr.unwrap_parens()
}
trait UnwrapParens {
fn unwrap_parens(self) -> Result<Self>
where
Self: Sized;
}
impl UnwrapParens for &ast::Expr {
fn unwrap_parens(self) -> Result<Self> {
match self {
ast::Expr::Column { .. } => Ok(self),
ast::Expr::Parenthesized(exprs) => match exprs.len() {
1 => unwrap_parens(exprs.first().unwrap()),
_ => crate::bail_parse_error!("expected single expression in parentheses"),
},
_ => Ok(self),
}
}
}
impl UnwrapParens for ast::Expr {
fn unwrap_parens(self) -> Result<Self> {
match self {
ast::Expr::Column { .. } => Ok(self),
ast::Expr::Parenthesized(mut exprs) => match exprs.len() {
1 => unwrap_parens(exprs.pop().unwrap()),
_ => crate::bail_parse_error!("expected single expression in parentheses"),
},
_ => Ok(self),
}
}
}
/// Get the position of a column in an index
/// For example, if there is an index on table T(x,y) then y's position in the index is 1.
fn get_column_position_in_index(
@@ -804,20 +845,20 @@ fn get_column_position_in_index(
table_index: usize,
table_reference: &TableReference,
index: &Arc<Index>,
) -> Option<usize> {
let ast::Expr::Column { table, column, .. } = expr else {
return None;
) -> Result<Option<usize>> {
let ast::Expr::Column { table, column, .. } = unwrap_parens(expr)? else {
return Ok(None);
};
if *table != table_index {
return None;
return Ok(None);
}
let Some(column) = table_reference.table.get_column_at(*column) else {
return None;
return Ok(None);
};
index
Ok(index
.columns
.iter()
.position(|col| Some(&col.name) == column.name.as_ref())
.position(|col| Some(&col.name) == column.name.as_ref()))
}
/// Find all [IndexConstraint]s for a given WHERE clause
@@ -839,7 +880,7 @@ fn find_index_constraints(
continue;
}
// Skip terms that are not binary comparisons
let ast::Expr::Binary(lhs, operator, rhs) = &term.expr else {
let ast::Expr::Binary(lhs, operator, rhs) = unwrap_parens(&term.expr)? else {
continue;
};
// Only consider index scans for binary ops that are comparisons
@@ -868,7 +909,7 @@ fn find_index_constraints(
// Check if lhs is a column that is in the i'th position of the index
if Some(position_in_index)
== get_column_position_in_index(lhs, table_index, table_reference, index)
== get_column_position_in_index(lhs, table_index, table_reference, index)?
{
out_constraints.push(IndexConstraint {
operator: *operator,
@@ -879,7 +920,7 @@ fn find_index_constraints(
}
// Check if rhs is a column that is in the i'th position of the index
if Some(position_in_index)
== get_column_position_in_index(rhs, table_index, table_reference, index)
== get_column_position_in_index(rhs, table_index, table_reference, index)?
{
out_constraints.push(IndexConstraint {
operator: opposite_cmp_op(*operator), // swap the operator since e.g. if condition is 5 >= x, we want to use x <= 5
@@ -931,7 +972,8 @@ pub fn build_seek_def_from_index_constraints(
// Extract the other expression from the binary WhereTerm (i.e. the one being compared to the index column)
let (idx, side) = constraint.position_in_where_clause;
let where_term = &mut where_clause[idx];
let ast::Expr::Binary(lhs, _, rhs) = where_term.expr.take_ownership() else {
let ast::Expr::Binary(lhs, _, rhs) = unwrap_parens(where_term.expr.take_ownership())?
else {
crate::bail_parse_error!("expected binary expression");
};
let cmp_expr = if side == BinaryExprSide::Lhs {