vendor sqlite3-parser (lemon-rs)

This commit is contained in:
jussisaurio
2024-11-16 13:22:26 +02:00
parent 8efeb16b82
commit 3cc9d9d79f
38 changed files with 15973 additions and 37 deletions

View File

@@ -0,0 +1,95 @@
use std::error;
use std::fmt;
use std::io;
use crate::lexer::scan::ScanError;
use crate::parser::ParserError;
/// SQL lexer and parser errors
#[non_exhaustive]
#[derive(Debug)]
pub enum Error {
/// I/O Error
Io(io::Error),
/// Lexer error
UnrecognizedToken(Option<(u64, usize)>),
/// Missing quote or double-quote or backtick
UnterminatedLiteral(Option<(u64, usize)>),
/// Missing `]`
UnterminatedBracket(Option<(u64, usize)>),
/// Missing `*/`
UnterminatedBlockComment(Option<(u64, usize)>),
/// Invalid parameter name
BadVariableName(Option<(u64, usize)>),
/// Invalid number format
BadNumber(Option<(u64, usize)>),
/// Invalid or missing sign after `!`
ExpectedEqualsSign(Option<(u64, usize)>),
/// BLOB literals are string literals containing hexadecimal data and preceded by a single "x" or "X" character.
MalformedBlobLiteral(Option<(u64, usize)>),
/// Hexadecimal integer literals follow the C-language notation of "0x" or "0X" followed by hexadecimal digits.
MalformedHexInteger(Option<(u64, usize)>),
/// Grammar error
ParserError(ParserError, Option<(u64, usize)>),
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
Self::Io(ref err) => err.fmt(f),
Self::UnrecognizedToken(pos) => write!(f, "unrecognized token at {:?}", pos.unwrap()),
Self::UnterminatedLiteral(pos) => {
write!(f, "non-terminated literal at {:?}", pos.unwrap())
}
Self::UnterminatedBracket(pos) => {
write!(f, "non-terminated bracket at {:?}", pos.unwrap())
}
Self::UnterminatedBlockComment(pos) => {
write!(f, "non-terminated block comment at {:?}", pos.unwrap())
}
Self::BadVariableName(pos) => write!(f, "bad variable name at {:?}", pos.unwrap()),
Self::BadNumber(pos) => write!(f, "bad number at {:?}", pos.unwrap()),
Self::ExpectedEqualsSign(pos) => write!(f, "expected = sign at {:?}", pos.unwrap()),
Self::MalformedBlobLiteral(pos) => {
write!(f, "malformed blob literal at {:?}", pos.unwrap())
}
Self::MalformedHexInteger(pos) => {
write!(f, "malformed hex integer at {:?}", pos.unwrap())
}
Self::ParserError(ref msg, Some(pos)) => write!(f, "{msg} at {pos:?}"),
Self::ParserError(ref msg, _) => write!(f, "{msg}"),
}
}
}
impl error::Error for Error {}
impl From<io::Error> for Error {
fn from(err: io::Error) -> Self {
Self::Io(err)
}
}
impl From<ParserError> for Error {
fn from(err: ParserError) -> Self {
Self::ParserError(err, None)
}
}
impl ScanError for Error {
fn position(&mut self, line: u64, column: usize) {
match *self {
Self::Io(_) => {}
Self::UnrecognizedToken(ref mut pos) => *pos = Some((line, column)),
Self::UnterminatedLiteral(ref mut pos) => *pos = Some((line, column)),
Self::UnterminatedBracket(ref mut pos) => *pos = Some((line, column)),
Self::UnterminatedBlockComment(ref mut pos) => *pos = Some((line, column)),
Self::BadVariableName(ref mut pos) => *pos = Some((line, column)),
Self::BadNumber(ref mut pos) => *pos = Some((line, column)),
Self::ExpectedEqualsSign(ref mut pos) => *pos = Some((line, column)),
Self::MalformedBlobLiteral(ref mut pos) => *pos = Some((line, column)),
Self::MalformedHexInteger(ref mut pos) => *pos = Some((line, column)),
Self::ParserError(_, ref mut pos) => *pos = Some((line, column)),
}
}
}

View File

