mirror of
https://github.com/aljazceru/turso.git
synced 2026-01-06 09:44:21 +01:00
Implement custom expression equality checking
This commit is contained in:
@@ -498,6 +498,112 @@ impl Expr {
|
||||
}
|
||||
}
|
||||
|
||||
/// This function is used to determine whether two expressions are logically
|
||||
/// equivalent in the context of queries, even if their representations
|
||||
/// differ. e.g.: `SUM(x)` and `sum(x)`, `x + y` and `y + x`
|
||||
pub fn exprs_are_equivalent(expr1: &Expr, expr2: &Expr) -> bool {
|
||||
use Expr::*;
|
||||
match (expr1, expr2) {
|
||||
(
|
||||
Between {
|
||||
lhs: lhs1,
|
||||
not: not1,
|
||||
start: start1,
|
||||
end: end1,
|
||||
},
|
||||
Between {
|
||||
lhs: lhs2,
|
||||
not: not2,
|
||||
start: start2,
|
||||
end: end2,
|
||||
},
|
||||
) => {
|
||||
*not1 == *not2
|
||||
&& exprs_are_equivalent(lhs1, lhs2)
|
||||
&& exprs_are_equivalent(start1, start2)
|
||||
&& exprs_are_equivalent(end1, end2)
|
||||
}
|
||||
|
||||
(Binary(lhs1, op1, rhs1), Binary(lhs2, op2, rhs2)) => {
|
||||
op1 == op2
|
||||
&& ((exprs_are_equivalent(lhs1, lhs2) && exprs_are_equivalent(rhs1, rhs2))
|
||||
|| (op1.is_commutative()
|
||||
&& exprs_are_equivalent(lhs1, rhs2)
|
||||
&& exprs_are_equivalent(rhs1, lhs2)))
|
||||
}
|
||||
(
|
||||
Case {
|
||||
base: base1,
|
||||
when_then_pairs: pairs1,
|
||||
else_expr: else1,
|
||||
},
|
||||
Case {
|
||||
base: base2,
|
||||
when_then_pairs: pairs2,
|
||||
else_expr: else2,
|
||||
},
|
||||
) => {
|
||||
base1 == base2
|
||||
&& pairs1.len() == pairs2.len()
|
||||
&& pairs1.iter().zip(pairs2).all(|((w1, t1), (w2, t2))| {
|
||||
exprs_are_equivalent(w1, w2) && exprs_are_equivalent(t1, t2)
|
||||
})
|
||||
&& else1 == else2
|
||||
}
|
||||
(
|
||||
Cast {
|
||||
expr: expr1,
|
||||
type_name: type1,
|
||||
},
|
||||
Cast {
|
||||
expr: expr2,
|
||||
type_name: type2,
|
||||
},
|
||||
) => exprs_are_equivalent(expr1, expr2) && type1 == type2,
|
||||
(Collate(expr1, collation1), Collate(expr2, collation2)) => {
|
||||
exprs_are_equivalent(expr1, expr2) && collation1.eq_ignore_ascii_case(collation2)
|
||||
}
|
||||
(
|
||||
FunctionCall {
|
||||
name: name1,
|
||||
distinctness: distinct1,
|
||||
args: args1,
|
||||
order_by: order1,
|
||||
filter_over: filter1,
|
||||
},
|
||||
FunctionCall {
|
||||
name: name2,
|
||||
distinctness: distinct2,
|
||||
args: args2,
|
||||
order_by: order2,
|
||||
filter_over: filter2,
|
||||
},
|
||||
) => {
|
||||
name1.0.eq_ignore_ascii_case(&name2.0)
|
||||
&& distinct1 == distinct2
|
||||
&& args1 == args2
|
||||
&& order1 == order2
|
||||
&& filter1 == filter2
|
||||
}
|
||||
(Literal(lit1), Literal(lit2)) => lit1 == lit2,
|
||||
(Id(id1), Id(id2)) => id1.0.eq_ignore_ascii_case(&id2.0),
|
||||
(Unary(op1, expr1), Unary(op2, expr2)) => op1 == op2 && exprs_are_equivalent(expr1, expr2),
|
||||
(Variable(var1), Variable(var2)) => var1 == var2,
|
||||
(Parenthesized(exprs1), Parenthesized(exprs2)) => {
|
||||
exprs1.len() == exprs2.len()
|
||||
&& exprs1
|
||||
.iter()
|
||||
.zip(exprs2)
|
||||
.all(|(e1, e2)| exprs_are_equivalent(e1, e2))
|
||||
}
|
||||
(Parenthesized(exprs1), exprs2) => {
|
||||
exprs1.len() == 1 && exprs_are_equivalent(&exprs1[0], exprs2)
|
||||
}
|
||||
// fall back to naive equality check
|
||||
_ => expr1 == expr2,
|
||||
}
|
||||
}
|
||||
|
||||
/// SQL literal
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum Literal {
|
||||
@@ -534,6 +640,19 @@ impl Literal {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
pub fn is_equal(&self, other: &Self) -> bool {
|
||||
match (self, other) {
|
||||
(Self::Numeric(n1), Self::Numeric(n2)) => n1 == n2,
|
||||
(Self::String(s1), Self::String(s2)) => s1 == s2,
|
||||
(Self::Blob(b1), Self::Blob(b2)) => b1 == b2,
|
||||
(Self::Keyword(k1), Self::Keyword(k2)) => k1 == k2,
|
||||
(Self::Null, Self::Null) => true,
|
||||
(Self::CurrentDate, Self::CurrentDate) => true,
|
||||
(Self::CurrentTime, Self::CurrentTime) => true,
|
||||
(Self::CurrentTimestamp, Self::CurrentTimestamp) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Textual comparison operator in an expression
|
||||
@@ -648,6 +767,20 @@ impl From<YYCODETYPE> for Operator {
|
||||
}
|
||||
}
|
||||
|
||||
impl Operator {
|
||||
pub fn is_commutative(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Operator::Add
|
||||
| Operator::Multiply
|
||||
| Operator::BitwiseAnd
|
||||
| Operator::BitwiseOr
|
||||
| Operator::Equals
|
||||
| Operator::NotEquals
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Unary operators
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
||||
pub enum UnaryOperator {
|
||||
@@ -1871,4 +2004,139 @@ mod test {
|
||||
fn name(s: &'static str) -> Name {
|
||||
Name(s.to_owned())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exprs_are_equivalent() {
|
||||
use super::{Expr, Id, Literal, Operator::*};
|
||||
|
||||
// commutative addition
|
||||
let expr1 = Expr::Binary(
|
||||
Box::new(Expr::Literal(Literal::Numeric("1".to_string()))),
|
||||
Add,
|
||||
Box::new(Expr::Literal(Literal::Numeric("2".to_string()))),
|
||||
);
|
||||
let expr2 = Expr::Binary(
|
||||
Box::new(Expr::Literal(Literal::Numeric("2".to_string()))),
|
||||
Add,
|
||||
Box::new(Expr::Literal(Literal::Numeric("1".to_string()))),
|
||||
);
|
||||
assert!(super::exprs_are_equivalent(&expr1, &expr2));
|
||||
|
||||
// non-commutative subtraction
|
||||
let expr3 = Expr::Binary(
|
||||
Box::new(Expr::Literal(Literal::Numeric("3".to_string()))),
|
||||
Subtract,
|
||||
Box::new(Expr::Literal(Literal::Numeric("2".to_string()))),
|
||||
);
|
||||
let expr4 = Expr::Binary(
|
||||
Box::new(Expr::Literal(Literal::Numeric("2".to_string()))),
|
||||
Subtract,
|
||||
Box::new(Expr::Literal(Literal::Numeric("3".to_string()))),
|
||||
);
|
||||
assert!(!super::exprs_are_equivalent(&expr3, &expr4));
|
||||
|
||||
// case-insensitive function calls
|
||||
let func1 = Expr::FunctionCall {
|
||||
name: Id("SUM".to_string()),
|
||||
distinctness: None,
|
||||
args: Some(vec![Expr::Id(Id("x".to_string()))]),
|
||||
order_by: None,
|
||||
filter_over: None,
|
||||
};
|
||||
let func2 = Expr::FunctionCall {
|
||||
name: Id("sum".to_string()),
|
||||
distinctness: None,
|
||||
args: Some(vec![Expr::Id(Id("x".to_string()))]),
|
||||
order_by: None,
|
||||
filter_over: None,
|
||||
};
|
||||
assert!(super::exprs_are_equivalent(&func1, &func2));
|
||||
|
||||
// DISTINCT function argument mismatch
|
||||
let func3 = Expr::FunctionCall {
|
||||
name: Id("SUM".to_string()),
|
||||
distinctness: Some(super::Distinctness::Distinct),
|
||||
args: Some(vec![Expr::Id(Id("x".to_string()))]),
|
||||
order_by: None,
|
||||
filter_over: None,
|
||||
};
|
||||
assert!(!super::exprs_are_equivalent(&func1, &func3));
|
||||
|
||||
// commutative multiplication
|
||||
let expr5 = Expr::Binary(
|
||||
Box::new(Expr::Literal(Literal::Numeric("4".to_string()))),
|
||||
Multiply,
|
||||
Box::new(Expr::Literal(Literal::Numeric("5".to_string()))),
|
||||
);
|
||||
let expr6 = Expr::Binary(
|
||||
Box::new(Expr::Literal(Literal::Numeric("5".to_string()))),
|
||||
Multiply,
|
||||
Box::new(Expr::Literal(Literal::Numeric("4".to_string()))),
|
||||
);
|
||||
assert!(super::exprs_are_equivalent(&expr5, &expr6));
|
||||
|
||||
// parenthesized expressions
|
||||
let expr7 = Expr::Parenthesized(vec![Expr::Binary(
|
||||
Box::new(Expr::Literal(Literal::Numeric("6".to_string()))),
|
||||
Add,
|
||||
Box::new(Expr::Literal(Literal::Numeric("7".to_string()))),
|
||||
)]);
|
||||
let expr8 = Expr::Binary(
|
||||
Box::new(Expr::Literal(Literal::Numeric("6".to_string()))),
|
||||
Add,
|
||||
Box::new(Expr::Literal(Literal::Numeric("7".to_string()))),
|
||||
);
|
||||
assert!(super::exprs_are_equivalent(&expr7, &expr8));
|
||||
|
||||
// LIKE expressions with escape clauses
|
||||
let expr9 = Expr::Like {
|
||||
lhs: Box::new(Expr::Id(Id("name".to_string()))),
|
||||
not: false,
|
||||
op: super::LikeOperator::Like,
|
||||
rhs: Box::new(Expr::Literal(Literal::String("%john%".to_string()))),
|
||||
escape: Some(Box::new(Expr::Literal(Literal::String("\\".to_string())))),
|
||||
};
|
||||
let expr10 = Expr::Like {
|
||||
lhs: Box::new(Expr::Id(Id("name".to_string()))),
|
||||
not: false,
|
||||
op: super::LikeOperator::Like,
|
||||
rhs: Box::new(Expr::Literal(Literal::String("%john%".to_string()))),
|
||||
escape: Some(Box::new(Expr::Literal(Literal::String("\\".to_string())))),
|
||||
};
|
||||
assert!(super::exprs_are_equivalent(&expr9, &expr10));
|
||||
|
||||
// differing escape clauses in LIKE
|
||||
let expr11 = Expr::Like {
|
||||
lhs: Box::new(Expr::Id(Id("name".to_string()))),
|
||||
not: false,
|
||||
op: super::LikeOperator::Like,
|
||||
rhs: Box::new(Expr::Literal(Literal::String("%john%".to_string()))),
|
||||
escape: Some(Box::new(Expr::Literal(Literal::String("#".to_string())))),
|
||||
};
|
||||
assert!(!super::exprs_are_equivalent(&expr9, &expr11));
|
||||
|
||||
// BETWEEN expressions
|
||||
let expr12 = Expr::Between {
|
||||
lhs: Box::new(Expr::Id(Id("age".to_string()))),
|
||||
not: false,
|
||||
start: Box::new(Expr::Literal(Literal::Numeric("18".to_string()))),
|
||||
end: Box::new(Expr::Literal(Literal::Numeric("65".to_string()))),
|
||||
};
|
||||
let expr13 = Expr::Between {
|
||||
lhs: Box::new(Expr::Id(Id("age".to_string()))),
|
||||
not: false,
|
||||
start: Box::new(Expr::Literal(Literal::Numeric("18".to_string()))),
|
||||
end: Box::new(Expr::Literal(Literal::Numeric("65".to_string()))),
|
||||
};
|
||||
assert!(super::exprs_are_equivalent(&expr12, &expr13));
|
||||
|
||||
// differing BETWEEN bounds
|
||||
let expr14 = Expr::Between {
|
||||
lhs: Box::new(Expr::Id(Id("age".to_string()))),
|
||||
not: false,
|
||||
start: Box::new(Expr::Literal(Literal::Numeric("20".to_string()))),
|
||||
end: Box::new(Expr::Literal(Literal::Numeric("65".to_string()))),
|
||||
};
|
||||
assert!(!super::exprs_are_equivalent(&expr12, &expr14));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user