From 25772ee1f3cb5fabc9f5cd93e08857c8c927bdf7 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sun, 15 Dec 2024 23:29:56 -0500 Subject: [PATCH] Implement custom expression equality checking --- core/translate/emitter.rs | 7 +- core/translate/mod.rs | 2 +- vendored/sqlite3-parser/src/parser/ast/mod.rs | 268 ++++++++++++++++++ 3 files changed, 271 insertions(+), 6 deletions(-) diff --git a/core/translate/emitter.rs b/core/translate/emitter.rs index d276adb3b..f7658ccfa 100644 --- a/core/translate/emitter.rs +++ b/core/translate/emitter.rs @@ -5,7 +5,7 @@ use std::cell::RefCell; use std::collections::HashMap; use std::rc::{Rc, Weak}; -use sqlite3_parser::ast; +use sqlite3_parser::ast::{self, exprs_are_equivalent}; use crate::schema::{Column, PseudoTable, Table}; use crate::storage::sqlite3_ondisk::DatabaseHeader; @@ -1766,13 +1766,10 @@ fn order_by_deduplicate_result_columns( ) -> Option> { let mut result_column_remapping: Option> = None; for (i, rc) in result_columns.iter().enumerate() { - // TODO: implement a custom equality check for expressions - // there are lots of examples where this breaks, even simple ones like - // sum(x) != SUM(x) let found = order_by .iter() .enumerate() - .find(|(_, (expr, _))| expr == &rc.expr); + .find(|(_, (expr, _))| exprs_are_equivalent(expr, &rc.expr)); if let Some((j, _)) = found { if let Some(ref mut v) = result_column_remapping { v.push((i, j)); diff --git a/core/translate/mod.rs b/core/translate/mod.rs index cb2463239..e7ce74d4b 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -392,7 +392,7 @@ fn update_pragma(name: &str, value: i64, header: Rc>, pa struct TableFormatter<'a> { body: &'a ast::CreateTableBody, } -impl<'a> Display for TableFormatter<'a> { +impl<'a> Display for TableFormatter<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.body.to_fmt(f) } diff --git a/vendored/sqlite3-parser/src/parser/ast/mod.rs b/vendored/sqlite3-parser/src/parser/ast/mod.rs index da1798cff..3613cd6eb 100644 --- a/vendored/sqlite3-parser/src/parser/ast/mod.rs +++ b/vendored/sqlite3-parser/src/parser/ast/mod.rs @@ -498,6 +498,112 @@ impl Expr { } } +/// This function is used to determine whether two expressions are logically +/// equivalent in the context of queries, even if their representations +/// differ. e.g.: `SUM(x)` and `sum(x)`, `x + y` and `y + x` +pub fn exprs_are_equivalent(expr1: &Expr, expr2: &Expr) -> bool { + use Expr::*; + match (expr1, expr2) { + ( + Between { + lhs: lhs1, + not: not1, + start: start1, + end: end1, + }, + Between { + lhs: lhs2, + not: not2, + start: start2, + end: end2, + }, + ) => { + *not1 == *not2 + && exprs_are_equivalent(lhs1, lhs2) + && exprs_are_equivalent(start1, start2) + && exprs_are_equivalent(end1, end2) + } + + (Binary(lhs1, op1, rhs1), Binary(lhs2, op2, rhs2)) => { + op1 == op2 + && ((exprs_are_equivalent(lhs1, lhs2) && exprs_are_equivalent(rhs1, rhs2)) + || (op1.is_commutative() + && exprs_are_equivalent(lhs1, rhs2) + && exprs_are_equivalent(rhs1, lhs2))) + } + ( + Case { + base: base1, + when_then_pairs: pairs1, + else_expr: else1, + }, + Case { + base: base2, + when_then_pairs: pairs2, + else_expr: else2, + }, + ) => { + base1 == base2 + && pairs1.len() == pairs2.len() + && pairs1.iter().zip(pairs2).all(|((w1, t1), (w2, t2))| { + exprs_are_equivalent(w1, w2) && exprs_are_equivalent(t1, t2) + }) + && else1 == else2 + } + ( + Cast { + expr: expr1, + type_name: type1, + }, + Cast { + expr: expr2, + type_name: type2, + }, + ) => exprs_are_equivalent(expr1, expr2) && type1 == type2, + (Collate(expr1, collation1), Collate(expr2, collation2)) => { + exprs_are_equivalent(expr1, expr2) && collation1.eq_ignore_ascii_case(collation2) + } + ( + FunctionCall { + name: name1, + distinctness: distinct1, + args: args1, + order_by: order1, + filter_over: filter1, + }, + FunctionCall { + name: name2, + distinctness: distinct2, + args: args2, + order_by: order2, + filter_over: filter2, + }, + ) => { + name1.0.eq_ignore_ascii_case(&name2.0) + && distinct1 == distinct2 + && args1 == args2 + && order1 == order2 + && filter1 == filter2 + } + (Literal(lit1), Literal(lit2)) => lit1 == lit2, + (Id(id1), Id(id2)) => id1.0.eq_ignore_ascii_case(&id2.0), + (Unary(op1, expr1), Unary(op2, expr2)) => op1 == op2 && exprs_are_equivalent(expr1, expr2), + (Variable(var1), Variable(var2)) => var1 == var2, + (Parenthesized(exprs1), Parenthesized(exprs2)) => { + exprs1.len() == exprs2.len() + && exprs1 + .iter() + .zip(exprs2) + .all(|(e1, e2)| exprs_are_equivalent(e1, e2)) + } + (Parenthesized(exprs1), exprs2) => { + exprs1.len() == 1 && exprs_are_equivalent(&exprs1[0], exprs2) + } + // fall back to naive equality check + _ => expr1 == expr2, + } +} + /// SQL literal #[derive(Clone, Debug, PartialEq, Eq)] pub enum Literal { @@ -534,6 +640,19 @@ impl Literal { unreachable!() } } + pub fn is_equal(&self, other: &Self) -> bool { + match (self, other) { + (Self::Numeric(n1), Self::Numeric(n2)) => n1 == n2, + (Self::String(s1), Self::String(s2)) => s1 == s2, + (Self::Blob(b1), Self::Blob(b2)) => b1 == b2, + (Self::Keyword(k1), Self::Keyword(k2)) => k1 == k2, + (Self::Null, Self::Null) => true, + (Self::CurrentDate, Self::CurrentDate) => true, + (Self::CurrentTime, Self::CurrentTime) => true, + (Self::CurrentTimestamp, Self::CurrentTimestamp) => true, + _ => false, + } + } } /// Textual comparison operator in an expression @@ -648,6 +767,20 @@ impl From for Operator { } } +impl Operator { + pub fn is_commutative(&self) -> bool { + matches!( + self, + Operator::Add + | Operator::Multiply + | Operator::BitwiseAnd + | Operator::BitwiseOr + | Operator::Equals + | Operator::NotEquals + ) + } +} + /// Unary operators #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum UnaryOperator { @@ -1871,4 +2004,139 @@ mod test { fn name(s: &'static str) -> Name { Name(s.to_owned()) } + + #[test] + fn test_exprs_are_equivalent() { + use super::{Expr, Id, Literal, Operator::*}; + + // commutative addition + let expr1 = Expr::Binary( + Box::new(Expr::Literal(Literal::Numeric("1".to_string()))), + Add, + Box::new(Expr::Literal(Literal::Numeric("2".to_string()))), + ); + let expr2 = Expr::Binary( + Box::new(Expr::Literal(Literal::Numeric("2".to_string()))), + Add, + Box::new(Expr::Literal(Literal::Numeric("1".to_string()))), + ); + assert!(super::exprs_are_equivalent(&expr1, &expr2)); + + // non-commutative subtraction + let expr3 = Expr::Binary( + Box::new(Expr::Literal(Literal::Numeric("3".to_string()))), + Subtract, + Box::new(Expr::Literal(Literal::Numeric("2".to_string()))), + ); + let expr4 = Expr::Binary( + Box::new(Expr::Literal(Literal::Numeric("2".to_string()))), + Subtract, + Box::new(Expr::Literal(Literal::Numeric("3".to_string()))), + ); + assert!(!super::exprs_are_equivalent(&expr3, &expr4)); + + // case-insensitive function calls + let func1 = Expr::FunctionCall { + name: Id("SUM".to_string()), + distinctness: None, + args: Some(vec![Expr::Id(Id("x".to_string()))]), + order_by: None, + filter_over: None, + }; + let func2 = Expr::FunctionCall { + name: Id("sum".to_string()), + distinctness: None, + args: Some(vec![Expr::Id(Id("x".to_string()))]), + order_by: None, + filter_over: None, + }; + assert!(super::exprs_are_equivalent(&func1, &func2)); + + // DISTINCT function argument mismatch + let func3 = Expr::FunctionCall { + name: Id("SUM".to_string()), + distinctness: Some(super::Distinctness::Distinct), + args: Some(vec![Expr::Id(Id("x".to_string()))]), + order_by: None, + filter_over: None, + }; + assert!(!super::exprs_are_equivalent(&func1, &func3)); + + // commutative multiplication + let expr5 = Expr::Binary( + Box::new(Expr::Literal(Literal::Numeric("4".to_string()))), + Multiply, + Box::new(Expr::Literal(Literal::Numeric("5".to_string()))), + ); + let expr6 = Expr::Binary( + Box::new(Expr::Literal(Literal::Numeric("5".to_string()))), + Multiply, + Box::new(Expr::Literal(Literal::Numeric("4".to_string()))), + ); + assert!(super::exprs_are_equivalent(&expr5, &expr6)); + + // parenthesized expressions + let expr7 = Expr::Parenthesized(vec![Expr::Binary( + Box::new(Expr::Literal(Literal::Numeric("6".to_string()))), + Add, + Box::new(Expr::Literal(Literal::Numeric("7".to_string()))), + )]); + let expr8 = Expr::Binary( + Box::new(Expr::Literal(Literal::Numeric("6".to_string()))), + Add, + Box::new(Expr::Literal(Literal::Numeric("7".to_string()))), + ); + assert!(super::exprs_are_equivalent(&expr7, &expr8)); + + // LIKE expressions with escape clauses + let expr9 = Expr::Like { + lhs: Box::new(Expr::Id(Id("name".to_string()))), + not: false, + op: super::LikeOperator::Like, + rhs: Box::new(Expr::Literal(Literal::String("%john%".to_string()))), + escape: Some(Box::new(Expr::Literal(Literal::String("\\".to_string())))), + }; + let expr10 = Expr::Like { + lhs: Box::new(Expr::Id(Id("name".to_string()))), + not: false, + op: super::LikeOperator::Like, + rhs: Box::new(Expr::Literal(Literal::String("%john%".to_string()))), + escape: Some(Box::new(Expr::Literal(Literal::String("\\".to_string())))), + }; + assert!(super::exprs_are_equivalent(&expr9, &expr10)); + + // differing escape clauses in LIKE + let expr11 = Expr::Like { + lhs: Box::new(Expr::Id(Id("name".to_string()))), + not: false, + op: super::LikeOperator::Like, + rhs: Box::new(Expr::Literal(Literal::String("%john%".to_string()))), + escape: Some(Box::new(Expr::Literal(Literal::String("#".to_string())))), + }; + assert!(!super::exprs_are_equivalent(&expr9, &expr11)); + + // BETWEEN expressions + let expr12 = Expr::Between { + lhs: Box::new(Expr::Id(Id("age".to_string()))), + not: false, + start: Box::new(Expr::Literal(Literal::Numeric("18".to_string()))), + end: Box::new(Expr::Literal(Literal::Numeric("65".to_string()))), + }; + let expr13 = Expr::Between { + lhs: Box::new(Expr::Id(Id("age".to_string()))), + not: false, + start: Box::new(Expr::Literal(Literal::Numeric("18".to_string()))), + end: Box::new(Expr::Literal(Literal::Numeric("65".to_string()))), + }; + assert!(super::exprs_are_equivalent(&expr12, &expr13)); + + // differing BETWEEN bounds + let expr14 = Expr::Between { + lhs: Box::new(Expr::Id(Id("age".to_string()))), + not: false, + start: Box::new(Expr::Literal(Literal::Numeric("20".to_string()))), + end: Box::new(Expr::Literal(Literal::Numeric("65".to_string()))), + }; + assert!(!super::exprs_are_equivalent(&expr12, &expr14)); + } }