Implement custom expression equality checking

This commit is contained in:
PThorpe92
2024-12-15 23:29:56 -05:00
parent 8d18263fd6
commit 25772ee1f3
3 changed files with 271 additions and 6 deletions

View File

@@ -498,6 +498,112 @@ impl Expr {
}
}
/// This function is used to determine whether two expressions are logically
/// equivalent in the context of queries, even if their representations
/// differ. e.g.: `SUM(x)` and `sum(x)`, `x + y` and `y + x`
pub fn exprs_are_equivalent(expr1: &Expr, expr2: &Expr) -> bool {
use Expr::*;
match (expr1, expr2) {
(
Between {
lhs: lhs1,
not: not1,
start: start1,
end: end1,
},
Between {
lhs: lhs2,
not: not2,
start: start2,
end: end2,
},
) => {
*not1 == *not2
&& exprs_are_equivalent(lhs1, lhs2)
&& exprs_are_equivalent(start1, start2)
&& exprs_are_equivalent(end1, end2)
}
(Binary(lhs1, op1, rhs1), Binary(lhs2, op2, rhs2)) => {
op1 == op2
&& ((exprs_are_equivalent(lhs1, lhs2) && exprs_are_equivalent(rhs1, rhs2))
|| (op1.is_commutative()
&& exprs_are_equivalent(lhs1, rhs2)
&& exprs_are_equivalent(rhs1, lhs2)))
}
(
Case {
base: base1,
when_then_pairs: pairs1,
else_expr: else1,
},
Case {
base: base2,
when_then_pairs: pairs2,
else_expr: else2,
},
) => {
base1 == base2
&& pairs1.len() == pairs2.len()
&& pairs1.iter().zip(pairs2).all(|((w1, t1), (w2, t2))| {
exprs_are_equivalent(w1, w2) && exprs_are_equivalent(t1, t2)
})
&& else1 == else2
}
(
Cast {
expr: expr1,
type_name: type1,
},
Cast {
expr: expr2,
type_name: type2,
},
) => exprs_are_equivalent(expr1, expr2) && type1 == type2,
(Collate(expr1, collation1), Collate(expr2, collation2)) => {
exprs_are_equivalent(expr1, expr2) && collation1.eq_ignore_ascii_case(collation2)
}
(
FunctionCall {
name: name1,
distinctness: distinct1,
args: args1,
order_by: order1,
filter_over: filter1,
},
FunctionCall {
name: name2,
distinctness: distinct2,
args: args2,
order_by: order2,
filter_over: filter2,
},
) => {
name1.0.eq_ignore_ascii_case(&name2.0)
&& distinct1 == distinct2
&& args1 == args2
&& order1 == order2
&& filter1 == filter2
}
(Literal(lit1), Literal(lit2)) => lit1 == lit2,
(Id(id1), Id(id2)) => id1.0.eq_ignore_ascii_case(&id2.0),
(Unary(op1, expr1), Unary(op2, expr2)) => op1 == op2 && exprs_are_equivalent(expr1, expr2),
(Variable(var1), Variable(var2)) => var1 == var2,
(Parenthesized(exprs1), Parenthesized(exprs2)) => {
exprs1.len() == exprs2.len()
&& exprs1
.iter()
.zip(exprs2)
.all(|(e1, e2)| exprs_are_equivalent(e1, e2))
}
(Parenthesized(exprs1), exprs2) => {
exprs1.len() == 1 && exprs_are_equivalent(&exprs1[0], exprs2)
}
// fall back to naive equality check
_ => expr1 == expr2,
}
}
/// SQL literal
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Literal {
@@ -534,6 +640,19 @@ impl Literal {
unreachable!()
}
}
pub fn is_equal(&self, other: &Self) -> bool {
match (self, other) {
(Self::Numeric(n1), Self::Numeric(n2)) => n1 == n2,
(Self::String(s1), Self::String(s2)) => s1 == s2,
(Self::Blob(b1), Self::Blob(b2)) => b1 == b2,
(Self::Keyword(k1), Self::Keyword(k2)) => k1 == k2,
(Self::Null, Self::Null) => true,
(Self::CurrentDate, Self::CurrentDate) => true,
(Self::CurrentTime, Self::CurrentTime) => true,
(Self::CurrentTimestamp, Self::CurrentTimestamp) => true,
_ => false,
}
}
}
/// Textual comparison operator in an expression
@@ -648,6 +767,20 @@ impl From<YYCODETYPE> for Operator {
}
}
impl Operator {
pub fn is_commutative(&self) -> bool {
matches!(
self,
Operator::Add
| Operator::Multiply
| Operator::BitwiseAnd
| Operator::BitwiseOr
| Operator::Equals
| Operator::NotEquals
)
}
}
/// Unary operators
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum UnaryOperator {
@@ -1871,4 +2004,139 @@ mod test {
fn name(s: &'static str) -> Name {
Name(s.to_owned())
}
#[test]
fn test_exprs_are_equivalent() {
use super::{Expr, Id, Literal, Operator::*};
// commutative addition
let expr1 = Expr::Binary(
Box::new(Expr::Literal(Literal::Numeric("1".to_string()))),
Add,
Box::new(Expr::Literal(Literal::Numeric("2".to_string()))),
);
let expr2 = Expr::Binary(
Box::new(Expr::Literal(Literal::Numeric("2".to_string()))),
Add,
Box::new(Expr::Literal(Literal::Numeric("1".to_string()))),
);
assert!(super::exprs_are_equivalent(&expr1, &expr2));
// non-commutative subtraction
let expr3 = Expr::Binary(
Box::new(Expr::Literal(Literal::Numeric("3".to_string()))),
Subtract,
Box::new(Expr::Literal(Literal::Numeric("2".to_string()))),
);
let expr4 = Expr::Binary(
Box::new(Expr::Literal(Literal::Numeric("2".to_string()))),
Subtract,
Box::new(Expr::Literal(Literal::Numeric("3".to_string()))),
);
assert!(!super::exprs_are_equivalent(&expr3, &expr4));
// case-insensitive function calls
let func1 = Expr::FunctionCall {
name: Id("SUM".to_string()),
distinctness: None,
args: Some(vec![Expr::Id(Id("x".to_string()))]),
order_by: None,
filter_over: None,
};
let func2 = Expr::FunctionCall {
name: Id("sum".to_string()),
distinctness: None,
args: Some(vec![Expr::Id(Id("x".to_string()))]),
order_by: None,
filter_over: None,
};
assert!(super::exprs_are_equivalent(&func1, &func2));
// DISTINCT function argument mismatch
let func3 = Expr::FunctionCall {
name: Id("SUM".to_string()),
distinctness: Some(super::Distinctness::Distinct),
args: Some(vec![Expr::Id(Id("x".to_string()))]),
order_by: None,
filter_over: None,
};
assert!(!super::exprs_are_equivalent(&func1, &func3));
// commutative multiplication
let expr5 = Expr::Binary(
Box::new(Expr::Literal(Literal::Numeric("4".to_string()))),
Multiply,
Box::new(Expr::Literal(Literal::Numeric("5".to_string()))),
);
let expr6 = Expr::Binary(
Box::new(Expr::Literal(Literal::Numeric("5".to_string()))),
Multiply,
Box::new(Expr::Literal(Literal::Numeric("4".to_string()))),
);
assert!(super::exprs_are_equivalent(&expr5, &expr6));
// parenthesized expressions
let expr7 = Expr::Parenthesized(vec![Expr::Binary(
Box::new(Expr::Literal(Literal::Numeric("6".to_string()))),
Add,
Box::new(Expr::Literal(Literal::Numeric("7".to_string()))),
)]);
let expr8 = Expr::Binary(
Box::new(Expr::Literal(Literal::Numeric("6".to_string()))),
Add,
Box::new(Expr::Literal(Literal::Numeric("7".to_string()))),
);
assert!(super::exprs_are_equivalent(&expr7, &expr8));
// LIKE expressions with escape clauses
let expr9 = Expr::Like {
lhs: Box::new(Expr::Id(Id("name".to_string()))),
not: false,
op: super::LikeOperator::Like,
rhs: Box::new(Expr::Literal(Literal::String("%john%".to_string()))),
escape: Some(Box::new(Expr::Literal(Literal::String("\\".to_string())))),
};
let expr10 = Expr::Like {
lhs: Box::new(Expr::Id(Id("name".to_string()))),
not: false,
op: super::LikeOperator::Like,
rhs: Box::new(Expr::Literal(Literal::String("%john%".to_string()))),
escape: Some(Box::new(Expr::Literal(Literal::String("\\".to_string())))),
};
assert!(super::exprs_are_equivalent(&expr9, &expr10));
// differing escape clauses in LIKE
let expr11 = Expr::Like {
lhs: Box::new(Expr::Id(Id("name".to_string()))),
not: false,
op: super::LikeOperator::Like,
rhs: Box::new(Expr::Literal(Literal::String("%john%".to_string()))),
escape: Some(Box::new(Expr::Literal(Literal::String("#".to_string())))),
};
assert!(!super::exprs_are_equivalent(&expr9, &expr11));
// BETWEEN expressions
let expr12 = Expr::Between {
lhs: Box::new(Expr::Id(Id("age".to_string()))),
not: false,
start: Box::new(Expr::Literal(Literal::Numeric("18".to_string()))),
end: Box::new(Expr::Literal(Literal::Numeric("65".to_string()))),
};
let expr13 = Expr::Between {
lhs: Box::new(Expr::Id(Id("age".to_string()))),
not: false,
start: Box::new(Expr::Literal(Literal::Numeric("18".to_string()))),
end: Box::new(Expr::Literal(Literal::Numeric("65".to_string()))),
};
assert!(super::exprs_are_equivalent(&expr12, &expr13));
// differing BETWEEN bounds
let expr14 = Expr::Between {
lhs: Box::new(Expr::Id(Id("age".to_string()))),
not: false,
start: Box::new(Expr::Literal(Literal::Numeric("20".to_string()))),
end: Box::new(Expr::Literal(Literal::Numeric("65".to_string()))),
};
assert!(!super::exprs_are_equivalent(&expr12, &expr14));
}
}