diff --git a/core/translate/optimizer.rs b/core/translate/optimizer.rs index 5b64f7b2d..ec657c58c 100644 --- a/core/translate/optimizer.rs +++ b/core/translate/optimizer.rs @@ -24,7 +24,7 @@ pub fn optimize_plan(plan: &mut Plan) -> Result<()> { */ fn optimize_select_plan(plan: &mut SelectPlan) -> Result<()> { optimize_subqueries(&mut plan.source)?; - eliminate_between(&mut plan.source, &mut plan.where_clause)?; + rewrite_exprs(&mut plan.source, &mut plan.where_clause)?; if let ConstantConditionEliminationResult::ImpossibleCondition = eliminate_constants(&mut plan.source, &mut plan.where_clause)? { @@ -55,7 +55,7 @@ fn optimize_select_plan(plan: &mut SelectPlan) -> Result<()> { } fn optimize_delete_plan(plan: &mut DeletePlan) -> Result<()> { - eliminate_between(&mut plan.source, &mut plan.where_clause)?; + rewrite_exprs(&mut plan.source, &mut plan.where_clause)?; if let ConstantConditionEliminationResult::ImpossibleCondition = eliminate_constants(&mut plan.source, &mut plan.where_clause)? { @@ -603,12 +603,14 @@ fn push_scan_direction(operator: &mut SourceOperator, direction: &Direction) { } } -fn eliminate_between( +fn rewrite_exprs( operator: &mut SourceOperator, where_clauses: &mut Option>, ) -> Result<()> { if let Some(predicates) = where_clauses { - *predicates = predicates.drain(..).map(convert_between_expr).collect(); + for expr in predicates.iter_mut() { + rewrite_expr(expr)?; + } } match operator { @@ -618,24 +620,30 @@ fn eliminate_between( predicates, .. } => { - eliminate_between(left, where_clauses)?; - eliminate_between(right, where_clauses)?; + rewrite_exprs(left, where_clauses)?; + rewrite_exprs(right, where_clauses)?; if let Some(predicates) = predicates { - *predicates = predicates.drain(..).map(convert_between_expr).collect(); + for expr in predicates.iter_mut() { + rewrite_expr(expr)?; + } } } SourceOperator::Scan { predicates: Some(preds), .. } => { - *preds = preds.drain(..).map(convert_between_expr).collect(); + for expr in preds.iter_mut() { + rewrite_expr(expr)?; + } } SourceOperator::Search { predicates: Some(preds), .. } => { - *preds = preds.drain(..).map(convert_between_expr).collect(); + for expr in preds.iter_mut() { + rewrite_expr(expr)?; + } } _ => (), } @@ -735,9 +743,17 @@ impl Optimizable for ast::Expr { rhs.check_index_scan(table_index, referenced_tables, available_indexes)?; if rhs_index.is_some() { // swap lhs and rhs + let swapped_operator = match *op { + ast::Operator::Equals => ast::Operator::Equals, + ast::Operator::Greater => ast::Operator::Less, + ast::Operator::GreaterEquals => ast::Operator::LessEquals, + ast::Operator::Less => ast::Operator::Greater, + ast::Operator::LessEquals => ast::Operator::GreaterEquals, + _ => unreachable!(), + }; let lhs_new = rhs.take_ownership(); let rhs_new = lhs.take_ownership(); - *self = Self::Binary(Box::new(lhs_new), *op, Box::new(rhs_new)); + *self = Self::Binary(Box::new(lhs_new), swapped_operator, Box::new(rhs_new)); return Ok(rhs_index); } Ok(None) @@ -747,16 +763,6 @@ impl Optimizable for ast::Expr { } fn check_constant(&self) -> Result> { match self { - Self::Id(id) => { - // true and false are special constants that are effectively aliases for 1 and 0 - if id.0.eq_ignore_ascii_case("true") { - return Ok(Some(ConstantPredicate::AlwaysTrue)); - } - if id.0.eq_ignore_ascii_case("false") { - return Ok(Some(ConstantPredicate::AlwaysFalse)); - } - Ok(None) - } Self::Literal(lit) => match lit { ast::Literal::Null => Ok(Some(ConstantPredicate::AlwaysFalse)), ast::Literal::Numeric(b) => { @@ -967,8 +973,20 @@ pub fn try_extract_index_search_expression( } } -fn convert_between_expr(expr: ast::Expr) -> ast::Expr { +fn rewrite_expr(expr: &mut ast::Expr) -> Result<()> { match expr { + ast::Expr::Id(id) => { + // Convert "true" and "false" to 1 and 0 + if id.0.eq_ignore_ascii_case("true") { + *expr = ast::Expr::Literal(ast::Literal::Numeric(1.to_string())); + return Ok(()); + } + if id.0.eq_ignore_ascii_case("false") { + *expr = ast::Expr::Literal(ast::Literal::Numeric(0.to_string())); + return Ok(()); + } + Ok(()) + } ast::Expr::Between { lhs, not, @@ -976,53 +994,62 @@ fn convert_between_expr(expr: ast::Expr) -> ast::Expr { end, } => { // Convert `y NOT BETWEEN x AND z` to `x > y OR y > z` - let (lower_op, upper_op) = if not { + let (lower_op, upper_op) = if *not { (ast::Operator::Greater, ast::Operator::Greater) } else { // Convert `y BETWEEN x AND z` to `x <= y AND y <= z` (ast::Operator::LessEquals, ast::Operator::LessEquals) }; - let lower_bound = ast::Expr::Binary(start, lower_op, lhs.clone()); - let upper_bound = ast::Expr::Binary(lhs, upper_op, end); + rewrite_expr(start)?; + rewrite_expr(lhs)?; + rewrite_expr(end)?; - if not { - ast::Expr::Binary( + let start = start.take_ownership(); + let lhs = lhs.take_ownership(); + let end = end.take_ownership(); + + let lower_bound = ast::Expr::Binary(Box::new(start), lower_op, Box::new(lhs.clone())); + let upper_bound = ast::Expr::Binary(Box::new(lhs), upper_op, Box::new(end)); + + if *not { + *expr = ast::Expr::Binary( Box::new(lower_bound), ast::Operator::Or, Box::new(upper_bound), - ) + ); } else { - ast::Expr::Binary( + *expr = ast::Expr::Binary( Box::new(lower_bound), ast::Operator::And, Box::new(upper_bound), - ) + ); } + Ok(()) } - ast::Expr::Parenthesized(mut exprs) => { - ast::Expr::Parenthesized(exprs.drain(..).map(convert_between_expr).collect()) + ast::Expr::Parenthesized(ref mut exprs) => { + for subexpr in exprs.iter_mut() { + rewrite_expr(subexpr)?; + } + let exprs = std::mem::take(exprs); + *expr = ast::Expr::Parenthesized(exprs); + Ok(()) } // Process other expressions recursively - ast::Expr::Binary(lhs, op, rhs) => ast::Expr::Binary( - Box::new(convert_between_expr(*lhs)), - op, - Box::new(convert_between_expr(*rhs)), - ), - ast::Expr::FunctionCall { - name, - distinctness, - args, - order_by, - filter_over, - } => ast::Expr::FunctionCall { - name, - distinctness, - args: args.map(|args| args.into_iter().map(convert_between_expr).collect()), - order_by, - filter_over, - }, - _ => expr, + ast::Expr::Binary(lhs, _, rhs) => { + rewrite_expr(lhs)?; + rewrite_expr(rhs)?; + Ok(()) + } + ast::Expr::FunctionCall { args, .. } => { + if let Some(args) = args { + for arg in args.iter_mut() { + rewrite_expr(arg)?; + } + } + Ok(()) + } + _ => Ok(()), } } diff --git a/core/translate/planner.rs b/core/translate/planner.rs index b12391579..abd224d64 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -113,18 +113,18 @@ pub fn bind_column_references( crate::bail_parse_error!("Column {} is ambiguous", id.0); } let col = table.columns().get(col_idx.unwrap()).unwrap(); - match_result = Some((tbl_idx, col_idx.unwrap(), col.primary_key)); + match_result = Some((tbl_idx, col_idx.unwrap(), col.is_rowid_alias)); } } if match_result.is_none() { crate::bail_parse_error!("Column {} not found", id.0); } - let (tbl_idx, col_idx, is_primary_key) = match_result.unwrap(); + let (tbl_idx, col_idx, is_rowid_alias) = match_result.unwrap(); *expr = ast::Expr::Column { database: None, // TODO: support different databases table: tbl_idx, column: col_idx, - is_rowid_alias: is_primary_key, + is_rowid_alias, }; Ok(()) } diff --git a/testing/where.test b/testing/where.test index 8a568d0fc..feb4e812a 100755 --- a/testing/where.test +++ b/testing/where.test @@ -338,3 +338,8 @@ do_execsql_test between-price-range-with-names { AND (name = 'sweatshirt' OR name = 'sneakers'); } {5|sweatshirt|74.0 8|sneakers|82.0} + +do_execsql_test where-between-true-and-2 { + select id from users where id between true and 2; +} {1 +2}