@@ -0,0 +1,678 @@
//! Adaptation/port of [`SQLite` tokenizer](http://www.sqlite.org/src/artifact?ci=trunk&filename=src/tokenize.c)
use fallible_iterator::FallibleIterator;
use memchr::memchr;
pub use crate::dialect::TokenType;
use crate::dialect::TokenType::*;
use crate::dialect::{
is_identifier_continue, is_identifier_start, keyword_token, sentinel, MAX_KEYWORD_LEN,
};
use crate::parser::ast::Cmd;
use crate::parser::parse::{yyParser, YYCODETYPE};
use crate::parser::Context;
mod error;
#[cfg(test)]
mod test;
use crate::lexer::scan::ScanError;
use crate::lexer::scan::Splitter;
use crate::lexer::Scanner;
pub use crate::parser::ParserError;
pub use error::Error;
// TODO Extract scanning stuff and move this into the parser crate
// to make possible to use the tokenizer without depending on the parser...
/// SQL parser
pub struct Parser<'input> {
input: &'input [u8],
scanner: Scanner<Tokenizer>,
parser: yyParser<'input>,
}
impl<'input> Parser<'input> {
/// Constructor
pub fn new(input: &'input [u8]) -> Self {
let lexer = Tokenizer::new();
let scanner = Scanner::new(lexer);
let ctx = Context::new(input);
let parser = yyParser::new(ctx);
Parser {
input,
scanner,
parser,
}
}
/// Parse new `input`
pub fn reset(&mut self, input: &'input [u8]) {
self.input = input;
self.scanner.reset();
}
/// Current line position in input
pub fn line(&self) -> u64 {
self.scanner.line()
}
/// Current column position in input
pub fn column(&self) -> usize {
self.scanner.column()
}
}
/*
** Return the id of the next token in input.
*/
fn get_token(scanner: &mut Scanner<Tokenizer>, input: &[u8]) -> Result<TokenType, Error> {
let mut t = {
let (_, token_type) = match scanner.scan(input)? {
(_, None, _) => {
return Ok(TK_EOF);
}
(_, Some(tuple), _) => tuple,
};
token_type
};
if t == TK_ID
|| t == TK_STRING
|| t == TK_JOIN_KW
|| t == TK_WINDOW
|| t == TK_OVER
|| yyParser::parse_fallback(t as YYCODETYPE) == TK_ID as YYCODETYPE
{
t = TK_ID;
}
Ok(t)
}
/*
** The following three functions are called immediately after the tokenizer
** reads the keywords WINDOW, OVER and FILTER, respectively, to determine
** whether the token should be treated as a keyword or an SQL identifier.
** This cannot be handled by the usual lemon %fallback method, due to
** the ambiguity in some constructions. e.g.
**
** SELECT sum(x) OVER ...
**
** In the above, "OVER" might be a keyword, or it might be an alias for the
** sum(x) expression. If a "%fallback ID OVER" directive were added to
** grammar, then SQLite would always treat "OVER" as an alias, making it
** impossible to call a window-function without a FILTER clause.
**
** WINDOW is treated as a keyword if:
**
** * the following token is an identifier, or a keyword that can fallback
** to being an identifier, and
** * the token after than one is TK_AS.
**
** OVER is a keyword if:
**
** * the previous token was TK_RP, and
** * the next token is either TK_LP or an identifier.
**
** FILTER is a keyword if:
**
** * the previous token was TK_RP, and
** * the next token is TK_LP.
*/
fn analyze_window_keyword(
scanner: &mut Scanner<Tokenizer>,
input: &[u8],
) -> Result<TokenType, Error> {
let t = get_token(scanner, input)?;
if t != TK_ID {
return Ok(TK_ID);
};
let t = get_token(scanner, input)?;
if t != TK_AS {
return Ok(TK_ID);
};
Ok(TK_WINDOW)
}
fn analyze_over_keyword(
scanner: &mut Scanner<Tokenizer>,
input: &[u8],
last_token: TokenType,
) -> Result<TokenType, Error> {
if last_token == TK_RP {
let t = get_token(scanner, input)?;
if t == TK_LP || t == TK_ID {
return Ok(TK_OVER);
}
}
Ok(TK_ID)
}
fn analyze_filter_keyword(
scanner: &mut Scanner<Tokenizer>,
input: &[u8],
last_token: TokenType,
) -> Result<TokenType, Error> {
if last_token == TK_RP && get_token(scanner, input)? == TK_LP {
return Ok(TK_FILTER);
}
Ok(TK_ID)
}
macro_rules! try_with_position {
($scanner:expr, $expr:expr) => {
match $expr {
Ok(val) => val,
Err(err) => {
let mut err = Error::from(err);
err.position($scanner.line(), $scanner.column());
return Err(err);
}
}
};
}
impl FallibleIterator for Parser<'_> {
type Item = Cmd;
type Error = Error;
fn next(&mut self) -> Result<Option<Cmd>, Error> {
//print!("line: {}, column: {}: ", self.scanner.line(), self.scanner.column());
self.parser.ctx.reset();
let mut last_token_parsed = TK_EOF;
let mut eof = false;
loop {
let (start, (value, mut token_type), end) = match self.scanner.scan(self.input)? {
(_, None, _) => {
eof = true;
break;
}
(start, Some(tuple), end) => (start, tuple, end),
};
let token = if token_type >= TK_WINDOW {
debug_assert!(
token_type == TK_OVER || token_type == TK_FILTER || token_type == TK_WINDOW
);
self.scanner.mark();
if token_type == TK_WINDOW {
token_type = analyze_window_keyword(&mut self.scanner, self.input)?;
} else if token_type == TK_OVER {
token_type =
analyze_over_keyword(&mut self.scanner, self.input, last_token_parsed)?;
} else if token_type == TK_FILTER {
token_type =
analyze_filter_keyword(&mut self.scanner, self.input, last_token_parsed)?;
}
self.scanner.reset_to_mark();
token_type.to_token(start, value, end)
} else {
token_type.to_token(start, value, end)
};
//println!("({:?}, {:?})", token_type, token);
try_with_position!(self.scanner, self.parser.sqlite3Parser(token_type, token));
last_token_parsed = token_type;
if self.parser.ctx.done() {
//println!();
break;
}
}
if last_token_parsed == TK_EOF {
return Ok(None); // empty input
}
/* Upon reaching the end of input, call the parser two more times
with tokens TK_SEMI and 0, in that order. */
if eof && self.parser.ctx.is_ok() {
if last_token_parsed != TK_SEMI {
try_with_position!(
self.scanner,
self.parser
.sqlite3Parser(TK_SEMI, sentinel(self.input.len()))
);
}
try_with_position!(
self.scanner,
self.parser
.sqlite3Parser(TK_EOF, sentinel(self.input.len()))
);
}
self.parser.sqlite3ParserFinalize();
if let Some(e) = self.parser.ctx.error() {
let err = Error::ParserError(e, Some((self.scanner.line(), self.scanner.column())));
return Err(err);
}
let cmd = self.parser.ctx.cmd();
if let Some(ref cmd) = cmd {
if let Err(e) = cmd.check() {
let err = Error::ParserError(e, Some((self.scanner.line(), self.scanner.column())));
return Err(err);
}
}
Ok(cmd)
}
}
/// SQL token
pub type Token<'input> = (&'input [u8], TokenType);
/// SQL lexer
#[derive(Default)]
pub struct Tokenizer {}
impl Tokenizer {
/// Constructor
pub fn new() -> Self {
Self {}
}
}
/// ```rust
/// use sqlite3_parser::lexer::sql::Tokenizer;
/// use sqlite3_parser::lexer::Scanner;
///
/// let tokenizer = Tokenizer::new();
/// let input = b"PRAGMA parser_trace=ON;";
/// let mut s = Scanner::new(tokenizer);
/// let Ok((_, Some((token1, _)), _)) = s.scan(input) else { panic!() };
/// s.scan(input).unwrap();
/// assert!(b"PRAGMA".eq_ignore_ascii_case(token1));
/// ```
impl Splitter for Tokenizer {
type Error = Error;
type TokenType = TokenType;
fn split<'input>(
&mut self,
data: &'input [u8],
) -> Result<(Option<Token<'input>>, usize), Error> {
if data[0].is_ascii_whitespace() {
// eat as much space as possible
return Ok((
None,
match data.iter().skip(1).position(|&b| !b.is_ascii_whitespace()) {
Some(i) => i + 1,
_ => data.len(),
},
));
}
match data[0] {
b'-' => {
if let Some(b) = data.get(1) {
if *b == b'-' {
// eat comment
if let Some(i) = memchr(b'\n', data) {
Ok((None, i + 1))
} else {
Ok((None, data.len()))
}
} else if *b == b'>' {
if let Some(b) = data.get(2) {
if *b == b'>' {
return Ok((Some((&data[..3], TK_PTR)), 3));
}
}
Ok((Some((&data[..2], TK_PTR)), 2))
} else {
Ok((Some((&data[..1], TK_MINUS)), 1))
}
} else {
Ok((Some((&data[..1], TK_MINUS)), 1))
}
}
b'(' => Ok((Some((&data[..1], TK_LP)), 1)),
b')' => Ok((Some((&data[..1], TK_RP)), 1)),
b';' => Ok((Some((&data[..1], TK_SEMI)), 1)),
b'+' => Ok((Some((&data[..1], TK_PLUS)), 1)),
b'*' => Ok((Some((&data[..1], TK_STAR)), 1)),
b'/' => {
if let Some(b) = data.get(1) {
if *b == b'*' {
// eat comment
let mut pb = 0;
let mut end = None;
for (i, b) in data.iter().enumerate().skip(2) {
if *b == b'/' && pb == b'*' {
end = Some(i);
break;
}
pb = *b;
}
if let Some(i) = end {
Ok((None, i + 1))
} else {
Err(Error::UnterminatedBlockComment(None))
}
} else {
Ok((Some((&data[..1], TK_SLASH)), 1))
}
} else {
Ok((Some((&data[..1], TK_SLASH)), 1))
}
}
b'%' => Ok((Some((&data[..1], TK_REM)), 1)),
b'=' => {
if let Some(b) = data.get(1) {
Ok(if *b == b'=' {
(Some((&data[..2], TK_EQ)), 2)
} else {
(Some((&data[..1], TK_EQ)), 1)
})
} else {
Ok((Some((&data[..1], TK_EQ)), 1))
}
}
b'<' => {
if let Some(b) = data.get(1) {
Ok(match *b {
b'=' => (Some((&data[..2], TK_LE)), 2),
b'>' => (Some((&data[..2], TK_NE)), 2),
b'<' => (Some((&data[..2], TK_LSHIFT)), 2),
_ => (Some((&data[..1], TK_LT)), 1),
})
} else {
Ok((Some((&data[..1], TK_LT)), 1))
}
}
b'>' => {
if let Some(b) = data.get(1) {
Ok(match *b {
b'=' => (Some((&data[..2], TK_GE)), 2),
b'>' => (Some((&data[..2], TK_RSHIFT)), 2),
_ => (Some((&data[..1], TK_GT)), 1),
})
} else {
Ok((Some((&data[..1], TK_GT)), 1))
}
}
b'!' => {
if let Some(b) = data.get(1) {
if *b == b'=' {
Ok((Some((&data[..2], TK_NE)), 2))
} else {
Err(Error::ExpectedEqualsSign(None))
}
} else {
Err(Error::ExpectedEqualsSign(None))
}
}
b'|' => {
if let Some(b) = data.get(1) {
Ok(if *b == b'|' {
(Some((&data[..2], TK_CONCAT)), 2)
} else {
(Some((&data[..1], TK_BITOR)), 1)
})
} else {
Ok((Some((&data[..1], TK_BITOR)), 1))
}
}
b',' => Ok((Some((&data[..1], TK_COMMA)), 1)),
b'&' => Ok((Some((&data[..1], TK_BITAND)), 1)),
b'~' => Ok((Some((&data[..1], TK_BITNOT)), 1)),
quote @ (b'`' | b'\'' | b'"') => literal(data, quote),
b'.' => {
if let Some(b) = data.get(1) {
if b.is_ascii_digit() {
fractional_part(data, 0)
} else {
Ok((Some((&data[..1], TK_DOT)), 1))
}
} else {
Ok((Some((&data[..1], TK_DOT)), 1))
}
}
b'0'..=b'9' => number(data),
b'[' => {
if let Some(i) = memchr(b']', data) {
// Keep original quotes / '[' ... ]'
Ok((Some((&data[0..=i], TK_ID)), i + 1))
} else {
Err(Error::UnterminatedBracket(None))
}
}
b'?' => {
match data.iter().skip(1).position(|&b| !b.is_ascii_digit()) {
Some(i) => {
// do not include the '?' in the token
Ok((Some((&data[1..=i], TK_VARIABLE)), i + 1))
}
None => Ok((Some((&data[1..], TK_VARIABLE)), data.len())),
}
}
b'$' | b'@' | b'#' | b':' => {
match data
.iter()
.skip(1)
.position(|&b| !is_identifier_continue(b))
{
Some(0) => Err(Error::BadVariableName(None)),
Some(i) => {
// '$' is included as part of the name
Ok((Some((&data[..=i], TK_VARIABLE)), i + 1))
}
None => {
if data.len() == 1 {
return Err(Error::BadVariableName(None));
}
Ok((Some((data, TK_VARIABLE)), data.len()))
}
}
}
b if is_identifier_start(b) => {
if b == b'x' || b == b'X' {
if let Some(&b'\'') = data.get(1) {
blob_literal(data)
} else {
Ok(self.identifierish(data))
}
} else {
Ok(self.identifierish(data))
}
}
_ => Err(Error::UnrecognizedToken(None)),
}
}
}
fn literal(data: &[u8], quote: u8) -> Result<(Option<Token<'_>>, usize), Error> {
debug_assert_eq!(data[0], quote);
let tt = if quote == b'\'' { TK_STRING } else { TK_ID };
let mut pb = 0;
let mut end = None;
// data[0] == quote => skip(1)
for (i, b) in data.iter().enumerate().skip(1) {
if *b == quote {
if pb == quote {
// escaped quote
pb = 0;
continue;
}
} else if pb == quote {
end = Some(i);
break;
}
pb = *b;
}
if end.is_some() || pb == quote {
let i = match end {
Some(i) => i,
_ => data.len(),
};
// keep original quotes in the token
Ok((Some((&data[0..i], tt)), i))
} else {
Err(Error::UnterminatedLiteral(None))
}
}
fn blob_literal(data: &[u8]) -> Result<(Option<Token<'_>>, usize), Error> {
debug_assert!(data[0] == b'x' || data[0] == b'X');
debug_assert_eq!(data[1], b'\'');
if let Some((i, b)) = data
.iter()
.enumerate()
.skip(2)
.find(|&(_, &b)| !b.is_ascii_hexdigit())
{
if *b != b'\'' || i % 2 != 0 {
return Err(Error::MalformedBlobLiteral(None));
}
Ok((Some((&data[2..i], TK_BLOB)), i + 1))
} else {
Err(Error::MalformedBlobLiteral(None))
}
}
fn number(data: &[u8]) -> Result<(Option<Token<'_>>, usize), Error> {
debug_assert!(data[0].is_ascii_digit());
if data[0] == b'0' {
if let Some(b) = data.get(1) {
if *b == b'x' || *b == b'X' {
return hex_integer(data);
}
} else {
return Ok((Some((data, TK_INTEGER)), data.len()));
}
}
if let Some((i, b)) = find_end_of_number(data, 1, u8::is_ascii_digit)? {
if b == b'.' {
return fractional_part(data, i);
} else if b == b'e' || b == b'E' {
return exponential_part(data, i);
} else if is_identifier_start(b) {
return Err(Error::BadNumber(None));
}
Ok((Some((&data[..i], TK_INTEGER)), i))
} else {
Ok((Some((data, TK_INTEGER)), data.len()))
}
}
fn hex_integer(data: &[u8]) -> Result<(Option<Token<'_>>, usize), Error> {
debug_assert_eq!(data[0], b'0');
debug_assert!(data[1] == b'x' || data[1] == b'X');
if let Some((i, b)) = find_end_of_number(data, 2, u8::is_ascii_hexdigit)? {
// Must not be empty (Ox is invalid)
if i == 2 || is_identifier_start(b) {
return Err(Error::MalformedHexInteger(None));
}
Ok((Some((&data[..i], TK_INTEGER)), i))
} else {
// Must not be empty (Ox is invalid)
if data.len() == 2 {
return Err(Error::MalformedHexInteger(None));
}
Ok((Some((data, TK_INTEGER)), data.len()))
}
}
fn fractional_part(data: &[u8], i: usize) -> Result<(Option<Token<'_>>, usize), Error> {
debug_assert_eq!(data[i], b'.');
if let Some((i, b)) = find_end_of_number(data, i + 1, u8::is_ascii_digit)? {
if b == b'e' || b == b'E' {
return exponential_part(data, i);
} else if is_identifier_start(b) {
return Err(Error::BadNumber(None));
}
Ok((Some((&data[..i], TK_FLOAT)), i))
} else {
Ok((Some((data, TK_FLOAT)), data.len()))
}
}
fn exponential_part(data: &[u8], i: usize) -> Result<(Option<Token<'_>>, usize), Error> {
debug_assert!(data[i] == b'e' || data[i] == b'E');
// data[i] == 'e'|'E'
if let Some(b) = data.get(i + 1) {
let i = if *b == b'+' || *b == b'-' { i + 1 } else { i };
if let Some((j, b)) = find_end_of_number(data, i + 1, u8::is_ascii_digit)? {
if j == i + 1 || is_identifier_start(b) {
return Err(Error::BadNumber(None));
}
Ok((Some((&data[..j], TK_FLOAT)), j))
} else {
if data.len() == i + 1 {
return Err(Error::BadNumber(None));
}
Ok((Some((data, TK_FLOAT)), data.len()))
}
} else {
Err(Error::BadNumber(None))
}
}
fn find_end_of_number(
data: &[u8],
i: usize,
test: fn(&u8) -> bool,
) -> Result<Option<(usize, u8)>, Error> {
for (j, &b) in data.iter().enumerate().skip(i) {
if test(&b) {
continue;
} else if b == b'_' {
if j >= 1 && data.get(j - 1).map_or(false, test) && data.get(j + 1).map_or(false, test)
{
continue;
}
return Err(Error::BadNumber(None));
} else {
return Ok(Some((j, b)));
}
}
Ok(None)
}
impl Tokenizer {
fn identifierish<'input>(&mut self, data: &'input [u8]) -> (Option<Token<'input>>, usize) {
debug_assert!(is_identifier_start(data[0]));
// data[0] is_identifier_start => skip(1)
let end = data
.iter()
.skip(1)
.position(|&b| !is_identifier_continue(b));
let i = match end {
Some(i) => i + 1,
_ => data.len(),
};
let word = &data[..i];
let tt = if word.len() >= 2 && word.len() <= MAX_KEYWORD_LEN && word.is_ascii() {
keyword_token(word).unwrap_or(TK_ID)
} else {
TK_ID
};
(Some((word, tt)), i)
}
}
#[cfg(test)]
mod tests {
use super::Tokenizer;
use crate::dialect::TokenType;
use crate::lexer::sql::Error;
use crate::lexer::Scanner;
#[test]
fn fallible_iterator() -> Result<(), Error> {
let tokenizer = Tokenizer::new();
let input = b"PRAGMA parser_trace=ON;";
let mut s = Scanner::new(tokenizer);
expect_token(&mut s, input, b"PRAGMA", TokenType::TK_PRAGMA)?;
expect_token(&mut s, input, b"parser_trace", TokenType::TK_ID)?;
Ok(())
}
#[test]
fn invalid_number_literal() -> Result<(), Error> {
let tokenizer = Tokenizer::new();
let input = b"SELECT 1E;";
let mut s = Scanner::new(tokenizer);
expect_token(&mut s, input, b"SELECT", TokenType::TK_SELECT)?;
let err = s.scan(input).unwrap_err();
assert!(matches!(err, Error::BadNumber(_)));
Ok(())
}
fn expect_token(
s: &mut Scanner<Tokenizer>,
input: &[u8],
token: &[u8],
token_type: TokenType,
) -> Result<(), Error> {
let (t, tt) = s.scan(input)?.1.unwrap();
assert_eq!(token, t);
assert_eq!(token_type, tt);
Ok(())
}
}

