From e0c2a09d71a10c4b8f3a21cb6dccb3e0625cd653 Mon Sep 17 00:00:00 2001 From: pedrocarlo Date: Tue, 27 May 2025 17:28:49 -0300 Subject: [PATCH] more tests for select + fixes --- .../sqlite3-parser/src/to_sql_string/expr.rs | 86 ++++++------ .../src/to_sql_string/stmt/mod.rs | 16 +++ .../src/to_sql_string/stmt/select.rs | 122 +++++++++++++++--- 3 files changed, 167 insertions(+), 57 deletions(-) diff --git a/vendored/sqlite3-parser/src/to_sql_string/expr.rs b/vendored/sqlite3-parser/src/to_sql_string/expr.rs index 7fbb875e0..861a02975 100644 --- a/vendored/sqlite3-parser/src/to_sql_string/expr.rs +++ b/vendored/sqlite3-parser/src/to_sql_string/expr.rs @@ -42,9 +42,10 @@ impl ToSqlString for Expr { ret.push_str("CASE "); if let Some(base) = base { ret.push_str(&base.to_sql_string(context)); + ret.push(' '); } for (when, then) in when_then_pairs { - ret.push_str(" WHEN "); + ret.push_str("WHEN "); ret.push_str(&when.to_sql_string(context)); ret.push_str(" THEN "); ret.push_str(&then.to_sql_string(context)); @@ -103,13 +104,14 @@ impl ToSqlString for Expr { ret.push(')'); if let Some(filter_over) = filter_over { if let Some(filter) = &filter_over.filter_clause { - ret.push_str(" FILTER ("); - ret.push_str("WHERE "); - ret.push_str(&filter.to_sql_string(context)); - ret.push(')'); + ret.push_str(&format!( + " FILTER (WHERE {})", + filter.to_sql_string(context) + )); } - if let Some(_over) = &filter_over.over_clause { - todo!() + if let Some(over) = &filter_over.over_clause { + ret.push(' '); + ret.push_str(&over.to_sql_string(context)); } } } @@ -118,13 +120,14 @@ impl ToSqlString for Expr { ret.push_str("(*)"); if let Some(filter_over) = filter_over { if let Some(filter) = &filter_over.filter_clause { - ret.push_str(" FILTER ("); - ret.push_str("WHERE "); - ret.push_str(&filter.to_sql_string(context)); - ret.push(')'); + ret.push_str(&format!( + " FILTER (WHERE {})", + filter.to_sql_string(context) + )); } - if let Some(_over) = &filter_over.over_clause { - todo!() + if let Some(over) = &filter_over.over_clause { + ret.push(' '); + ret.push_str(&over.to_sql_string(context)); } } } @@ -147,31 +150,27 @@ impl ToSqlString for Expr { table: _, } => todo!(), Expr::InList { lhs, not, rhs } => { - ret.push_str(&lhs.to_sql_string(context)); - ret.push(' '); - if *not { - ret.push_str("NOT "); - } - ret.push('('); - if let Some(rhs) = rhs { - let joined_args = rhs - .iter() - .map(|expr| expr.to_sql_string(context)) - .collect::>() - .join(", "); - ret.push_str(&joined_args); - } - ret.push(')'); + ret.push_str(&format!( + "{} {}IN ({})", + lhs.to_sql_string(context), + if *not { "NOT " } else { "" }, + if let Some(rhs) = rhs { + rhs.iter() + .map(|expr| expr.to_sql_string(context)) + .collect::>() + .join(", ") + } else { + "".to_string() + } + )); } Expr::InSelect { lhs, not, rhs } => { - ret.push_str(&lhs.to_sql_string(context)); - ret.push(' '); - if *not { - ret.push_str("NOT "); - } - ret.push('('); - ret.push_str(&rhs.to_sql_string(context)); - ret.push(')'); + ret.push_str(&format!( + "{} {}IN ({})", + lhs.to_sql_string(context), + if *not { "NOT " } else { "" }, + rhs.to_sql_string(context) + )); } Expr::InTable { lhs, @@ -409,5 +408,20 @@ impl ToSqlString for ast::UnaryOperator { } } +impl ToSqlString for ast::Over { + fn to_sql_string(&self, context: &C) -> String { + let mut ret = vec!["OVER".to_string()]; + match self { + Self::Name(name) => { + ret.push(name.0.clone()); + } + Self::Window(window) => { + ret.push(window.to_sql_string(context)); + } + } + ret.join(" ") + } +} + #[cfg(test)] mod tests {} diff --git a/vendored/sqlite3-parser/src/to_sql_string/stmt/mod.rs b/vendored/sqlite3-parser/src/to_sql_string/stmt/mod.rs index 8353b9c46..bf314126e 100644 --- a/vendored/sqlite3-parser/src/to_sql_string/stmt/mod.rs +++ b/vendored/sqlite3-parser/src/to_sql_string/stmt/mod.rs @@ -35,6 +35,22 @@ mod tests { ); } }; + ($test_name:ident, $input:literal, $($attribute:meta),*) => { + #[test] + $(#[$attribute])* + fn $test_name() { + let context = crate::to_sql_string::stmt::tests::TestContext; + let input: &str = $input; + let mut parser = crate::lexer::sql::Parser::new(input.as_bytes()); + let cmd = fallible_iterator::FallibleIterator::next(&mut parser) + .unwrap() + .unwrap(); + assert_eq!( + input, + crate::to_sql_string::ToSqlString::to_sql_string(cmd.stmt(), &context) + ); + } + } } pub(crate) struct TestContext; diff --git a/vendored/sqlite3-parser/src/to_sql_string/stmt/select.rs b/vendored/sqlite3-parser/src/to_sql_string/stmt/select.rs index 99d554e41..0d162b66d 100644 --- a/vendored/sqlite3-parser/src/to_sql_string/stmt/select.rs +++ b/vendored/sqlite3-parser/src/to_sql_string/stmt/select.rs @@ -5,6 +5,7 @@ use crate::{ impl ToSqlString for ast::Select { fn to_sql_string(&self, context: &C) -> String { + dbg!(&self); let mut ret = Vec::new(); if let Some(with) = &self.with { let joined_expr = with @@ -15,8 +16,8 @@ impl ToSqlString for ast::Select { .join(", "); ret.push(format!( - "WITH {} {}", - if with.recursive { "RECURSIVE" } else { "" }, + "WITH{} {}", + if with.recursive { " RECURSIVE " } else { "" }, joined_expr )); } @@ -212,23 +213,29 @@ impl ToSqlString for ast::SelectTable { impl ToSqlString for ast::CommonTableExpr { fn to_sql_string(&self, context: &C) -> String { - let mut ret = self.tbl_name.0.clone(); + let mut ret = Vec::with_capacity(self.columns.as_ref().map_or(2, |cols| cols.len())); + ret.push(self.tbl_name.0.clone()); if let Some(cols) = &self.columns { - ret.push_str(" ("); let joined_cols = cols .iter() .map(|col| col.to_sql_string(context)) .collect::>() .join(", "); - ret.push_str(&joined_cols); - ret.push(')'); + + ret.push(format!("({})", joined_cols)); } - ret.push_str(" AS "); - ret.push_str(&self.materialized.to_sql_string(context)); - ret.push_str(" ("); - ret.push_str(&self.select.to_sql_string(context)); - ret.push(')'); - ret + ret.push(format!( + "AS {}({})", + { + let mut materialized = self.materialized.to_sql_string(context); + if !materialized.is_empty() { + materialized.push(' '); + } + materialized + }, + self.select.to_sql_string(context) + )); + ret.join(" ") } } @@ -345,7 +352,7 @@ impl ToSqlString for ast::JoinOperator { Self::TypedJoin(join) => { let join_keyword = "JOIN"; if let Some(join) = join { - format!("{} {}", join_keyword, join.to_sql_string(context)) + format!("{} {}", join.to_sql_string(context), join_keyword) } else { join_keyword.to_string() } @@ -360,15 +367,21 @@ impl ToSqlString for ast::JoinType { if self.contains(Self::NATURAL) { modifiers.push("NATURAL"); } - if self.contains(Self::LEFT) { - modifiers.push("LEFT"); - } - if self.contains(Self::OUTER) { - modifiers.push("OUTER"); - } - if self.contains(Self::RIGHT) { - modifiers.push("RIGHT"); + if self.contains(Self::LEFT) || self.contains(Self::RIGHT) { + // TODO: I think the parser incorrectly asigns outer to every LEFT and RIGHT query + if self.contains(Self::LEFT | Self::RIGHT) { + modifiers.push("FULL"); + } else if self.contains(Self::LEFT) { + modifiers.push("LEFT"); + } else if self.contains(Self::RIGHT) { + modifiers.push("RIGHT"); + } + // FIXME: ignore outer joins as I think they are parsed incorrectly in the bitflags + // if self.contains(Self::OUTER) { + // modifiers.push("OUTER"); + // } } + if self.contains(Self::INNER) { modifiers.push("INNER"); } @@ -558,4 +571,71 @@ mod tests { test_select_with_subquery, "SELECT a FROM (SELECT b FROM t) AS sub" ); + + to_sql_string_test!( + test_select_nested_subquery, + "SELECT a FROM (SELECT b FROM (SELECT c FROM t WHERE c > 10) AS sub1 WHERE b < 20) AS sub2" + ); + + to_sql_string_test!( + test_select_multiple_joins, + "SELECT t1.a, t2.b, t3.c FROM t1 JOIN t2 ON t1.id = t2.id LEFT JOIN t3 ON t2.id = t3.id" + ); + + to_sql_string_test!( + test_select_with_cte, + "WITH cte AS (SELECT a FROM t WHERE b = 1) SELECT a FROM cte WHERE a > 10" + ); + + to_sql_string_test!( + test_select_with_window_function, + "SELECT a, ROW_NUMBER() OVER (PARTITION BY b ORDER BY c DESC) AS rn FROM t" + ); + + to_sql_string_test!( + test_select_with_complex_where, + "SELECT a FROM t WHERE b IN (1, 2, 3) AND c BETWEEN 10 AND 20 OR d IS NULL" + ); + + to_sql_string_test!( + test_select_with_case, + "SELECT CASE WHEN a > 0 THEN 'positive' ELSE 'non-positive' END AS result FROM t" + ); + + to_sql_string_test!(test_select_with_aggregate_and_join, "SELECT t1.a, COUNT(t2.b) FROM t1 LEFT JOIN t2 ON t1.id = t2.id GROUP BY t1.a HAVING COUNT(t2.b) > 5"); + + to_sql_string_test!(test_select_with_multiple_ctes, "WITH cte1 AS (SELECT a FROM t WHERE b = 1), cte2 AS (SELECT c FROM t2 WHERE d = 2) SELECT cte1.a, cte2.c FROM cte1 JOIN cte2 ON cte1.a = cte2.c"); + + to_sql_string_test!( + test_select_with_union, + "SELECT a FROM t1 UNION SELECT b FROM t2" + ); + + to_sql_string_test!( + test_select_with_union_all, + "SELECT a FROM t1 UNION ALL SELECT b FROM t2" + ); + + to_sql_string_test!( + test_select_with_exists, + "SELECT a FROM t WHERE EXISTS (SELECT 1 FROM t2 WHERE t2.b = t.a)" + ); + + to_sql_string_test!( + test_select_with_correlated_subquery, + "SELECT a, (SELECT COUNT(*) FROM t2 WHERE t2.b = t.a) AS count_b FROM t" + ); + + to_sql_string_test!( + test_select_with_complex_order_by, + "SELECT a, b FROM t ORDER BY CASE WHEN a IS NULL THEN 1 ELSE 0 END, b ASC, c DESC" + ); + + to_sql_string_test!( + test_select_with_full_outer_join, + "SELECT t1.a, t2.b FROM t1 FULL OUTER JOIN t2 ON t1.id = t2.id", + ignore = "OUTER JOIN is incorrectly parsed in parser" + ); + + to_sql_string_test!(test_select_with_aggregate_window, "SELECT a, SUM(b) OVER (PARTITION BY c ORDER BY d ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS running_sum FROM t"); }