From 2f647001bc7c6c9cbfe7be92e13d08ef0cc81313 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Mon, 16 Dec 2024 12:02:12 -0500 Subject: [PATCH] Add cases to expr equality check, normalize numeric strings --- vendored/sqlite3-parser/src/parser/ast/mod.rs | 249 ++++++++++++++---- 1 file changed, 200 insertions(+), 49 deletions(-) diff --git a/vendored/sqlite3-parser/src/parser/ast/mod.rs b/vendored/sqlite3-parser/src/parser/ast/mod.rs index 3613cd6eb..d32e2d8a6 100644 --- a/vendored/sqlite3-parser/src/parser/ast/mod.rs +++ b/vendored/sqlite3-parser/src/parser/ast/mod.rs @@ -523,7 +523,6 @@ pub fn exprs_are_equivalent(expr1: &Expr, expr2: &Expr) -> bool { && exprs_are_equivalent(start1, start2) && exprs_are_equivalent(end1, end2) } - (Binary(lhs1, op1, rhs1), Binary(lhs2, op2, rhs2)) => { op1 == op2 && ((exprs_are_equivalent(lhs1, lhs2) && exprs_are_equivalent(rhs1, rhs2)) @@ -585,7 +584,7 @@ pub fn exprs_are_equivalent(expr1: &Expr, expr2: &Expr) -> bool { && order1 == order2 && filter1 == filter2 } - (Literal(lit1), Literal(lit2)) => lit1 == lit2, + (Literal(lit1), Literal(lit2)) => lit1.is_equivalent(lit2), (Id(id1), Id(id2)) => id1.0.eq_ignore_ascii_case(&id2.0), (Unary(op1, expr1), Unary(op2, expr2)) => op1 == op2 && exprs_are_equivalent(expr1, expr2), (Variable(var1), Variable(var2)) => var1 == var2, @@ -596,9 +595,35 @@ pub fn exprs_are_equivalent(expr1: &Expr, expr2: &Expr) -> bool { .zip(exprs2) .all(|(e1, e2)| exprs_are_equivalent(e1, e2)) } - (Parenthesized(exprs1), exprs2) => { + (Parenthesized(exprs1), exprs2) | (exprs2, Parenthesized(exprs1)) => { exprs1.len() == 1 && exprs_are_equivalent(&exprs1[0], exprs2) } + ( + InList { + lhs: lhs1, + not: not1, + rhs: rhs1, + }, + InList { + lhs: lhs2, + not: not2, + rhs: rhs2, + }, + ) => { + *not1 == *not2 + && exprs_are_equivalent(lhs1, lhs2) + && rhs1 + .as_ref() + .zip(rhs2.as_ref()) + .map(|(list1, list2)| { + list1.len() == list2.len() + && list1 + .iter() + .zip(list2) + .all(|(e1, e2)| exprs_are_equivalent(e1, e2)) + }) + .unwrap_or(false) + } // fall back to naive equality check _ => expr1 == expr2, } @@ -627,6 +652,20 @@ pub enum Literal { CurrentTimestamp, } +/// normalization for comparison of numeric literals +fn normalize_numeric_str(num_str: &str) -> Option { + if let Ok(value) = num_str.parse::() { + let canonical = if value.fract() == 0.0 { + format!("{}", value as i64) + } else { + format!("{}", value) + }; + Some(canonical) + } else { + None + } +} + impl Literal { /// Constructor pub fn from_ctime_kw(token: Token) -> Self { @@ -640,12 +679,20 @@ impl Literal { unreachable!() } } - pub fn is_equal(&self, other: &Self) -> bool { + + /// checks if two literal values are equivalent + fn is_equivalent(&self, other: &Self) -> bool { match (self, other) { - (Self::Numeric(n1), Self::Numeric(n2)) => n1 == n2, - (Self::String(s1), Self::String(s2)) => s1 == s2, + (Self::Numeric(n1), Self::Numeric(n2)) => { + match (normalize_numeric_str(n1), normalize_numeric_str(n2)) { + (Some(canonical), Some(canonical2)) => canonical == canonical2, + _ => false, + } + } + // TODO: check for quoted == unquoted strings? + (Self::String(s1), Self::String(s2)) => s1.eq_ignore_ascii_case(s2), (Self::Blob(b1), Self::Blob(b2)) => b1 == b2, - (Self::Keyword(k1), Self::Keyword(k2)) => k1 == k2, + (Self::Keyword(k1), Self::Keyword(k2)) => k1.eq_ignore_ascii_case(k2), (Self::Null, Self::Null) => true, (Self::CurrentDate, Self::CurrentDate) => true, (Self::CurrentTime, Self::CurrentTime) => true, @@ -768,6 +815,7 @@ impl From for Operator { } impl Operator { + /// returns whether order of operations can be ignored pub fn is_commutative(&self) -> bool { matches!( self, @@ -2006,36 +2054,72 @@ mod test { } #[test] - fn test_exprs_are_equivalent() { - use super::{Expr, Id, Literal, Operator::*}; - - // commutative addition + fn test_basic_addition_exprs_are_equivalent() { + use super::{Expr, Literal, Operator::*}; let expr1 = Expr::Binary( - Box::new(Expr::Literal(Literal::Numeric("1".to_string()))), + Box::new(Expr::Literal(Literal::Numeric("826".to_string()))), Add, - Box::new(Expr::Literal(Literal::Numeric("2".to_string()))), + Box::new(Expr::Literal(Literal::Numeric("389".to_string()))), ); let expr2 = Expr::Binary( - Box::new(Expr::Literal(Literal::Numeric("2".to_string()))), + Box::new(Expr::Literal(Literal::Numeric("389".to_string()))), Add, - Box::new(Expr::Literal(Literal::Numeric("1".to_string()))), + Box::new(Expr::Literal(Literal::Numeric("826".to_string()))), ); assert!(super::exprs_are_equivalent(&expr1, &expr2)); + } - // non-commutative subtraction + #[test] + fn test_addition_expressions_equivalent_normalized() { + use super::{Expr, Literal, Operator::*}; + let expr1 = Expr::Binary( + Box::new(Expr::Literal(Literal::Numeric("123.0".to_string()))), + Add, + Box::new(Expr::Literal(Literal::Numeric("243".to_string()))), + ); + let expr2 = Expr::Binary( + Box::new(Expr::Literal(Literal::Numeric("243.0".to_string()))), + Add, + Box::new(Expr::Literal(Literal::Numeric("123".to_string()))), + ); + assert!(super::exprs_are_equivalent(&expr1, &expr2)); + } + + #[test] + fn test_subtraction_expressions_not_equivalent() { + use super::{Expr, Literal, Operator::*}; let expr3 = Expr::Binary( - Box::new(Expr::Literal(Literal::Numeric("3".to_string()))), + Box::new(Expr::Literal(Literal::Numeric("364".to_string()))), Subtract, - Box::new(Expr::Literal(Literal::Numeric("2".to_string()))), + Box::new(Expr::Literal(Literal::Numeric("22.0".to_string()))), ); let expr4 = Expr::Binary( - Box::new(Expr::Literal(Literal::Numeric("2".to_string()))), + Box::new(Expr::Literal(Literal::Numeric("22.0".to_string()))), Subtract, - Box::new(Expr::Literal(Literal::Numeric("3".to_string()))), + Box::new(Expr::Literal(Literal::Numeric("364".to_string()))), ); assert!(!super::exprs_are_equivalent(&expr3, &expr4)); + } - // case-insensitive function calls + #[test] + fn test_subtraction_expressions_normalized() { + use super::{Expr, Literal, Operator::*}; + let expr3 = Expr::Binary( + Box::new(Expr::Literal(Literal::Numeric("66.0".to_string()))), + Subtract, + Box::new(Expr::Literal(Literal::Numeric("22".to_string()))), + ); + let expr4 = Expr::Binary( + Box::new(Expr::Literal(Literal::Numeric("66".to_string()))), + Subtract, + Box::new(Expr::Literal(Literal::Numeric("22.0".to_string()))), + ); + assert!(super::exprs_are_equivalent(&expr3, &expr4)); + } + + #[test] + fn test_expressions_equivalent_case_insensitive_functioncalls() { + use super::{Expr, Id}; let func1 = Expr::FunctionCall { name: Id("SUM".to_string()), distinctness: None, @@ -2052,7 +2136,6 @@ mod test { }; assert!(super::exprs_are_equivalent(&func1, &func2)); - // DISTINCT function argument mismatch let func3 = Expr::FunctionCall { name: Id("SUM".to_string()), distinctness: Some(super::Distinctness::Distinct), @@ -2061,21 +2144,62 @@ mod test { filter_over: None, }; assert!(!super::exprs_are_equivalent(&func1, &func3)); + } - // commutative multiplication - let expr5 = Expr::Binary( - Box::new(Expr::Literal(Literal::Numeric("4".to_string()))), - Multiply, - Box::new(Expr::Literal(Literal::Numeric("5".to_string()))), - ); - let expr6 = Expr::Binary( - Box::new(Expr::Literal(Literal::Numeric("5".to_string()))), - Multiply, - Box::new(Expr::Literal(Literal::Numeric("4".to_string()))), - ); - assert!(super::exprs_are_equivalent(&expr5, &expr6)); + #[test] + fn test_expressions_equivalent_identical_fn_with_distinct() { + use super::{Expr, Id}; + let sum = Expr::FunctionCall { + name: Id("SUM".to_string()), + distinctness: None, + args: Some(vec![Expr::Id(Id("x".to_string()))]), + order_by: None, + filter_over: None, + }; + let sum_distinct = Expr::FunctionCall { + name: Id("SUM".to_string()), + distinctness: Some(super::Distinctness::Distinct), + args: Some(vec![Expr::Id(Id("x".to_string()))]), + order_by: None, + filter_over: None, + }; + assert!(!super::exprs_are_equivalent(&sum, &sum_distinct)); + } - // parenthesized expressions + #[test] + fn test_expressions_equivalent_multiplicaiton() { + use super::{Expr, Literal, Operator::*}; + let expr1 = Expr::Binary( + Box::new(Expr::Literal(Literal::Numeric("42.0".to_string()))), + Multiply, + Box::new(Expr::Literal(Literal::Numeric("38".to_string()))), + ); + let expr2 = Expr::Binary( + Box::new(Expr::Literal(Literal::Numeric("38.0".to_string()))), + Multiply, + Box::new(Expr::Literal(Literal::Numeric("42".to_string()))), + ); + assert!(super::exprs_are_equivalent(&expr1, &expr2)); + } + + #[test] + fn test_expressions_both_parenthesized_equivalent() { + use super::{Expr, Literal, Operator::*}; + let expr1 = Expr::Parenthesized(vec![Expr::Binary( + Box::new(Expr::Literal(Literal::Numeric("683".to_string()))), + Add, + Box::new(Expr::Literal(Literal::Numeric("799.0".to_string()))), + )]); + let expr2 = Expr::Binary( + Box::new(Expr::Literal(Literal::Numeric("799".to_string()))), + Add, + Box::new(Expr::Literal(Literal::Numeric("683".to_string()))), + ); + assert!(super::exprs_are_equivalent(&expr1, &expr2)); + } + #[test] + fn test_expressions_parenthesized_equivalent() { + use super::{Expr, Literal, Operator::*}; let expr7 = Expr::Parenthesized(vec![Expr::Binary( Box::new(Expr::Literal(Literal::Numeric("6".to_string()))), Add, @@ -2087,56 +2211,83 @@ mod test { Box::new(Expr::Literal(Literal::Numeric("7".to_string()))), ); assert!(super::exprs_are_equivalent(&expr7, &expr8)); + } - // LIKE expressions with escape clauses - let expr9 = Expr::Like { + #[test] + fn test_like_expressions_equivalent() { + use super::{Expr, Id, Literal}; + let expr1 = Expr::Like { lhs: Box::new(Expr::Id(Id("name".to_string()))), not: false, op: super::LikeOperator::Like, rhs: Box::new(Expr::Literal(Literal::String("%john%".to_string()))), escape: Some(Box::new(Expr::Literal(Literal::String("\\".to_string())))), }; - let expr10 = Expr::Like { + let expr2 = Expr::Like { lhs: Box::new(Expr::Id(Id("name".to_string()))), not: false, op: super::LikeOperator::Like, rhs: Box::new(Expr::Literal(Literal::String("%john%".to_string()))), escape: Some(Box::new(Expr::Literal(Literal::String("\\".to_string())))), }; - assert!(super::exprs_are_equivalent(&expr9, &expr10)); + assert!(super::exprs_are_equivalent(&expr1, &expr2)); + } - // differing escape clauses in LIKE - let expr11 = Expr::Like { + #[test] + fn test_expressions_equivalent_like_escaped() { + use super::{Expr, Id, Literal}; + let expr1 = Expr::Like { + lhs: Box::new(Expr::Id(Id("name".to_string()))), + not: false, + op: super::LikeOperator::Like, + rhs: Box::new(Expr::Literal(Literal::String("%john%".to_string()))), + escape: Some(Box::new(Expr::Literal(Literal::String("\\".to_string())))), + }; + let expr2 = Expr::Like { lhs: Box::new(Expr::Id(Id("name".to_string()))), not: false, op: super::LikeOperator::Like, rhs: Box::new(Expr::Literal(Literal::String("%john%".to_string()))), escape: Some(Box::new(Expr::Literal(Literal::String("#".to_string())))), }; - assert!(!super::exprs_are_equivalent(&expr9, &expr11)); - - // BETWEEN expressions - let expr12 = Expr::Between { + assert!(!super::exprs_are_equivalent(&expr1, &expr2)); + } + #[test] + fn test_expressions_equivalent_between() { + use super::{Expr, Id, Literal}; + let expr1 = Expr::Between { lhs: Box::new(Expr::Id(Id("age".to_string()))), not: false, start: Box::new(Expr::Literal(Literal::Numeric("18".to_string()))), end: Box::new(Expr::Literal(Literal::Numeric("65".to_string()))), }; - let expr13 = Expr::Between { + let expr2 = Expr::Between { lhs: Box::new(Expr::Id(Id("age".to_string()))), not: false, start: Box::new(Expr::Literal(Literal::Numeric("18".to_string()))), end: Box::new(Expr::Literal(Literal::Numeric("65".to_string()))), }; - assert!(super::exprs_are_equivalent(&expr12, &expr13)); + assert!(super::exprs_are_equivalent(&expr1, &expr2)); // differing BETWEEN bounds - let expr14 = Expr::Between { + let expr3 = Expr::Between { lhs: Box::new(Expr::Id(Id("age".to_string()))), not: false, start: Box::new(Expr::Literal(Literal::Numeric("20".to_string()))), end: Box::new(Expr::Literal(Literal::Numeric("65".to_string()))), }; - assert!(!super::exprs_are_equivalent(&expr12, &expr14)); + assert!(!super::exprs_are_equivalent(&expr1, &expr3)); + } + + #[test] + fn test_normalize_numeric_string() { + use super::normalize_numeric_str; + assert_eq!(normalize_numeric_str("001"), Some("1".to_string())); + assert_eq!(normalize_numeric_str("1.00"), Some("1".to_string())); + assert_eq!(normalize_numeric_str("0.010"), Some("0.01".to_string())); + assert_eq!(normalize_numeric_str("1e3"), Some("1000".to_string())); + assert_eq!(normalize_numeric_str("1.23e2"), Some("123".to_string())); + assert_eq!(normalize_numeric_str("invalid"), None); + assert_eq!(normalize_numeric_str(""), None); } }