View File

@@ -0,0 +1,376 @@
use fallible_iterator::FallibleIterator;
use super::{Error, Parser};
use crate::parser::ast::fmt::ToTokens;
use crate::parser::{
ast::{Cmd, Name, ParameterInfo, QualifiedName, Stmt},
ParserError,
};
#[test]
fn count_placeholders() {
let ast = parse_cmd(b"SELECT ? WHERE 1 = ?");
let mut info = ParameterInfo::default();
ast.to_tokens(&mut info).unwrap();
assert_eq!(info.count, 2);
}
#[test]
fn count_numbered_placeholders() {
let ast = parse_cmd(b"SELECT ?1 WHERE 1 = ?2 AND 0 = ?1");
let mut info = ParameterInfo::default();
ast.to_tokens(&mut info).unwrap();
assert_eq!(info.count, 2);
}
#[test]
fn count_unused_placeholders() {
let ast = parse_cmd(b"SELECT ?1 WHERE 1 = ?3");
let mut info = ParameterInfo::default();
ast.to_tokens(&mut info).unwrap();
assert_eq!(info.count, 3);
}
#[test]
fn count_named_placeholders() {
let ast = parse_cmd(b"SELECT :x, :y WHERE 1 = :y");
let mut info = ParameterInfo::default();
ast.to_tokens(&mut info).unwrap();
assert_eq!(info.count, 2);
assert_eq!(info.names.len(), 2);
assert!(info.names.contains(":x"));
assert!(info.names.contains(":y"));
}
#[test]
fn duplicate_column() {
expect_parser_err_msg(
b"CREATE TABLE t (x TEXT, x TEXT)",
"duplicate column name: x",
);
expect_parser_err_msg(
b"CREATE TABLE t (x TEXT, \"x\" TEXT)",
"duplicate column name: \"x\"",
);
expect_parser_err_msg(
b"CREATE TABLE t (x TEXT, `x` TEXT)",
"duplicate column name: `x`",
);
}
#[test]
fn create_table_without_column() {
expect_parser_err(
b"CREATE TABLE t ()",
ParserError::SyntaxError(")".to_owned()),
);
}
#[test]
fn vtab_args() -> Result<(), Error> {
let sql = b"CREATE VIRTUAL TABLE mail USING fts3(
subject VARCHAR(256) NOT NULL,
body TEXT CHECK(length(body)<10240)
);";
let r = parse_cmd(sql);
let Cmd::Stmt(Stmt::CreateVirtualTable {
tbl_name: QualifiedName {
name: Name(tbl_name),
..
},
module_name: Name(module_name),
args: Some(args),
..
}) = r
else {
panic!("unexpected AST")
};
assert_eq!(tbl_name, "mail");
assert_eq!(module_name, "fts3");
assert_eq!(args.len(), 2);
assert_eq!(args[0], "subject VARCHAR(256) NOT NULL");
assert_eq!(args[1], "body TEXT CHECK(length(body)<10240)");
Ok(())
}
#[test]
fn only_semicolons_no_statements() {
let sqls = ["", ";", ";;;"];
for sql in &sqls {
let r = parse(sql.as_bytes());
assert_eq!(r.unwrap(), None);
}
}
#[test]
fn extra_semicolons_between_statements() {
let sqls = [
"SELECT 1; SELECT 2",
"SELECT 1; SELECT 2;",
"; SELECT 1; SELECT 2",
";; SELECT 1;; SELECT 2;;",
];
for sql in &sqls {
let mut parser = Parser::new(sql.as_bytes());
assert!(matches!(
parser.next().unwrap(),
Some(Cmd::Stmt(Stmt::Select { .. }))
));
assert!(matches!(
parser.next().unwrap(),
Some(Cmd::Stmt(Stmt::Select { .. }))
));
assert_eq!(parser.next().unwrap(), None);
}
}
#[test]
fn extra_comments_between_statements() {
let sqls = [
"-- abc\nSELECT 1; --def\nSELECT 2 -- ghj",
"/* abc */ SELECT 1; /* def */ SELECT 2; /* ghj */",
"/* abc */; SELECT 1 /* def */; SELECT 2 /* ghj */",
"/* abc */;; SELECT 1;/* def */; SELECT 2; /* ghj */; /* klm */",
];
for sql in &sqls {
let mut parser = Parser::new(sql.as_bytes());
assert!(matches!(
parser.next().unwrap(),
Some(Cmd::Stmt(Stmt::Select { .. }))
));
assert!(matches!(
parser.next().unwrap(),
Some(Cmd::Stmt(Stmt::Select { .. }))
));
assert_eq!(parser.next().unwrap(), None);
}
}
#[test]
fn insert_mismatch_count() {
expect_parser_err_msg(b"INSERT INTO t (a, b) VALUES (1)", "1 values for 2 columns");
}
#[test]
fn insert_default_values() {
expect_parser_err_msg(
b"INSERT INTO t (a) DEFAULT VALUES",
"0 values for 1 columns",
);
}
#[test]
fn create_view_mismatch_count() {
expect_parser_err_msg(
b"CREATE VIEW v (c1, c2) AS SELECT 1",
"expected 2 columns for v but got 1",
);
}
#[test]
fn create_view_duplicate_column_name() {
expect_parser_err_msg(
b"CREATE VIEW v (c1, c1) AS SELECT 1, 2",
"duplicate column name: c1",
);
}
#[test]
fn create_table_without_rowid_missing_pk() {
expect_parser_err_msg(
b"CREATE TABLE t (c1) WITHOUT ROWID",
"PRIMARY KEY missing on table t",
);
}
#[test]
fn create_temporary_table_with_qualified_name() {
expect_parser_err_msg(
b"CREATE TEMPORARY TABLE mem.x AS SELECT 1",
"temporary table name must be unqualified",
);
parse_cmd(b"CREATE TEMPORARY TABLE temp.x AS SELECT 1");
}
#[test]
fn create_table_with_only_generated_column() {
expect_parser_err_msg(
b"CREATE TABLE test(data AS (1))",
"must have at least one non-generated column",
);
}
#[test]
fn create_strict_table_missing_datatype() {
expect_parser_err_msg(b"CREATE TABLE t (c1) STRICT", "missing datatype for t.c1");
}
#[test]
fn create_strict_table_unknown_datatype() {
expect_parser_err_msg(
b"CREATE TABLE t (c1 BOOL) STRICT",
"unknown datatype for t.c1: \"BOOL\"",
);
}
#[test]
fn foreign_key_on_column() {
expect_parser_err_msg(
b"CREATE TABLE t(a REFERENCES o(a,b))",
"foreign key on a should reference only one column of table o",
);
}
#[test]
fn create_strict_table_generated_column() {
parse_cmd(
b"CREATE TABLE IF NOT EXISTS transactions (
debit REAL,
credit REAL,
amount REAL GENERATED ALWAYS AS (ifnull(credit, 0.0) -ifnull(debit, 0.0))
) STRICT;",
);
}
#[test]
fn selects_compound_mismatch_columns_count() {
expect_parser_err_msg(
b"SELECT 1 UNION SELECT 1, 2",
"SELECTs to the left and right of UNION do not have the same number of result columns",
);
}
#[test]
fn delete_order_by_without_limit() {
expect_parser_err_msg(
b"DELETE FROM t ORDER BY x",
"ORDER BY without LIMIT on DELETE",
);
}
#[test]
fn update_order_by_without_limit() {
expect_parser_err_msg(
b"UPDATE t SET x = 1 ORDER BY x",
"ORDER BY without LIMIT on UPDATE",
);
}
#[test]
fn values_mismatch_columns_count() {
expect_parser_err_msg(
b"INSERT INTO t VALUES (1), (1,2)",
"all VALUES must have the same number of terms",
);
}
#[test]
fn column_specified_more_than_once() {
expect_parser_err_msg(
b"INSERT INTO t (n, n, m) VALUES (1, 0, 2)",
"column \"n\" specified more than once",
)
}
#[test]
fn alter_add_column_primary_key() {
expect_parser_err_msg(
b"ALTER TABLE t ADD COLUMN c PRIMARY KEY",
"Cannot add a PRIMARY KEY column",
);
}
#[test]
fn alter_add_column_unique() {
expect_parser_err_msg(
b"ALTER TABLE t ADD COLUMN c UNIQUE",
"Cannot add a UNIQUE column",
);
}
#[test]
fn alter_rename_same() {
expect_parser_err_msg(
b"ALTER TABLE t RENAME TO t",
"there is already another table or index with this name: t",
);
}
#[test]
fn natural_join_on() {
expect_parser_err_msg(
b"SELECT x FROM t NATURAL JOIN t USING (x)",
"a NATURAL join may not have an ON or USING clause",
);
expect_parser_err_msg(
b"SELECT x FROM t NATURAL JOIN t ON t.x = t.x",
"a NATURAL join may not have an ON or USING clause",
);
}
#[test]
fn missing_join_clause() {
expect_parser_err_msg(
b"SELECT a FROM tt ON b",
"a JOIN clause is required before ON",
);
}
#[test]
fn cast_without_typename() {
parse_cmd(b"SELECT CAST(a AS ) FROM t");
}
#[test]
fn unknown_table_option() {
expect_parser_err_msg(b"CREATE TABLE t(x)o", "unknown table option: o");
expect_parser_err_msg(b"CREATE TABLE t(x) WITHOUT o", "unknown table option: o");
}
#[test]
fn qualified_table_name_within_triggers() {
expect_parser_err_msg(
b"CREATE TRIGGER tr1 AFTER INSERT ON t1 BEGIN
DELETE FROM main.t2;
END;",
"qualified table names are not allowed on INSERT, UPDATE, and DELETE statements \
within triggers",
);
}
#[test]
fn indexed_by_clause_within_triggers() {
expect_parser_err_msg(
b"CREATE TRIGGER main.t16err5 AFTER INSERT ON tA BEGIN
UPDATE t16 INDEXED BY t16a SET rowid=rowid+1 WHERE a=1;
END;",
"the INDEXED BY clause is not allowed on UPDATE or DELETE statements \
within triggers",
);
expect_parser_err_msg(
b"CREATE TRIGGER main.t16err6 AFTER INSERT ON tA BEGIN
DELETE FROM t16 NOT INDEXED WHERE a=123;
END;",
"the NOT INDEXED clause is not allowed on UPDATE or DELETE statements \
within triggers",
);
}
fn expect_parser_err_msg(input: &[u8], error_msg: &str) {
expect_parser_err(input, ParserError::Custom(error_msg.to_owned()))
}
fn expect_parser_err(input: &[u8], err: ParserError) {
let r = parse(input);
if let Error::ParserError(e, _) = r.unwrap_err() {
assert_eq!(e, err);
} else {
panic!("unexpected error type")
};
}
fn parse_cmd(input: &[u8]) -> Cmd {
parse(input).unwrap().unwrap()
}
fn parse(input: &[u8]) -> Result<Option<Cmd>, Error> {
let mut parser = Parser::new(input);
parser.next()
}