diff --git a/core/translate/display.rs b/core/translate/display.rs index 36eed99eb..9a0bd5d6d 100644 --- a/core/translate/display.rs +++ b/core/translate/display.rs @@ -613,7 +613,7 @@ impl ToTokens for UpdatePlan { .unwrap(); ast::Set { - col_names: vec![ast::Name::new(col_name)], + col_names: vec![ast::Name::exact(col_name.clone())], expr: set_expr.clone(), } }), diff --git a/core/translate/index.rs b/core/translate/index.rs index 4273f00e8..0fe25b882 100644 --- a/core/translate/index.rs +++ b/core/translate/index.rs @@ -392,10 +392,7 @@ fn create_idx_stmt_to_sql( sql.push_str(", "); } let col_ident = match col.expr.as_ref() { - Expr::Id(ast::Name::Ident(col_name)) - | Expr::Id(ast::Name::Quoted(col_name)) - | Expr::Name(ast::Name::Ident(col_name)) - | Expr::Name(ast::Name::Quoted(col_name)) => col_name, + Expr::Id(name) | Expr::Name(name) => name.as_str(), _ => unreachable!("expressions in CREATE INDEX should have been rejected earlier"), }; sql.push_str(col_ident); diff --git a/core/translate/planner.rs b/core/translate/planner.rs index e1919afbc..d0c06d2a7 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -968,7 +968,7 @@ fn parse_join( { for left_col in left_table.columns().iter().filter(|col| !col.hidden) { if left_col.name == right_col.name { - distinct_names.push(ast::Name::new( + distinct_names.push(ast::Name::exact( left_col.name.clone().expect("column name is None"), )); found_match = true; diff --git a/core/translate/pragma.rs b/core/translate/pragma.rs index 5264a5725..26b98c0bb 100644 --- a/core/translate/pragma.rs +++ b/core/translate/pragma.rs @@ -291,7 +291,7 @@ fn update_pragma( program = translate_create_table( QualifiedName { db_name: None, - name: ast::Name::new(table), + name: ast::Name::exact(table.to_string()), alias: None, }, false, @@ -796,7 +796,7 @@ pub const TURSO_CDC_DEFAULT_TABLE_NAME: &str = "turso_cdc"; fn turso_cdc_table_columns() -> Vec { vec![ ast::ColumnDefinition { - col_name: ast::Name::new("change_id"), + col_name: ast::Name::exact("change_id".to_string()), col_type: Some(ast::Type { name: "INTEGER".to_string(), size: None, @@ -811,7 +811,7 @@ fn turso_cdc_table_columns() -> Vec { }], }, ast::ColumnDefinition { - col_name: ast::Name::new("change_time"), + col_name: ast::Name::exact("change_time".to_string()), col_type: Some(ast::Type { name: "INTEGER".to_string(), size: None, @@ -819,7 +819,7 @@ fn turso_cdc_table_columns() -> Vec { constraints: vec![], }, ast::ColumnDefinition { - col_name: ast::Name::new("change_type"), + col_name: ast::Name::exact("change_type".to_string()), col_type: Some(ast::Type { name: "INTEGER".to_string(), size: None, @@ -827,7 +827,7 @@ fn turso_cdc_table_columns() -> Vec { constraints: vec![], }, ast::ColumnDefinition { - col_name: ast::Name::new("table_name"), + col_name: ast::Name::exact("table_name".to_string()), col_type: Some(ast::Type { name: "TEXT".to_string(), size: None, @@ -835,12 +835,12 @@ fn turso_cdc_table_columns() -> Vec { constraints: vec![], }, ast::ColumnDefinition { - col_name: ast::Name::new("id"), + col_name: ast::Name::exact("id".to_string()), col_type: None, constraints: vec![], }, ast::ColumnDefinition { - col_name: ast::Name::new("before"), + col_name: ast::Name::exact("before".to_string()), col_type: Some(ast::Type { name: "BLOB".to_string(), size: None, @@ -848,7 +848,7 @@ fn turso_cdc_table_columns() -> Vec { constraints: vec![], }, ast::ColumnDefinition { - col_name: ast::Name::new("after"), + col_name: ast::Name::exact("after".to_string()), col_type: Some(ast::Type { name: "BLOB".to_string(), size: None, @@ -856,7 +856,7 @@ fn turso_cdc_table_columns() -> Vec { constraints: vec![], }, ast::ColumnDefinition { - col_name: ast::Name::new("updates"), + col_name: ast::Name::exact("updates".to_string()), col_type: Some(ast::Type { name: "BLOB".to_string(), size: None, diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index ea714b6da..0240c656e 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -5222,7 +5222,7 @@ pub fn op_function( Some( ast::Stmt::CreateIndex { - tbl_name: ast::Name::new(original_rename_to), + tbl_name: ast::Name::exact(original_rename_to.to_string()), unique, if_not_exists, idx_name, @@ -5248,7 +5248,7 @@ pub fn op_function( ast::Stmt::CreateTable { tbl_name: ast::QualifiedName { db_name: None, - name: ast::Name::new(original_rename_to), + name: ast::Name::exact(original_rename_to.to_string()), alias: None, }, temporary, @@ -5365,7 +5365,7 @@ pub fn op_function( let column = columns .iter_mut() .find(|column| { - column.col_name == ast::Name::new(original_rename_from) + column.col_name == ast::Name::exact(original_rename_from.to_string()) }) .expect("column being renamed should be present"); diff --git a/parser/src/ast.rs b/parser/src/ast.rs index 912f1baf8..adfabe00b 100644 --- a/parser/src/ast.rs +++ b/parser/src/ast.rs @@ -1,6 +1,8 @@ pub mod check; pub mod fmt; +use std::sync::OnceLock; + use strum_macros::{EnumIter, EnumString}; /// `?` or `$` Prepared statement arg placeholder(s) @@ -879,22 +881,62 @@ pub struct GroupBy { /// identifier or string or `CROSS` or `FULL` or `INNER` or `LEFT` or `NATURAL` or `OUTER` or `RIGHT`. #[derive(Clone, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub enum Name { - /// Identifier - Ident(String), - /// Quoted values - Quoted(String), +pub struct Name { + quote: Option, + value: String, + lowercase: OnceLock, + value_is_lowercase: bool, +} + +#[cfg(feature = "serde")] +impl serde::Serialize for Name { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(&self.value) + } +} + +#[cfg(feature = "serde")] +impl<'de> serde::Deserialize<'de> for Name { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct NameVisitor; + impl<'de> serde::de::Visitor<'de> for NameVisitor { + type Value = Name; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a string") + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + Ok(Name::from_bytes(v.as_bytes())) + } + } + deserializer.deserialize_str(NameVisitor) + } } impl Name { pub fn exact(s: String) -> Self { - Self::Ident(s) + let value_is_lowercase = s.chars().all(|x| x.is_lowercase()); + Self { + value: s, + quote: None, + lowercase: OnceLock::new(), + value_is_lowercase, + } } pub fn from_bytes(s: &[u8]) -> Self { - Self::new(unsafe { std::str::from_utf8_unchecked(s) }.to_owned()) + Self::from_str(unsafe { std::str::from_utf8_unchecked(s) }) } - pub fn new(s: impl AsRef) -> Self { + pub fn from_str(s: impl AsRef) -> Self { let s = s.as_ref(); let bytes = s.as_bytes(); @@ -902,22 +944,51 @@ impl Name { return Name::exact(s.to_string()); } - match bytes[0] { - b'"' | b'\'' | b'`' | b'[' => Name::Quoted(s.to_string()), - _ => Name::exact(s.to_string()), + if matches!(bytes[0], b'"' | b'\'' | b'`') { + assert!(s.len() >= 2); + assert!(bytes[bytes.len() - 1] == bytes[0]); + let s = match bytes[0] { + b'"' => s[1..s.len() - 1].replace("\"\"", "\""), + b'\'' => s[1..s.len() - 1].replace("''", "'"), + b'`' => s[1..s.len() - 1].replace("``", "`"), + _ => unreachable!(), + }; + let value_is_lowercase = s.chars().all(|x| x.is_lowercase()); + Name { + value: s, + quote: Some(bytes[0] as char), + lowercase: OnceLock::new(), + value_is_lowercase, + } + } else if bytes[0] == b'[' { + assert!(s.len() >= 2); + assert!(bytes[bytes.len() - 1] == b']'); + Name::exact(s[1..s.len() - 1].to_string()) + } else { + Name::exact(s.to_string()) } } pub fn as_str(&self) -> &str { - match self { - Name::Ident(s) => s.as_str(), - Name::Quoted(s) => &s[1..s.len() - 1], + if self.value_is_lowercase { + return &self.value; } + if self.lowercase.get().is_none() { + let _ = self.lowercase.set(self.value.to_lowercase()); + } + self.lowercase.get().unwrap() } pub fn as_literal(&self) -> &str { - match self { - Name::Ident(s) | Name::Quoted(s) => s.as_str(), + &self.value + } + + pub fn as_quoted(&self) -> String { + let value = self.value.as_bytes(); + if !value.is_empty() && value.iter().all(|x| x.is_ascii_alphanumeric()) { + self.value.clone() + } else { + format!("\"{}\"", self.value.replace("\"", "\"\"")) } } @@ -927,14 +998,11 @@ impl Name { /// /// Also, used to detect string literals in PRAGMA cases pub fn quoted_with(&self, quote: char) -> bool { - if let Self::Quoted(ident) = self { - return ident.starts_with(quote); - } - false + self.quote == Some(quote) } pub fn quoted(&self) -> bool { - matches!(self, Self::Quoted(..)) + self.quote.is_some() } } diff --git a/parser/src/ast/fmt.rs b/parser/src/ast/fmt.rs index c7a8db510..7fdf19a9c 100644 --- a/parser/src/ast/fmt.rs +++ b/parser/src/ast/fmt.rs @@ -745,7 +745,7 @@ impl ToTokens for Expr { Self::Collate(expr, collation) => { expr.to_tokens(s, context)?; s.append(TK_COLLATE, None)?; - double_quote(collation.as_str(), s) + s.append(TK_ID, Some(&collation.as_quoted())) } Self::DoublyQualified(db_name, tbl_name, col_name) => { db_name.to_tokens(s, context)?; @@ -1370,7 +1370,7 @@ impl ToTokens for Name { s: &mut S, _: &C, ) -> Result<(), S::Error> { - double_quote(self.as_literal(), s) + s.append(TK_ID, Some(&self.as_quoted())) } } @@ -2460,11 +2460,3 @@ where { s.comma(items, context) } - -// TK_ID: [...] / `...` / "..." / some keywords / non keywords -fn double_quote(name: &str, s: &mut S) -> Result<(), S::Error> { - if name.is_empty() { - return s.append(TK_ID, Some("\"\"")); - } - s.append(TK_ID, Some(name)) -} diff --git a/parser/src/parser.rs b/parser/src/parser.rs index d9bfc593f..f55abf48b 100644 --- a/parser/src/parser.rs +++ b/parser/src/parser.rs @@ -4127,7 +4127,7 @@ mod tests { b"BEGIN EXCLUSIVE TRANSACTION 'my_transaction'".as_slice(), vec![Cmd::Stmt(Stmt::Begin { typ: Some(TransactionType::Exclusive), - name: Some(Name::new("'my_transaction'".to_string())), + name: Some(Name::from_str("'my_transaction'".to_string())), })], ), ( @@ -4148,7 +4148,7 @@ mod tests { b"BEGIN CONCURRENT TRANSACTION 'my_transaction'".as_slice(), vec![Cmd::Stmt(Stmt::Begin { typ: Some(TransactionType::Concurrent), - name: Some(Name::new("'my_transaction'".to_string())), + name: Some(Name::from_str("'my_transaction'".to_string())), })], ), ( @@ -4243,7 +4243,7 @@ mod tests { ( b"SAVEPOINT 'my_savepoint'".as_slice(), vec![Cmd::Stmt(Stmt::Savepoint { - name: Name::new("'my_savepoint'".to_string()), + name: Name::from_str("'my_savepoint'".to_string()), })], ), // release @@ -4262,7 +4262,7 @@ mod tests { ( b"RELEASE SAVEPOINT 'my_savepoint'".as_slice(), vec![Cmd::Stmt(Stmt::Release { - name: Name::new("'my_savepoint'".to_string()), + name: Name::from_str("'my_savepoint'".to_string()), })], ), ( @@ -11474,13 +11474,13 @@ mod tests { if_not_exists: false, tbl_name: QualifiedName { db_name: None, - name: Name::new("\"settings\"".to_owned()), + name: Name::from_str("\"settings\"".to_owned()), alias: None, }, body: CreateTableBody::ColumnsAndConstraints{ columns: vec![ ColumnDefinition { - col_name: Name::new("\"enabled\"".to_owned()), + col_name: Name::from_str("\"enabled\"".to_owned()), col_type: Some(Type { name: "INTEGER".to_owned(), size: None, diff --git a/sql_generation/generation/expr.rs b/sql_generation/generation/expr.rs index 244bf6469..982ec3255 100644 --- a/sql_generation/generation/expr.rs +++ b/sql_generation/generation/expr.rs @@ -236,7 +236,7 @@ impl Arbitrary for QualifiedName { // TODO: for now forego alias Self { db_name: None, - name: Name::new(&table.name), + name: Name::from_str(&table.name), alias: None, } } diff --git a/sql_generation/generation/predicate/binary.rs b/sql_generation/generation/predicate/binary.rs index 2b7df9b08..5a82e52d7 100644 --- a/sql_generation/generation/predicate/binary.rs +++ b/sql_generation/generation/predicate/binary.rs @@ -88,8 +88,8 @@ impl Predicate { Box::new(|_| { Some(Expr::Binary( Box::new(ast::Expr::Qualified( - ast::Name::new(&table_name), - ast::Name::new(&column.name), + ast::Name::from_str(&table_name), + ast::Name::from_str(&column.name), )), ast::Operator::Equals, Box::new(Expr::Literal(value.into())), @@ -105,8 +105,8 @@ impl Predicate { } else { Some(Expr::Binary( Box::new(ast::Expr::Qualified( - ast::Name::new(&table_name), - ast::Name::new(&column.name), + ast::Name::from_str(&table_name), + ast::Name::from_str(&column.name), )), ast::Operator::NotEquals, Box::new(Expr::Literal(v.into())), @@ -120,8 +120,8 @@ impl Predicate { let lt_value = LTValue::arbitrary_from(rng, context, value).0; Some(Expr::Binary( Box::new(ast::Expr::Qualified( - ast::Name::new(&table_name), - ast::Name::new(&column.name), + ast::Name::from_str(&table_name), + ast::Name::from_str(&column.name), )), ast::Operator::Greater, Box::new(Expr::Literal(lt_value.into())), @@ -134,8 +134,8 @@ impl Predicate { let gt_value = GTValue::arbitrary_from(rng, context, value).0; Some(Expr::Binary( Box::new(ast::Expr::Qualified( - ast::Name::new(&table_name), - ast::Name::new(&column.name), + ast::Name::from_str(&table_name), + ast::Name::from_str(&column.name), )), ast::Operator::Less, Box::new(Expr::Literal(gt_value.into())), @@ -149,8 +149,8 @@ impl Predicate { LikeValue::arbitrary_from_maybe(rng, context, value).map(|like| { Expr::Like { lhs: Box::new(ast::Expr::Qualified( - ast::Name::new(&table_name), - ast::Name::new(&column.name), + ast::Name::from_str(&table_name), + ast::Name::from_str(&column.name), )), not: false, // TODO: also generate this value eventually op: ast::LikeOperator::Like, @@ -199,8 +199,8 @@ impl Predicate { Box::new(|_| { Expr::Binary( Box::new(ast::Expr::Qualified( - ast::Name::new(&table_name), - ast::Name::new(&column.name), + ast::Name::from_str(&table_name), + ast::Name::from_str(&column.name), )), ast::Operator::NotEquals, Box::new(Expr::Literal(value.into())), @@ -215,8 +215,8 @@ impl Predicate { }; Expr::Binary( Box::new(ast::Expr::Qualified( - ast::Name::new(&table_name), - ast::Name::new(&column.name), + ast::Name::from_str(&table_name), + ast::Name::from_str(&column.name), )), ast::Operator::Equals, Box::new(Expr::Literal(v.into())), @@ -226,8 +226,8 @@ impl Predicate { let gt_value = GTValue::arbitrary_from(rng, context, value).0; Expr::Binary( Box::new(ast::Expr::Qualified( - ast::Name::new(&table_name), - ast::Name::new(&column.name), + ast::Name::from_str(&table_name), + ast::Name::from_str(&column.name), )), ast::Operator::Greater, Box::new(Expr::Literal(gt_value.into())), @@ -237,8 +237,8 @@ impl Predicate { let lt_value = LTValue::arbitrary_from(rng, context, value).0; Expr::Binary( Box::new(ast::Expr::Qualified( - ast::Name::new(&table_name), - ast::Name::new(&column.name), + ast::Name::from_str(&table_name), + ast::Name::from_str(&column.name), )), ast::Operator::Less, Box::new(Expr::Literal(lt_value.into())), @@ -275,8 +275,8 @@ impl SimplePredicate { Box::new(|_rng| { Expr::Binary( Box::new(ast::Expr::Qualified( - ast::Name::new(table_name), - ast::Name::new(&column.column.name), + ast::Name::from_str(table_name), + ast::Name::from_str(&column.column.name), )), ast::Operator::Equals, Box::new(Expr::Literal(column_value.into())), @@ -286,8 +286,8 @@ impl SimplePredicate { let lt_value = LTValue::arbitrary_from(rng, context, column_value).0; Expr::Binary( Box::new(Expr::Qualified( - ast::Name::new(table_name), - ast::Name::new(&column.column.name), + ast::Name::from_str(table_name), + ast::Name::from_str(&column.column.name), )), ast::Operator::Greater, Box::new(Expr::Literal(lt_value.into())), @@ -297,8 +297,8 @@ impl SimplePredicate { let gt_value = GTValue::arbitrary_from(rng, context, column_value).0; Expr::Binary( Box::new(Expr::Qualified( - ast::Name::new(table_name), - ast::Name::new(&column.column.name), + ast::Name::from_str(table_name), + ast::Name::from_str(&column.column.name), )), ast::Operator::Less, Box::new(Expr::Literal(gt_value.into())), @@ -333,8 +333,8 @@ impl SimplePredicate { Box::new(|_rng| { Expr::Binary( Box::new(Expr::Qualified( - ast::Name::new(table_name), - ast::Name::new(&column.column.name), + ast::Name::from_str(table_name), + ast::Name::from_str(&column.column.name), )), ast::Operator::NotEquals, Box::new(Expr::Literal(column_value.into())), @@ -344,8 +344,8 @@ impl SimplePredicate { let gt_value = GTValue::arbitrary_from(rng, context, column_value).0; Expr::Binary( Box::new(ast::Expr::Qualified( - ast::Name::new(table_name), - ast::Name::new(&column.column.name), + ast::Name::from_str(table_name), + ast::Name::from_str(&column.column.name), )), ast::Operator::Greater, Box::new(Expr::Literal(gt_value.into())), @@ -355,8 +355,8 @@ impl SimplePredicate { let lt_value = LTValue::arbitrary_from(rng, context, column_value).0; Expr::Binary( Box::new(ast::Expr::Qualified( - ast::Name::new(table_name), - ast::Name::new(&column.column.name), + ast::Name::from_str(table_name), + ast::Name::from_str(&column.column.name), )), ast::Operator::Less, Box::new(Expr::Literal(lt_value.into())), diff --git a/sql_generation/model/query/select.rs b/sql_generation/model/query/select.rs index 721df334a..ddf6fc1f7 100644 --- a/sql_generation/model/query/select.rs +++ b/sql_generation/model/query/select.rs @@ -187,7 +187,7 @@ impl FromClause { fn to_sql_ast(&self) -> ast::FromClause { ast::FromClause { select: Box::new(ast::SelectTable::Table( - ast::QualifiedName::single(ast::Name::new(&self.table)), + ast::QualifiedName::single(ast::Name::from_str(&self.table)), None, None, )), @@ -203,7 +203,7 @@ impl FromClause { JoinType::Cross => ast::JoinOperator::TypedJoin(Some(ast::JoinType::CROSS)), }, table: Box::new(ast::SelectTable::Table( - ast::QualifiedName::single(ast::Name::new(&join.table)), + ast::QualifiedName::single(ast::Name::from_str(&join.table)), None, None, )),