diff --git a/parser/src/parser.rs b/parser/src/parser.rs index 56bedd1e6..740579f48 100644 --- a/parser/src/parser.rs +++ b/parser/src/parser.rs @@ -580,7 +580,7 @@ impl<'a> Parser<'a> { TokenType::TK_DROP, TokenType::TK_INSERT, TokenType::TK_REPLACE, - // add more + TokenType::TK_UPDATE, ])?; match tok.token_type.unwrap() { @@ -601,6 +601,7 @@ impl<'a> Parser<'a> { TokenType::TK_DELETE => self.parse_delete(), TokenType::TK_DROP => self.parse_drop_stmt(), TokenType::TK_INSERT | TokenType::TK_REPLACE => self.parse_insert(), + TokenType::TK_UPDATE => self.parse_update(), _ => unreachable!(), } } @@ -875,7 +876,7 @@ impl<'a> Parser<'a> { TokenType::TK_SELECT | TokenType::TK_VALUES => { Ok(Stmt::Select(self.parse_select_without_cte(with)?)) } - TokenType::TK_UPDATE => todo!(), + TokenType::TK_UPDATE => self.parse_update_without_cte(with), TokenType::TK_DELETE => self.parse_delete_without_cte(with), TokenType::TK_INSERT | TokenType::TK_REPLACE => self.parse_insert_without_cte(with), _ => unreachable!(), @@ -3944,6 +3945,37 @@ impl<'a> Parser<'a> { let with = self.parse_with()?; self.parse_insert_without_cte(with) } + + fn parse_update_without_cte(&mut self, with: Option) -> Result { + self.eat_assert(&[TokenType::TK_UPDATE]); + let resolve_type = self.parse_or_conflict()?; + let tbl_name = self.parse_fullname(true)?; + let indexed = self.parse_indexed()?; + self.eat_expect(&[TokenType::TK_SET])?; + let sets = self.parse_set_list()?; + let from = self.parse_from_clause_opt()?; + let where_clause = self.parse_where()?; + let returning = self.parse_returning()?; + let order_by = self.parse_order_by()?; + let limit = self.parse_limit()?; + Ok(Stmt::Update { + with, + or_conflict: resolve_type, + tbl_name, + indexed, + sets, + from, + where_clause, + returning, + order_by, + limit, + }) + } + + fn parse_update(&mut self) -> Result { + let with = self.parse_with()?; + self.parse_update_without_cte(with) + } } #[cfg(test)] @@ -11112,6 +11144,112 @@ mod tests { ], })], ), + // parse update + ( + b"UPDATE foo SET bar = 1".as_slice(), + vec![Cmd::Stmt(Stmt::Update { + with: None, + or_conflict: None, + tbl_name: QualifiedName { + db_name: None, + name: Name::Ident("foo".to_owned()), + alias: None, + }, + indexed: None, + sets: vec![ + Set { + col_names: vec![ + Name::Ident("bar".to_owned()), + ], + expr: Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + } + ], + from: None, + where_clause: None, + returning: vec![], + order_by: vec![], + limit: None, + })], + ), + ( + b"WITH test AS (SELECT 1) UPDATE OR REPLACE foo NOT INDEXED SET bar = 1 FROM foo_2 WHERE 1 RETURNING bar ORDER By bar LIMIT 1".as_slice(), + vec![Cmd::Stmt(Stmt::Update { + with: Some(With { + recursive: false, + ctes: vec![ + CommonTableExpr { + tbl_name: Name::Ident("test".to_owned()), + columns: vec![], + materialized: Materialized::Any, + select: Select { + with: None, + body: SelectBody { + select: OneSelect::Select { + distinctness: None, + columns: vec![ResultColumn::Expr( + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + None, + )], + from: None, + where_clause: None, + group_by: None, + window_clause: vec![], + }, + compounds: vec![], + }, + order_by: vec![], + limit: None, + }, + } + ], + }), + or_conflict: Some(ResolveType::Replace), + tbl_name: QualifiedName { + db_name: None, + name: Name::Ident("foo".to_owned()), + alias: None, + }, + indexed: Some(Indexed::NotIndexed), + sets: vec![ + Set { + col_names: vec![ + Name::Ident("bar".to_owned()), + ], + expr: Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + } + ], + from: Some(FromClause { + select: Box::new(SelectTable::Table( + QualifiedName { + db_name: None, + name: Name::Ident("foo_2".to_owned()), + alias: None, + }, + None, + None, + )), + joins: vec![] + }), + where_clause: Some(Box::new(Expr::Literal(Literal::Numeric("1".to_owned())))), + returning: vec![ + ResultColumn::Expr( + Box::new(Expr::Id(Name::Ident("bar".to_owned()))), + None, + ), + ], + order_by: vec![ + SortedColumn { + expr: Box::new(Expr::Id(Name::Ident("bar".to_owned()))), + order: None, + nulls: None, + } + ], + limit: Some(Limit { + expr: Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + offset: None, + }), + })], + ), ]; for (input, expected) in test_cases {