From 1bcdf99eab506b2a44aecc3eb6c225629735b270 Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Fri, 10 Jan 2025 10:04:07 +0200 Subject: [PATCH] core/optimizer: do expression rewriting on all expressions --- core/translate/optimizer.rs | 91 +++++++++++++++++++++++++------------ core/translate/planner.rs | 8 ++++ testing/select.test | 17 +++++++ 3 files changed, 86 insertions(+), 30 deletions(-) diff --git a/core/translate/optimizer.rs b/core/translate/optimizer.rs index ec657c58c..d6caba85e 100644 --- a/core/translate/optimizer.rs +++ b/core/translate/optimizer.rs @@ -24,7 +24,7 @@ pub fn optimize_plan(plan: &mut Plan) -> Result<()> { */ fn optimize_select_plan(plan: &mut SelectPlan) -> Result<()> { optimize_subqueries(&mut plan.source)?; - rewrite_exprs(&mut plan.source, &mut plan.where_clause)?; + rewrite_exprs_select(plan)?; if let ConstantConditionEliminationResult::ImpossibleCondition = eliminate_constants(&mut plan.source, &mut plan.where_clause)? { @@ -55,7 +55,7 @@ fn optimize_select_plan(plan: &mut SelectPlan) -> Result<()> { } fn optimize_delete_plan(plan: &mut DeletePlan) -> Result<()> { - rewrite_exprs(&mut plan.source, &mut plan.where_clause)?; + rewrite_exprs_delete(plan)?; if let ConstantConditionEliminationResult::ImpossibleCondition = eliminate_constants(&mut plan.source, &mut plan.where_clause)? { @@ -603,16 +603,45 @@ fn push_scan_direction(operator: &mut SourceOperator, direction: &Direction) { } } -fn rewrite_exprs( - operator: &mut SourceOperator, - where_clauses: &mut Option>, -) -> Result<()> { - if let Some(predicates) = where_clauses { - for expr in predicates.iter_mut() { +fn rewrite_exprs_select(plan: &mut SelectPlan) -> Result<()> { + rewrite_source_operator_exprs(&mut plan.source)?; + for rc in plan.result_columns.iter_mut() { + rewrite_expr(&mut rc.expr)?; + } + for agg in plan.aggregates.iter_mut() { + rewrite_expr(&mut agg.original_expr)?; + } + if let Some(predicates) = &mut plan.where_clause { + for expr in predicates { + rewrite_expr(expr)?; + } + } + if let Some(group_by) = &mut plan.group_by { + for expr in group_by.exprs.iter_mut() { + rewrite_expr(expr)?; + } + } + if let Some(order_by) = &mut plan.order_by { + for (expr, _) in order_by.iter_mut() { rewrite_expr(expr)?; } } + Ok(()) +} + +fn rewrite_exprs_delete(plan: &mut DeletePlan) -> Result<()> { + rewrite_source_operator_exprs(&mut plan.source)?; + if let Some(predicates) = &mut plan.where_clause { + for expr in predicates { + rewrite_expr(expr)?; + } + } + + Ok(()) +} + +fn rewrite_source_operator_exprs(operator: &mut SourceOperator) -> Result<()> { match operator { SourceOperator::Join { left, @@ -620,35 +649,37 @@ fn rewrite_exprs( predicates, .. } => { - rewrite_exprs(left, where_clauses)?; - rewrite_exprs(right, where_clauses)?; + rewrite_source_operator_exprs(left)?; + rewrite_source_operator_exprs(right)?; if let Some(predicates) = predicates { for expr in predicates.iter_mut() { rewrite_expr(expr)?; } } - } - SourceOperator::Scan { - predicates: Some(preds), - .. - } => { - for expr in preds.iter_mut() { - rewrite_expr(expr)?; - } - } - SourceOperator::Search { - predicates: Some(preds), - .. - } => { - for expr in preds.iter_mut() { - rewrite_expr(expr)?; - } - } - _ => (), - } - Ok(()) + Ok(()) + } + SourceOperator::Scan { predicates, .. } | SourceOperator::Search { predicates, .. } => { + if let Some(predicates) = predicates { + for expr in predicates.iter_mut() { + rewrite_expr(expr)?; + } + } + + Ok(()) + } + SourceOperator::Subquery { predicates, .. } => { + if let Some(predicates) = predicates { + for expr in predicates.iter_mut() { + rewrite_expr(expr)?; + } + } + + Ok(()) + } + SourceOperator::Nothing { .. } => Ok(()), + } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] diff --git a/core/translate/planner.rs b/core/translate/planner.rs index abd224d64..64a0ffb04 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -544,6 +544,14 @@ fn parse_join( pub fn parse_limit(limit: Limit) -> Option { if let Expr::Literal(ast::Literal::Numeric(n)) = limit.expr { n.parse().ok() + } else if let Expr::Id(id) = limit.expr { + if id.0.eq_ignore_ascii_case("true") { + Some(1) + } else if id.0.eq_ignore_ascii_case("false") { + Some(0) + } else { + None + } } else { None } diff --git a/testing/select.test b/testing/select.test index c6d403a6a..49f8021bc 100755 --- a/testing/select.test +++ b/testing/select.test @@ -11,6 +11,14 @@ do_execsql_test select-const-2 { SELECT 2 } {2} +do_execsql_test select-true { + SELECT true +} {1} + +do_execsql_test select-false { + SELECT false +} {0} + do_execsql_test select-text-escape-1 { SELECT '''a' } {'a} @@ -31,6 +39,15 @@ do_execsql_test select-limit-0 { SELECT id FROM users LIMIT 0; } {} +# ORDER BY id here because sqlite uses age_idx here and we (yet) don't so force it to evaluate in ID order +do_execsql_test select-limit-true { + SELECT id FROM users ORDER BY id LIMIT true; +} {1} + +do_execsql_test select-limit-false { + SELECT id FROM users ORDER BY id LIMIT false; +} {} + do_execsql_test realify { select price from products limit 1; } {79.0}