diff --git a/vendored/sqlite3-parser/src/to_sql_string/expr.rs b/vendored/sqlite3-parser/src/to_sql_string/expr.rs index 9744a5148..84e63104b 100644 --- a/vendored/sqlite3-parser/src/to_sql_string/expr.rs +++ b/vendored/sqlite3-parser/src/to_sql_string/expr.rs @@ -408,3 +408,6 @@ impl ToSqlString for ast::UnaryOperator { .to_string() } } + +#[cfg(test)] +mod tests {} diff --git a/vendored/sqlite3-parser/src/to_sql_string/mod.rs b/vendored/sqlite3-parser/src/to_sql_string/mod.rs index 48ef90c03..d47617bce 100644 --- a/vendored/sqlite3-parser/src/to_sql_string/mod.rs +++ b/vendored/sqlite3-parser/src/to_sql_string/mod.rs @@ -26,3 +26,20 @@ impl ToSqlString for Box { T::to_sql_string(&self, context) } } + +#[cfg(test)] +mod tests { + use super::ToSqlContext; + + struct TestContext; + + impl ToSqlContext for TestContext { + fn get_column_name(&self, _table_id: crate::ast::TableInternalId, _col_idx: usize) -> &str { + "placeholder_column" + } + + fn get_table_name(&self, _id: crate::ast::TableInternalId) -> &str { + "placeholder_table" + } + } +} diff --git a/vendored/sqlite3-parser/src/to_sql_string/stmt/select.rs b/vendored/sqlite3-parser/src/to_sql_string/stmt/select.rs index 64bdf9baf..de1aa0c84 100644 --- a/vendored/sqlite3-parser/src/to_sql_string/stmt/select.rs +++ b/vendored/sqlite3-parser/src/to_sql_string/stmt/select.rs @@ -1,36 +1,499 @@ use crate::{ - ast, + ast::{self}, to_sql_string::{ToSqlContext, ToSqlString}, }; impl ToSqlString for ast::Select { fn to_sql_string(&self, context: &C) -> String { - let mut ret = String::new(); - ret + let mut ret = Vec::new(); + if let Some(with) = &self.with { + let joined_expr = with + .ctes + .iter() + .map(|cte| cte.to_sql_string(context)) + .collect::>() + .join(", "); + + ret.push(format!( + "WITH {} {}", + if with.recursive { "RECURSIVE" } else { "" }, + joined_expr + )); + } + + ret.push(self.body.to_sql_string(context)); + + if let Some(order_by) = &self.order_by { + // TODO: SortedColumn missing collation in ast + let joined_cols = order_by + .iter() + .map(|col| col.to_sql_string(context)) + .collect::>() + .join(", "); + ret.push(format!("ORDER BY {}", joined_cols)); + } + if let Some(limit) = &self.limit { + ret.push(format!("LIMIT {}", limit.expr.to_sql_string(context))); + // TODO: missing , + expr in ast + if let Some(offset) = &limit.offset { + ret.push(format!("OFFSET {}", offset.to_sql_string(context))); + } + } + ret.join(" ") } } impl ToSqlString for ast::SelectBody { fn to_sql_string(&self, context: &C) -> String { - let mut ret = String::new(); + let mut ret = self.select.to_sql_string(context); + + if let Some(compounds) = &self.compounds { + ret.push(' '); + let compound_selects = compounds + .iter() + .map(|compound_select| { + let mut curr = compound_select.operator.to_sql_string(context); + curr.push(' '); + curr.push_str(&compound_select.select.to_sql_string(context)); + curr + }) + .collect::>() + .join(" "); + ret.push_str(&compound_selects); + } ret } } impl ToSqlString for ast::OneSelect { fn to_sql_string(&self, context: &C) -> String { - let mut ret = String::new(); match self { - ast::OneSelect::Select(select) => ret, - // TODO: come back here when we implement ToSqlString for Expr - ast::OneSelect::Values(values) => ret, + ast::OneSelect::Select(select) => select.to_sql_string(context), + ast::OneSelect::Values(values) => { + let joined_values = values + .iter() + .map(|value| { + let joined_value = value + .iter() + .map(|e| e.to_sql_string(context)) + .collect::>() + .join(","); + format!("({})", joined_value) + }) + .collect::>() + .join(", "); + joined_values + } } } } impl ToSqlString for ast::SelectInner { fn to_sql_string(&self, context: &C) -> String { - let mut ret = String::new(); + let mut ret = Vec::with_capacity(2 + self.columns.len()); + ret.push("SELECT".to_string()); + if let Some(distinct) = self.distinctness { + ret.push(distinct.to_sql_string(context)); + } + let joined_cols = self + .columns + .iter() + .map(|col| col.to_sql_string(context)) + .collect::>() + .join(", "); + ret.push(joined_cols); + + if let Some(from) = &self.from { + ret.push(from.to_sql_string(context)); + } + if let Some(where_expr) = &self.where_clause { + ret.push("WHERE".to_string()); + ret.push(where_expr.to_sql_string(context)); + } + if let Some(group_by) = &self.group_by { + ret.push(group_by.to_sql_string(context)); + } + if let Some(window_clause) = &self.window_clause { + ret.push("WINDOW".to_string()); + let joined_window = window_clause + .iter() + .map(|window_def| window_def.to_sql_string(context)) + .collect::>() + .join(","); + ret.push(joined_window); + } + + ret.join(" ") + } +} + +impl ToSqlString for ast::FromClause { + fn to_sql_string(&self, context: &C) -> String { + let mut ret = String::from("FROM"); + if let Some(select_table) = &self.select { + ret.push(' '); + ret.push_str(&select_table.to_sql_string(context)); + } + if let Some(joins) = &self.joins { + ret.push(' '); + let joined_joins = joins + .iter() + .map(|join| { + let mut curr = join.operator.to_sql_string(context); + curr.push(' '); + curr.push_str(&join.table.to_sql_string(context)); + if let Some(join_constraint) = &join.constraint { + curr.push(' '); + curr.push_str(&join_constraint.to_sql_string(context)); + } + curr + }) + .collect::>() + .join(" "); + ret.push_str(&joined_joins); + } ret } } + +impl ToSqlString for ast::SelectTable { + fn to_sql_string(&self, context: &C) -> String { + let mut ret = String::new(); + match self { + Self::Table(name, alias, indexed) => { + ret.push_str(&name.to_sql_string(context)); + if let Some(alias) = alias { + ret.push(' '); + ret.push_str(&alias.to_sql_string(context)); + } + if let Some(indexed) = indexed { + ret.push(' '); + ret.push_str(&indexed.to_sql_string(context)); + } + } + Self::TableCall(table_func, args, alias) => { + ret.push_str(&table_func.to_sql_string(context)); + if let Some(args) = args { + ret.push(' '); + let joined_args = args + .iter() + .map(|arg| arg.to_sql_string(context)) + .collect::>() + .join(", "); + ret.push_str(&joined_args); + } + if let Some(alias) = alias { + ret.push(' '); + ret.push_str(&alias.to_sql_string(context)); + } + } + Self::Select(select, alias) => { + ret.push_str(&select.to_sql_string(context)); + if let Some(alias) = alias { + ret.push(' '); + ret.push_str(&alias.to_sql_string(context)); + } + } + Self::Sub(from_clause, alias) => { + ret.push_str(&from_clause.to_sql_string(context)); + if let Some(alias) = alias { + ret.push(' '); + ret.push_str(&alias.to_sql_string(context)); + } + } + } + ret + } +} + +impl ToSqlString for ast::CommonTableExpr { + fn to_sql_string(&self, context: &C) -> String { + let mut ret = self.tbl_name.0.clone(); + if let Some(cols) = &self.columns { + ret.push_str(" ("); + let joined_cols = cols + .iter() + .map(|col| col.to_sql_string(context)) + .collect::>() + .join(", "); + ret.push_str(&joined_cols); + ret.push(')'); + } + ret.push_str(" AS "); + ret.push_str(&self.materialized.to_sql_string(context)); + ret.push_str(" ("); + ret.push_str(&self.select.to_sql_string(context)); + ret.push(')'); + ret + } +} + +impl ToSqlString for ast::IndexedColumn { + fn to_sql_string(&self, _context: &C) -> String { + self.col_name.0.to_string() + } +} + +impl ToSqlString for ast::SortedColumn { + fn to_sql_string(&self, context: &C) -> String { + let mut curr = self.expr.to_sql_string(context); + if let Some(sort_order) = self.order { + curr.push(' '); + curr.push_str(&sort_order.to_sql_string(context)); + } + if let Some(nulls_order) = self.nulls { + curr.push(' '); + curr.push_str(&nulls_order.to_sql_string(context)); + } + curr + } +} + +impl ToSqlString for ast::SortOrder { + fn to_sql_string(&self, _context: &C) -> String { + match self { + Self::Asc => "ASC", + Self::Desc => "DESC", + } + .to_string() + } +} + +impl ToSqlString for ast::NullsOrder { + fn to_sql_string(&self, _context: &C) -> String { + match self { + Self::First => "NULLS FIRST", + Self::Last => "NULLS LAST", + } + .to_string() + } +} + +impl ToSqlString for ast::Materialized { + fn to_sql_string(&self, _context: &C) -> String { + match self { + Self::Any => "", + Self::No => "NOT MATERIALIZED", + Self::Yes => "MATERIALIZED", + } + .to_string() + } +} + +impl ToSqlString for ast::CompoundOperator { + fn to_sql_string(&self, _context: &C) -> String { + match self { + Self::Except => "EXCEPT", + Self::Intersect => "INTERSECT", + Self::Union => "UNION", + Self::UnionAll => "UNION ALL", + } + .to_string() + } +} + +impl ToSqlString for ast::ResultColumn { + fn to_sql_string(&self, context: &C) -> String { + let mut ret = String::new(); + match self { + Self::Expr(expr, alias) => { + ret.push_str(&expr.to_sql_string(context)); + if let Some(alias) = alias { + ret.push(' '); + ret.push_str(&alias.to_sql_string(context)); + } + } + Self::Star => { + ret.push('*'); + } + Self::TableStar(name) => { + ret.push_str(&format!("{}.*", name.0)); + } + } + ret + } +} + +impl ToSqlString for ast::As { + fn to_sql_string(&self, _context: &C) -> String { + match self { + Self::As(alias) => { + format!("AS {}", alias.0) + } + Self::Elided(alias) => alias.0.clone(), + } + } +} + +impl ToSqlString for ast::Indexed { + fn to_sql_string(&self, _context: &C) -> String { + match self { + Self::NotIndexed => "NOT INDEXED".to_string(), + Self::IndexedBy(name) => format!("INDEXED BY {}", name.0), + } + } +} + +impl ToSqlString for ast::JoinOperator { + fn to_sql_string(&self, context: &C) -> String { + match self { + Self::Comma => ",".to_string(), + Self::TypedJoin(join) => format!( + "JOIN {}", + join.map_or(String::new(), |join| join.to_sql_string(context)) + ), + } + } +} + +impl ToSqlString for ast::JoinType { + fn to_sql_string(&self, _context: &C) -> String { + let mut modifiers = Vec::new(); + if self.contains(Self::NATURAL) { + modifiers.push("NATURAL"); + } + if self.contains(Self::LEFT) { + modifiers.push("LEFT"); + } + if self.contains(Self::OUTER) { + modifiers.push("OUTER"); + } + if self.contains(Self::RIGHT) { + modifiers.push("RIGHT"); + } + if self.contains(Self::INNER) { + modifiers.push("INNER"); + } + if self.contains(Self::CROSS) { + modifiers.push("CROSS"); + } + modifiers.join(" ") + } +} + +impl ToSqlString for ast::JoinConstraint { + fn to_sql_string(&self, context: &C) -> String { + match self { + Self::On(expr) => { + format!("ON {}", expr.to_sql_string(context)) + } + Self::Using(col_names) => { + let joined_names = col_names + .iter() + .map(|col| col.0.clone()) + .collect::>() + .join(","); + format!("USING ({})", joined_names) + } + } + } +} + +impl ToSqlString for ast::GroupBy { + fn to_sql_string(&self, context: &C) -> String { + let mut ret = String::from("GROUP BY "); + let curr = self + .exprs + .iter() + .map(|expr| expr.to_sql_string(context)) + .collect::>() + .join(","); + ret.push_str(&curr); + if let Some(having) = &self.having { + ret.push_str(&format!(" HAVING {}", having.to_sql_string(context))); + } + ret + } +} + +impl ToSqlString for ast::WindowDef { + fn to_sql_string(&self, context: &C) -> String { + format!("{} AS {}", self.name.0, self.window.to_sql_string(context)) + } +} + +impl ToSqlString for ast::Window { + fn to_sql_string(&self, context: &C) -> String { + let mut ret = Vec::new(); + if let Some(name) = &self.base { + ret.push(name.0.clone()); + } + if let Some(partition) = &self.partition_by { + let joined_exprs = partition + .iter() + .map(|e| e.to_sql_string(context)) + .collect::>() + .join(","); + ret.push(format!("PARTITION BY {}", joined_exprs)); + } + if let Some(order_by) = &self.order_by { + let joined_cols = order_by + .iter() + .map(|col| col.to_sql_string(context)) + .collect::>() + .join(", "); + ret.push(format!("ORDER BY {}", joined_cols)); + } + if let Some(frame_claue) = &self.frame_clause { + ret.push(frame_claue.to_sql_string(context)); + } + format!("({})", ret.join(" ")) + } +} + +impl ToSqlString for ast::FrameClause { + fn to_sql_string(&self, context: &C) -> String { + let mut ret = Vec::new(); + ret.push(self.mode.to_sql_string(context)); + let start_sql = self.start.to_sql_string(context); + if let Some(end) = &self.end { + ret.push(format!( + "BETWEEN {} AND {}", + start_sql, + end.to_sql_string(context) + )); + } else { + ret.push(start_sql); + } + if let Some(exclude) = &self.exclude { + ret.push(exclude.to_sql_string(context)); + } + + ret.join(" ") + } +} + +impl ToSqlString for ast::FrameMode { + fn to_sql_string(&self, _context: &C) -> String { + match self { + Self::Groups => "GROUPS", + Self::Range => "RANGE", + Self::Rows => "ROWS", + } + .to_string() + } +} + +impl ToSqlString for ast::FrameBound { + fn to_sql_string(&self, context: &C) -> String { + match self { + Self::CurrentRow => "CURRENT ROW".to_string(), + Self::Following(expr) => format!("{} FOLLOWING", expr.to_sql_string(context)), + Self::Preceding(expr) => format!("{} PRECEDING", expr.to_sql_string(context)), + Self::UnboundedFollowing => "UNBOUNDED FOLLOWING".to_string(), + Self::UnboundedPreceding => "UNBOUNDED PRECEDING".to_string(), + } + } +} + +impl ToSqlString for ast::FrameExclude { + fn to_sql_string(&self, _context: &C) -> String { + let clause = match self { + Self::CurrentRow => "CURRENT ROW", + Self::Group => "GROUP", + Self::NoOthers => "NO OTHERS", + Self::Ties => "TIES", + }; + format!("EXCLUDE {}", clause) + } +}