//! Abstract Syntax Tree pub mod check; pub mod fmt; use std::num::ParseIntError; use std::ops::Deref; use std::str::{self, Bytes, FromStr}; use strum_macros::{EnumIter, EnumString}; use fmt::{ToTokens, TokenStream}; use indexmap::{IndexMap, IndexSet}; use crate::custom_err; use crate::dialect::TokenType::{self, *}; use crate::dialect::{from_bytes, from_token, Token}; use crate::parser::{parse::YYCODETYPE, ParserError}; /// `?` or `$` Prepared statement arg placeholder(s) #[derive(Default)] pub struct ParameterInfo { /// Number of SQL parameters in a prepared statement, like `sqlite3_bind_parameter_count` pub count: u32, /// Parameter name(s) if any pub names: IndexSet, } // https://sqlite.org/lang_expr.html#parameters impl TokenStream for ParameterInfo { type Error = ParseIntError; fn append(&mut self, ty: TokenType, value: Option<&str>) -> Result<(), Self::Error> { if ty == TK_VARIABLE { if let Some(variable) = value { if variable == "?" { self.count = self.count.saturating_add(1); } else if variable.as_bytes()[0] == b'?' { let n = u32::from_str(&variable[1..])?; if n > self.count { self.count = n; } } else if self.names.insert(variable.to_owned()) { self.count = self.count.saturating_add(1); } } } Ok(()) } } /// Statement or Explain statement // https://sqlite.org/syntax/sql-stmt.html #[derive(Clone, Debug, PartialEq, Eq)] pub enum Cmd { /// `EXPLAIN` statement Explain(Stmt), /// `EXPLAIN QUERY PLAN` statement ExplainQueryPlan(Stmt), /// statement Stmt(Stmt), } pub(crate) enum ExplainKind { Explain, QueryPlan, } /// SQL statement // https://sqlite.org/syntax/sql-stmt.html #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum Stmt { /// `ALTER TABLE`: table name, body AlterTable(Box<(QualifiedName, AlterTableBody)>), /// `ANALYSE`: object name Analyze(Option), /// `ATTACH DATABASE` Attach { /// filename // TODO distinction between ATTACH and ATTACH DATABASE expr: Box, /// schema name db_name: Box, /// password key: Option>, }, /// `BEGIN`: tx type, tx name Begin(Option, Option), /// `COMMIT`/`END`: tx name Commit(Option), // TODO distinction between COMMIT and END /// `CREATE INDEX` CreateIndex { /// `UNIQUE` unique: bool, /// `IF NOT EXISTS` if_not_exists: bool, /// index name idx_name: Box, /// table name tbl_name: Name, /// indexed columns or expressions columns: Vec, /// partial index where_clause: Option>, }, /// `CREATE TABLE` CreateTable { /// `TEMPORARY` temporary: bool, // TODO distinction between TEMP and TEMPORARY /// `IF NOT EXISTS` if_not_exists: bool, /// table name tbl_name: QualifiedName, /// table body body: Box, }, /// `CREATE TRIGGER` CreateTrigger(Box), /// `CREATE VIEW` CreateView { /// `TEMPORARY` temporary: bool, /// `IF NOT EXISTS` if_not_exists: bool, /// view name view_name: QualifiedName, /// columns columns: Option>, /// query select: Box), /// `UPDATE` Update(Box), /// `VACUUM`: database name, into expr Vacuum(Option, Option>), } /// `CREATE VIRTUAL TABLE` #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct CreateVirtualTable { /// `IF NOT EXISTS` pub if_not_exists: bool, /// table name pub tbl_name: QualifiedName, /// module name pub module_name: Name, /// args pub args: Option>, // TODO smol str } /// `CREATE TRIGGER #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct CreateTrigger { /// `TEMPORARY` pub temporary: bool, /// `IF NOT EXISTS` pub if_not_exists: bool, /// trigger name pub trigger_name: QualifiedName, /// `BEFORE`/`AFTER`/`INSTEAD OF` pub time: Option, /// `DELETE`/`INSERT`/`UPDATE` pub event: TriggerEvent, /// table name pub tbl_name: QualifiedName, /// `FOR EACH ROW` pub for_each_row: bool, /// `WHEN` pub when_clause: Option, /// statements pub commands: Vec, } /// `INSERT` #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Insert { /// CTE pub with: Option, /// `OR` pub or_conflict: Option, // TODO distinction between REPLACE and INSERT OR REPLACE /// table name pub tbl_name: QualifiedName, /// `COLUMNS` pub columns: Option, /// `VALUES` or `SELECT` pub body: InsertBody, /// `RETURNING` pub returning: Option>, } /// `UPDATE` clause #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Update { /// CTE pub with: Option, /// `OR` pub or_conflict: Option, /// table name pub tbl_name: QualifiedName, /// `INDEXED` pub indexed: Option, /// `SET` assignments pub sets: Vec, /// `FROM` pub from: Option, /// `WHERE` clause pub where_clause: Option>, /// `RETURNING` pub returning: Option>, /// `ORDER BY` pub order_by: Option>, /// `LIMIT` pub limit: Option>, } /// `DELETE` #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Delete { /// CTE pub with: Option, /// `FROM` table name pub tbl_name: QualifiedName, /// `INDEXED` pub indexed: Option, /// `WHERE` clause pub where_clause: Option>, /// `RETURNING` pub returning: Option>, /// `ORDER BY` pub order_by: Option>, /// `LIMIT` pub limit: Option>, } #[repr(transparent)] #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] /// Internal ID of a table reference. /// /// Used by [Expr::Column] and [Expr::RowId] to refer to a table. /// E.g. in 'SELECT * FROM t UNION ALL SELECT * FROM t', there are two table references, /// so there are two TableInternalIds. /// /// FIXME: rename this to TableReferenceId. pub struct TableInternalId(usize); impl Default for TableInternalId { fn default() -> Self { Self(1) } } impl From for TableInternalId { fn from(value: usize) -> Self { Self(value) } } impl std::ops::AddAssign for TableInternalId { fn add_assign(&mut self, rhs: usize) { self.0 += rhs; } } impl From for usize { fn from(value: TableInternalId) -> Self { value.0 } } impl std::fmt::Display for TableInternalId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "t{}", self.0) } } /// SQL expression // https://sqlite.org/syntax/expr.html #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum Expr { /// `BETWEEN` Between { /// expression lhs: Box, /// `NOT` not: bool, /// start start: Box, /// end end: Box, }, /// binary expression Binary(Box, Operator, Box), /// `CASE` expression Case { /// operand base: Option>, /// `WHEN` condition `THEN` result when_then_pairs: Vec<(Expr, Expr)>, /// `ELSE` result else_expr: Option>, }, /// CAST expression Cast { /// expression expr: Box, /// `AS` type name type_name: Option, }, /// `COLLATE`: expression Collate(Box, String), /// schema-name.table-name.column-name DoublyQualified(Name, Name, Name), /// `EXISTS` subquery Exists(Box, }, /// `IN` table name / function InTable { /// expression lhs: Box, /// `NOT` not: bool, /// table name rhs: QualifiedName, /// table function arguments args: Option>, }, /// `IS NULL` IsNull(Box), /// `LIKE` Like { /// expression lhs: Box, /// `NOT` not: bool, /// operator op: LikeOperator, /// pattern rhs: Box, /// `ESCAPE` char escape: Option>, }, /// Literal expression Literal(Literal), /// Name Name(Name), /// `NOT NULL` or `NOTNULL` NotNull(Box), /// Parenthesized subexpression Parenthesized(Vec), /// Qualified name Qualified(Name, Name), /// `RAISE` function call Raise(ResolveType, Option>), /// Subquery expression Subquery(Box, Option), /// subquery Sub(FromClause, Option), } /// Join operators // https://sqlite.org/syntax/join-operator.html #[derive(Copy, Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum JoinOperator { /// `,` Comma, /// `JOIN` TypedJoin(Option), } impl JoinOperator { pub(crate) fn from( token: Token, n1: Option, n2: Option, ) -> Result { Ok({ let mut jt = JoinType::try_from(token.1)?; for n in [&n1, &n2].into_iter().flatten() { jt |= JoinType::try_from(n.as_str().as_bytes())?; } if (jt & (JoinType::INNER | JoinType::OUTER)) == (JoinType::INNER | JoinType::OUTER) || (jt & (JoinType::OUTER | JoinType::LEFT | JoinType::RIGHT)) == JoinType::OUTER { return Err(custom_err!( "unsupported JOIN type: {:?} {:?} {:?}", str::from_utf8(token.1), n1, n2 )); } Self::TypedJoin(Some(jt)) }) } fn is_natural(&self) -> bool { match self { Self::TypedJoin(Some(jt)) => jt.contains(JoinType::NATURAL), _ => false, } } } // https://github.com/sqlite/sqlite/blob/80511f32f7e71062026edd471913ef0455563964/src/select.c#L197-L257 bitflags::bitflags! { /// `JOIN` types #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct JoinType: u8 { /// `INNER` const INNER = 0x01; /// `CROSS` => INNER|CROSS const CROSS = 0x02; /// `NATURAL` const NATURAL = 0x04; /// `LEFT` => LEFT|OUTER const LEFT = 0x08; /// `RIGHT` => RIGHT|OUTER const RIGHT = 0x10; /// `OUTER` const OUTER = 0x20; } } impl TryFrom<&[u8]> for JoinType { type Error = ParserError; fn try_from(s: &[u8]) -> Result { if b"CROSS".eq_ignore_ascii_case(s) { Ok(Self::INNER | Self::CROSS) } else if b"FULL".eq_ignore_ascii_case(s) { Ok(Self::LEFT | Self::RIGHT | Self::OUTER) } else if b"INNER".eq_ignore_ascii_case(s) { Ok(Self::INNER) } else if b"LEFT".eq_ignore_ascii_case(s) { Ok(Self::LEFT | Self::OUTER) } else if b"NATURAL".eq_ignore_ascii_case(s) { Ok(Self::NATURAL) } else if b"RIGHT".eq_ignore_ascii_case(s) { Ok(Self::RIGHT | Self::OUTER) } else if b"OUTER".eq_ignore_ascii_case(s) { Ok(Self::OUTER) } else { Err(custom_err!( "unsupported JOIN type: {:?}", str::from_utf8(s) )) } } } /// `JOIN` constraint #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum JoinConstraint { /// `ON` On(Expr), /// `USING`: col names Using(DistinctNames), } /// `GROUP BY` #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct GroupBy { /// expressions pub exprs: Vec, /// `HAVING` pub having: Option>, // HAVING clause on a non-aggregate query } /// identifier or string or `CROSS` or `FULL` or `INNER` or `LEFT` or `NATURAL` or `OUTER` or `RIGHT`. #[derive(Clone, Debug, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum Name { /// Identifier Ident(String), /// Quoted values Quoted(String), } impl Name { /// Constructor pub fn from_token(_ty: YYCODETYPE, token: Token) -> Self { let text = from_bytes(token.1); Self::from_str(&text) } fn as_bytes(&self) -> QuotedIterator<'_> { match self { Name::Ident(s) => QuotedIterator(s.bytes(), 0), Name::Quoted(s) => { if s.is_empty() { return QuotedIterator(s.bytes(), 0); } let bytes = s.as_bytes(); let mut quote = bytes[0]; if quote != b'"' && quote != b'`' && quote != b'\'' && quote != b'[' { return QuotedIterator(s.bytes(), 0); } else if quote == b'[' { quote = b']'; } debug_assert!(bytes.len() > 1); debug_assert_eq!(quote, bytes[bytes.len() - 1]); let sub = &s.as_str()[1..bytes.len() - 1]; if quote == b']' { return QuotedIterator(sub.bytes(), 0); // no escape } QuotedIterator(sub.bytes(), quote) } } } /// as_str pub fn as_str(&self) -> &str { match self { Name::Ident(s) | Name::Quoted(s) => s.as_str(), } } /// Identifying from a string #[allow(clippy::should_implement_trait)] pub fn from_str(s: &str) -> Self { let bytes = s.as_bytes(); if s.is_empty() { return Name::Ident(s.to_string()); } match bytes[0] { b'"' | b'\'' | b'`' | b'[' => Name::Quoted(s.to_string()), _ => Name::Ident(s.to_string()), } } } struct QuotedIterator<'s>(Bytes<'s>, u8); impl Iterator for QuotedIterator<'_> { type Item = u8; fn next(&mut self) -> Option { match self.0.next() { x @ Some(b) => { if b == self.1 && self.0.next() != Some(self.1) { panic!("Malformed string literal: {:?}", self.0); } x } x => x, } } fn size_hint(&self) -> (usize, Option) { if self.1 == 0 { return self.0.size_hint(); } (0, None) } } fn eq_ignore_case_and_quote(mut it: QuotedIterator<'_>, mut other: QuotedIterator<'_>) -> bool { loop { match (it.next(), other.next()) { (Some(b1), Some(b2)) => { if !b1.eq_ignore_ascii_case(&b2) { return false; } } (None, None) => break, _ => return false, } } true } /// Ignore case and quote impl std::hash::Hash for Name { fn hash(&self, hasher: &mut H) { self.as_bytes() .for_each(|b| hasher.write_u8(b.to_ascii_lowercase())); } } /// Ignore case and quote impl PartialEq for Name { fn eq(&self, other: &Self) -> bool { eq_ignore_case_and_quote(self.as_bytes(), other.as_bytes()) } } /// Ignore case and quote impl PartialEq for Name { fn eq(&self, other: &str) -> bool { eq_ignore_case_and_quote(self.as_bytes(), QuotedIterator(other.bytes(), 0u8)) } } /// Ignore case and quote impl PartialEq<&str> for Name { fn eq(&self, other: &&str) -> bool { eq_ignore_case_and_quote(self.as_bytes(), QuotedIterator(other.bytes(), 0u8)) } } /// Qualified name #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct QualifiedName { /// schema pub db_name: Option, /// object name pub name: Name, /// alias pub alias: Option, // FIXME restrict alias usage (fullname vs xfullname) } impl QualifiedName { /// Constructor pub fn single(name: Name) -> Self { Self { db_name: None, name, alias: None, } } /// Constructor pub fn fullname(db_name: Name, name: Name) -> Self { Self { db_name: Some(db_name), name, alias: None, } } /// Constructor pub fn xfullname(db_name: Name, name: Name, alias: Name) -> Self { Self { db_name: Some(db_name), name, alias: Some(alias), } } /// Constructor pub fn alias(name: Name, alias: Name) -> Self { Self { db_name: None, name, alias: Some(alias), } } } /// Ordered set of column names #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Names(Vec); impl Names { /// Initialize pub fn new(name: Name) -> Self { let mut dn = Self(Vec::new()); dn.0.push(name); dn } /// Single column name pub fn single(name: Name) -> Self { let mut dn = Self(Vec::with_capacity(1)); dn.0.push(name); dn } /// Push name pub fn insert(&mut self, name: Name) -> Result<(), ParserError> { self.0.push(name); Ok(()) } } impl Deref for Names { type Target = Vec; fn deref(&self) -> &Vec { &self.0 } } /// Ordered set of distinct column names #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct DistinctNames(IndexSet); impl DistinctNames { /// Initialize pub fn new(name: Name) -> Self { let mut dn = Self(IndexSet::new()); dn.0.insert(name); dn } /// Single column name pub fn single(name: Name) -> Self { let mut dn = Self(IndexSet::with_capacity(1)); dn.0.insert(name); dn } /// Push a distinct name or fail pub fn insert(&mut self, name: Name) -> Result<(), ParserError> { if self.0.contains(&name) { return Err(custom_err!("column \"{}\" specified more than once", name)); } self.0.insert(name); Ok(()) } } impl Deref for DistinctNames { type Target = IndexSet; fn deref(&self) -> &IndexSet { &self.0 } } /// `ALTER TABLE` body // https://sqlite.org/lang_altertable.html #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum AlterTableBody { /// `RENAME TO`: new table name RenameTo(Name), /// `ADD COLUMN` AddColumn(ColumnDefinition), // TODO distinction between ADD and ADD COLUMN /// `RENAME COLUMN` RenameColumn { /// old name old: Name, /// new name new: Name, }, /// `DROP COLUMN` DropColumn(Name), // TODO distinction between DROP and DROP COLUMN } /// `CREATE TABLE` body // https://sqlite.org/lang_createtable.html // https://sqlite.org/syntax/create-table-stmt.html #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum CreateTableBody { /// columns and constraints ColumnsAndConstraints { /// table column definitions columns: IndexMap, /// table constraints constraints: Option>, /// table options options: TableOptions, }, /// `AS` select AsSelect(Box, Option), /// `DEFAULT VALUES` DefaultValues, } /// `UPDATE ... SET` #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Set { /// column name(s) pub col_names: Names, /// expression pub expr: Expr, } /// `PRAGMA` body // https://sqlite.org/syntax/pragma-stmt.html #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum PragmaBody { /// `=` Equals(PragmaValue), /// function call Call(PragmaValue), } /// `PRAGMA` value // https://sqlite.org/syntax/pragma-value.html pub type PragmaValue = Expr; // TODO /// `PRAGMA` value // https://sqlite.org/pragma.html #[derive(Clone, Debug, PartialEq, Eq, EnumIter, EnumString, strum::Display)] #[strum(serialize_all = "snake_case")] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum PragmaName { /// Returns the application ID of the database file. ApplicationId, /// set the autovacuum mode AutoVacuum, /// `cache_size` pragma CacheSize, /// List databases DatabaseList, /// Encoding - only support utf8 Encoding, /// Run integrity check on the database file IntegrityCheck, /// `journal_mode` pragma JournalMode, /// Noop as per SQLite docs LegacyFileFormat, /// Set or get the maximum number of pages in the database file. MaxPageCount, /// Return the total number of pages in the database file. PageCount, /// Return the page size of the database in bytes. PageSize, /// Returns schema version of the database file. SchemaVersion, /// returns information about the columns of a table TableInfo, /// enable capture-changes logic for the connection UnstableCaptureDataChangesConn, /// Returns the user version of the database file. UserVersion, /// trigger a checkpoint to run on database(s) if WAL is enabled WalCheckpoint, } /// `CREATE TRIGGER` time #[derive(Copy, Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum TriggerTime { /// `BEFORE` Before, // default /// `AFTER` After, /// `INSTEAD OF` InsteadOf, } /// `CREATE TRIGGER` event #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum TriggerEvent { /// `DELETE` Delete, /// `INSERT` Insert, /// `UPDATE` Update, /// `UPDATE OF`: col names UpdateOf(DistinctNames), } /// `CREATE TRIGGER` command // https://sqlite.org/lang_createtrigger.html // https://sqlite.org/syntax/create-trigger-stmt.html #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum TriggerCmd { /// `UPDATE` Update(Box), /// `INSERT` Insert(Box), /// `DELETE` Delete(Box), /// `SELECT` Select(Box, /// `ON CONFLICT` clause pub upsert: Option, /// `RETURNING` pub returning: Option>, } /// `DELETE` trigger command #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct TriggerCmdDelete { /// table name pub tbl_name: Name, /// `WHERE` clause pub where_clause: Option, } /// Conflict resolution types #[derive(Copy, Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum ResolveType { /// `ROLLBACK` Rollback, /// `ABORT` Abort, // default /// `FAIL` Fail, /// `IGNORE` Ignore, /// `REPLACE` Replace, } impl ResolveType { /// Get the OE_XXX bit value pub fn bit_value(&self) -> usize { match self { ResolveType::Rollback => 1, ResolveType::Abort => 2, ResolveType::Fail => 3, ResolveType::Ignore => 4, ResolveType::Replace => 5, } } } /// `WITH` clause // https://sqlite.org/lang_with.html // https://sqlite.org/syntax/with-clause.html #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct With { /// `RECURSIVE` pub recursive: bool, /// CTEs pub ctes: Vec, } /// CTE materialization #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum Materialized { /// No hint Any, /// `MATERIALIZED` Yes, /// `NOT MATERIALIZED` No, } /// CTE // https://sqlite.org/syntax/common-table-expression.html #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct CommonTableExpr { /// table name pub tbl_name: Name, /// table columns pub columns: Option>, // check no duplicate /// `MATERIALIZED` pub materialized: Materialized, /// query pub select: Box