complete parser integration

This commit is contained in:
Levy A.
2025-08-21 15:03:28 -03:00
parent c6b032de63
commit 4ba1304fb9
36 changed files with 1047 additions and 1283 deletions

View File

@@ -16,7 +16,7 @@ pub enum LimboError {
ParseError(String),
#[error(transparent)]
#[diagnostic(transparent)]
LexerError(#[from] turso_sqlite3_parser::lexer::sql::Error),
LexerError(#[from] turso_parser::error::Error),
#[error("Conversion error: {0}")]
ConversionError(String),
#[error("Env variable error: {0}")]

View File

@@ -323,8 +323,11 @@ impl FilterPredicate {
pub fn from_select(select: &turso_parser::ast::Select) -> crate::Result<Self> {
use turso_parser::ast::*;
if let OneSelect::Select(select_stmt) = &*select.body.select {
if let Some(where_clause) = &select_stmt.where_clause {
if let OneSelect::Select {
ref where_clause, ..
} = select.body.select
{
if let Some(where_clause) = where_clause {
Self::from_sql_expr(where_clause)
} else {
Ok(FilterPredicate::None)
@@ -344,7 +347,7 @@ pub enum ProjectColumn {
Column(String),
/// Computed expression
Expression {
expr: turso_parser::ast::Expr,
expr: Box<turso_parser::ast::Expr>,
alias: Option<String>,
},
}
@@ -643,11 +646,7 @@ impl ProjectOperator {
output
}
fn evaluate_expression(
&self,
expr: &turso_parser::ast::Expr,
values: &[Value],
) -> Value {
fn evaluate_expression(&self, expr: &turso_parser::ast::Expr, values: &[Value]) -> Value {
use turso_parser::ast::*;
match expr {
@@ -749,15 +748,11 @@ impl ProjectOperator {
Expr::FunctionCall { name, args, .. } => {
match name.as_str().to_lowercase().as_str() {
"hex" => {
if let Some(arg_list) = args {
if arg_list.len() == 1 {
let arg_val = self.evaluate_expression(&arg_list[0], values);
match arg_val {
Value::Integer(i) => Value::Text(Text::new(&format!("{i:X}"))),
_ => Value::Null,
}
} else {
Value::Null
if args.len() == 1 {
let arg_val = self.evaluate_expression(&args[0], values);
match arg_val {
Value::Integer(i) => Value::Text(Text::new(&format!("{i:X}"))),
_ => Value::Null,
}
} else {
Value::Null

View File

@@ -7,10 +7,10 @@ use crate::schema::{BTreeTable, Column, Schema};
use crate::types::{IOCompletions, IOResult, Value};
use crate::util::{extract_column_name_from_expr, extract_view_columns};
use crate::{io_yield_one, Completion, LimboError, Result, Statement};
use fallible_iterator::FallibleIterator;
use std::collections::BTreeMap;
use std::fmt;
use std::sync::{Arc, Mutex};
use turso_parser::ast;
use turso_parser::{
ast::{Cmd, Stmt},
parser::Parser,
@@ -73,7 +73,7 @@ pub struct IncrementalView {
// WHERE clause predicate for filtering (kept for compatibility)
pub where_predicate: FilterPredicate,
// The SELECT statement that defines how to transform input data
pub select_stmt: Box<turso_parser::ast::Select>,
pub select_stmt: Box<ast::Select>,
// Internal filter operator for predicate evaluation
filter_operator: Option<FilterOperator>,
@@ -96,10 +96,7 @@ pub struct IncrementalView {
impl IncrementalView {
/// Validate that a CREATE MATERIALIZED VIEW statement can be handled by IncrementalView
/// This should be called early, before updating sqlite_master
pub fn can_create_view(
select: &turso_parser::ast::Select,
schema: &Schema,
) -> Result<()> {
pub fn can_create_view(select: &ast::Select, schema: &Schema) -> Result<()> {
// Check for aggregations
let (group_by_columns, aggregate_functions, _) = Self::extract_aggregation_info(select);
@@ -150,7 +147,7 @@ impl IncrementalView {
pub fn has_same_sql(&self, sql: &str) -> bool {
// Parse the SQL to extract just the SELECT statement
if let Ok(Some(Cmd::Stmt(Stmt::CreateMaterializedView { select, .. }))) =
Parser::new(sql.as_bytes()).next()
Parser::new(sql.as_bytes()).next_cmd()
{
// Compare the SELECT statements as SQL strings
use turso_parser::ast::fmt::ToTokens;
@@ -175,7 +172,7 @@ impl IncrementalView {
}
pub fn from_sql(sql: &str, schema: &Schema) -> Result<Self> {
let mut parser = Parser::new(sql.as_bytes());
let cmd = parser.next()?;
let cmd = parser.next_cmd()?;
let cmd = cmd.expect("View is an empty statement");
match cmd {
Cmd::Stmt(Stmt::CreateMaterializedView {
@@ -183,7 +180,7 @@ impl IncrementalView {
view_name,
columns: _,
select,
}) => IncrementalView::from_stmt(view_name, select, schema),
}) => IncrementalView::from_stmt(view_name, select.into(), schema),
_ => Err(LimboError::ParseError(format!(
"View is not a CREATE MATERIALIZED VIEW statement: {sql}"
))),
@@ -191,8 +188,8 @@ impl IncrementalView {
}
pub fn from_stmt(
view_name: turso_parser::ast::QualifiedName,
select: Box<turso_parser::ast::Select>,
view_name: ast::QualifiedName,
select: Box<ast::Select>,
schema: &Schema,
) -> Result<Self> {
let name = view_name.name.as_str().to_string();
@@ -253,7 +250,7 @@ impl IncrementalView {
name: String,
initial_data: Vec<(i64, Vec<Value>)>,
where_predicate: FilterPredicate,
select_stmt: Box<turso_parser::ast::Select>,
select_stmt: Box<ast::Select>,
base_table: Arc<BTreeTable>,
base_table_column_names: Vec<String>,
columns: Vec<Column>,
@@ -353,19 +350,20 @@ impl IncrementalView {
/// Validate that view columns are a strict subset of the base table columns
/// No duplicates, no complex expressions, only simple column references
fn validate_view_columns(
select: &turso_parser::ast::Select,
select: &ast::Select,
base_table_column_names: &[String],
) -> Result<()> {
if let turso_parser::ast::OneSelect::Select(ref select_stmt) = &*select.body.select
{
if let ast::OneSelect::Select { ref columns, .. } = select.body.select {
let mut seen_columns = std::collections::HashSet::new();
for result_col in &select_stmt.columns {
for result_col in columns {
match result_col {
turso_parser::ast::ResultColumn::Expr(
turso_parser::ast::Expr::Id(name),
_,
) => {
ast::ResultColumn::Expr(expr, _)
if matches!(expr.as_ref(), ast::Expr::Id(_)) =>
{
let ast::Expr::Id(name) = expr.as_ref() else {
unreachable!()
};
let col_name = name.as_str();
// Check for duplicates
@@ -382,7 +380,7 @@ impl IncrementalView {
)));
}
}
turso_parser::ast::ResultColumn::Star => {
ast::ResultColumn::Star => {
// SELECT * is allowed - it's the full set
}
_ => {
@@ -396,17 +394,14 @@ impl IncrementalView {
}
/// Extract the base table name from a SELECT statement (for non-join cases)
fn extract_base_table(select: &turso_parser::ast::Select) -> Option<String> {
if let turso_parser::ast::OneSelect::Select(ref select_stmt) = &*select.body.select
fn extract_base_table(select: &ast::Select) -> Option<String> {
if let ast::OneSelect::Select {
from: Some(ref from),
..
} = select.body.select
{
if let Some(ref from) = &select_stmt.from {
if let Some(ref select_table) = &from.select {
if let turso_parser::ast::SelectTable::Table(name, _, _) =
&**select_table
{
return Some(name.name.as_str().to_string());
}
}
if let ast::SelectTable::Table(name, _, _) = from.select.as_ref() {
return Some(name.name.as_str().to_string());
}
}
None
@@ -625,7 +620,7 @@ impl IncrementalView {
/// Extract GROUP BY columns and aggregate functions from SELECT statement
fn extract_aggregation_info(
select: &turso_parser::ast::Select,
select: &ast::Select,
) -> (Vec<String>, Vec<AggregateFunction>, Vec<String>) {
use turso_parser::ast::*;
@@ -633,9 +628,14 @@ impl IncrementalView {
let mut aggregate_functions = Vec::new();
let mut output_column_names = Vec::new();
if let OneSelect::Select(ref select_stmt) = &*select.body.select {
if let OneSelect::Select {
ref group_by,
ref columns,
..
} = select.body.select
{
// Extract GROUP BY columns
if let Some(ref group_by) = select_stmt.group_by {
if let Some(group_by) = group_by {
for expr in &group_by.exprs {
if let Some(col_name) = extract_column_name_from_expr(expr) {
group_by_columns.push(col_name);
@@ -644,7 +644,7 @@ impl IncrementalView {
}
// Extract aggregate functions and column names/aliases from SELECT list
for result_col in &select_stmt.columns {
for result_col in columns {
match result_col {
ResultColumn::Expr(expr, alias) => {
// Extract aggregate functions
@@ -685,7 +685,7 @@ impl IncrementalView {
/// Recursively extract aggregate functions from an expression
fn extract_aggregates_from_expr(
expr: &turso_parser::ast::Expr,
expr: &ast::Expr,
aggregate_functions: &mut Vec<AggregateFunction>,
) {
use crate::function::Func;
@@ -705,14 +705,12 @@ impl IncrementalView {
}
Expr::FunctionCall { name, args, .. } => {
// Regular function calls with arguments
let arg_count = args.as_ref().map_or(0, |a| a.len());
let arg_count = args.len();
if let Ok(func) = Func::resolve_function(name.as_str(), arg_count) {
// Extract the input column if there's an argument
let input_column = if arg_count > 0 {
args.as_ref()
.and_then(|args| args.first())
.and_then(extract_column_name_from_expr)
args.first().and_then(extract_column_name_from_expr)
} else {
None
};
@@ -737,53 +735,42 @@ impl IncrementalView {
/// Extract JOIN information from SELECT statement
#[allow(clippy::type_complexity)]
pub fn extract_join_info(
select: &turso_parser::ast::Select,
select: &ast::Select,
) -> (Option<(String, String)>, Option<(String, String)>) {
use turso_parser::ast::*;
if let OneSelect::Select(ref select_stmt) = &*select.body.select {
if let Some(ref from) = &select_stmt.from {
// Check if there are any joins
if let Some(ref joins) = &from.joins {
if !joins.is_empty() {
// Get the first (left) table name
let left_table = if let Some(ref select_table) = &from.select {
match &**select_table {
SelectTable::Table(name, _, _) => {
Some(name.name.as_str().to_string())
}
_ => None,
}
} else {
None
};
if let OneSelect::Select {
from: Some(ref from),
..
} = select.body.select
{
// Check if there are any joins
if !from.joins.is_empty() {
// Get the first (left) table name
let left_table = match from.select.as_ref() {
SelectTable::Table(name, _, _) => Some(name.name.as_str().to_string()),
_ => None,
};
// Get the first join (right) table and condition
if let Some(first_join) = joins.first() {
let right_table = match &first_join.table {
SelectTable::Table(name, _, _) => {
Some(name.name.as_str().to_string())
}
_ => None,
};
// Get the first join (right) table and condition
if let Some(first_join) = from.joins.first() {
let right_table = match &first_join.table.as_ref() {
SelectTable::Table(name, _, _) => Some(name.name.as_str().to_string()),
_ => None,
};
// Extract join condition (simplified - assumes single equality)
let join_condition =
if let Some(ref constraint) = &first_join.constraint {
match constraint {
JoinConstraint::On(expr) => {
Self::extract_join_columns_from_expr(expr)
}
_ => None,
}
} else {
None
};
if let (Some(left), Some(right)) = (left_table, right_table) {
return (Some((left, right)), join_condition);
}
// Extract join condition (simplified - assumes single equality)
let join_condition = if let Some(ref constraint) = &first_join.constraint {
match constraint {
JoinConstraint::On(expr) => Self::extract_join_columns_from_expr(expr),
_ => None,
}
} else {
None
};
if let (Some(left), Some(right)) = (left_table, right_table) {
return (Some((left, right)), join_condition);
}
}
}
@@ -793,9 +780,7 @@ impl IncrementalView {
}
/// Extract join column names from a join condition expression
fn extract_join_columns_from_expr(
expr: &turso_parser::ast::Expr,
) -> Option<(String, String)> {
fn extract_join_columns_from_expr(expr: &ast::Expr) -> Option<(String, String)> {
use turso_parser::ast::*;
// Look for expressions like: t1.col = t2.col
@@ -825,18 +810,22 @@ impl IncrementalView {
/// Extract projection columns from SELECT statement
fn extract_project_columns(
select: &turso_parser::ast::Select,
select: &ast::Select,
column_names: &[String],
) -> Option<Vec<ProjectColumn>> {
use turso_parser::ast::*;
if let OneSelect::Select(ref select_stmt) = &*select.body.select {
if let OneSelect::Select {
columns: ref select_columns,
..
} = select.body.select
{
let mut columns = Vec::new();
for result_col in &select_stmt.columns {
for result_col in select_columns {
match result_col {
ResultColumn::Expr(expr, alias) => {
match expr {
match expr.as_ref() {
Expr::Id(name) => {
// Simple column reference
columns.push(ProjectColumn::Column(name.as_str().to_string()));

View File

@@ -51,7 +51,6 @@ use crate::vdbe::metrics::ConnectionMetrics;
use crate::vtab::VirtualTable;
use core::str;
pub use error::{CompletionError, LimboError};
use fallible_iterator::FallibleIterator;
pub use io::clock::{Clock, Instant};
#[cfg(all(feature = "fs", target_family = "unix"))]
pub use io::UnixIO;
@@ -875,7 +874,7 @@ impl Connection {
let sql = sql.as_ref();
tracing::trace!("Preparing: {}", sql);
let mut parser = Parser::new(sql.as_bytes());
let cmd = parser.next()?;
let cmd = parser.next_cmd()?;
let syms = self.syms.borrow();
let cmd = cmd.expect("Successful parse on nonempty input string should produce a command");
let byte_offset_end = parser.offset();
@@ -1032,7 +1031,7 @@ impl Connection {
let sql = sql.as_ref();
tracing::trace!("Preparing and executing batch: {}", sql);
let mut parser = Parser::new(sql.as_bytes());
while let Some(cmd) = parser.next()? {
while let Some(cmd) = parser.next_cmd()? {
let syms = self.syms.borrow();
let pager = self.pager.borrow().clone();
let byte_offset_end = parser.offset();
@@ -1068,7 +1067,7 @@ impl Connection {
let sql = sql.as_ref();
tracing::trace!("Querying: {}", sql);
let mut parser = Parser::new(sql.as_bytes());
let cmd = parser.next()?;
let cmd = parser.next_cmd()?;
let byte_offset_end = parser.offset();
let input = str::from_utf8(&sql.as_bytes()[..byte_offset_end])
.unwrap()
@@ -1110,7 +1109,7 @@ impl Connection {
ast::Stmt::Select(select) => {
let mut plan = prepare_select_plan(
self.schema.borrow().deref(),
*select,
select,
&syms,
&[],
&mut table_ref_counter,
@@ -1140,7 +1139,7 @@ impl Connection {
}
let sql = sql.as_ref();
let mut parser = Parser::new(sql.as_bytes());
while let Some(cmd) = parser.next()? {
while let Some(cmd) = parser.next_cmd()? {
let syms = self.syms.borrow();
let pager = self.pager.borrow().clone();
let byte_offset_end = parser.offset();
@@ -2017,7 +2016,7 @@ impl Statement {
*conn.schema.borrow_mut() = conn._db.clone_schema()?;
self.program = {
let mut parser = Parser::new(self.program.sql.as_bytes());
let cmd = parser.next()?;
let cmd = parser.next_cmd()?;
let cmd = cmd.expect("Same SQL string should be able to be parsed");
let syms = conn.syms.borrow();
@@ -2084,7 +2083,7 @@ impl Statement {
pub fn get_column_type(&self, idx: usize) -> Option<String> {
let column = &self.program.result_columns.get(idx).expect("No column");
match &column.expr {
turso_sqlite3_parser::ast::Expr::Column {
turso_parser::ast::Expr::Column {
table,
column: column_idx,
..
@@ -2227,7 +2226,7 @@ impl Iterator for QueryRunner<'_> {
type Item = Result<Option<Statement>>;
fn next(&mut self) -> Option<Self::Item> {
match self.parser.next() {
match self.parser.next_cmd() {
Ok(Some(cmd)) => {
let byte_offset_end = self.parser.offset();
let input = str::from_utf8(&self.statements[self.last_offset..byte_offset_end])
@@ -2237,10 +2236,7 @@ impl Iterator for QueryRunner<'_> {
Some(self.conn.run_cmd(cmd, input))
}
Ok(None) => None,
Err(err) => {
self.parser.finalize();
Some(Result::Err(LimboError::from(err)))
}
Err(err) => Some(Result::Err(LimboError::from(err))),
}
}
}

View File

@@ -10,7 +10,7 @@ pub struct View {
pub name: String,
pub sql: String,
pub select_stmt: ast::Select,
pub columns: Option<Vec<Column>>,
pub columns: Vec<Column>,
}
/// Type alias for regular views collection
@@ -24,7 +24,6 @@ use crate::util::{module_args_from_sql, module_name_from_sql, IOExt, UnparsedFro
use crate::{return_if_io, LimboError, MvCursor, Pager, RefValue, SymbolTable, VirtualTable};
use crate::{util::normalize_ident, Result};
use core::fmt;
use fallible_iterator::FallibleIterator;
use std::cell::RefCell;
use std::collections::hash_map::Entry;
use std::collections::{BTreeSet, HashMap};
@@ -409,7 +408,7 @@ impl Schema {
// Parse the SQL to determine if it's a regular or materialized view
let mut parser = Parser::new(sql.as_bytes());
if let Ok(Some(Cmd::Stmt(stmt))) = parser.next() {
if let Ok(Some(Cmd::Stmt(stmt))) = parser.next_cmd() {
match stmt {
Stmt::CreateMaterializedView { .. } => {
// Create IncrementalView for materialized views
@@ -434,11 +433,9 @@ impl Schema {
// If column names were provided in CREATE VIEW (col1, col2, ...),
// use them to rename the columns
let mut final_columns = view_columns;
if let Some(ref names) = column_names {
for (i, indexed_col) in names.iter().enumerate() {
if let Some(col) = final_columns.get_mut(i) {
col.name = Some(indexed_col.col_name.to_string());
}
for (i, indexed_col) in column_names.iter().enumerate() {
if let Some(col) = final_columns.get_mut(i) {
col.name = Some(indexed_col.col_name.to_string());
}
}
@@ -446,8 +443,8 @@ impl Schema {
let view = View {
name: name.to_string(),
sql: sql.to_string(),
select_stmt: *select,
columns: Some(final_columns),
select_stmt: select,
columns: final_columns,
};
self.add_view(view);
}
@@ -696,10 +693,10 @@ impl BTreeTable {
pub fn from_sql(sql: &str, root_page: usize) -> Result<BTreeTable> {
let mut parser = Parser::new(sql.as_bytes());
let cmd = parser.next()?;
let cmd = parser.next_cmd()?;
match cmd {
Some(Cmd::Stmt(Stmt::CreateTable { tbl_name, body, .. })) => {
create_table(tbl_name, *body, root_page)
create_table(tbl_name, body, root_page)
}
_ => unreachable!("Expected CREATE TABLE statement"),
}
@@ -833,53 +830,53 @@ fn create_table(
options,
} => {
is_strict = options.contains(TableOptions::STRICT);
if let Some(constraints) = constraints {
for c in constraints {
if let turso_sqlite3_parser::ast::TableConstraint::PrimaryKey {
columns, ..
} = c.constraint
{
for column in columns {
let col_name = match column.expr {
for c in constraints {
if let ast::TableConstraint::PrimaryKey { columns, .. } = c.constraint {
for column in columns {
let col_name = match column.expr.as_ref() {
Expr::Id(id) => normalize_ident(id.as_str()),
Expr::Literal(Literal::String(value)) => {
value.trim_matches('\'').to_owned()
}
_ => {
todo!("Unsupported primary key expression");
}
};
primary_key_columns
.push((col_name, column.order.unwrap_or(SortOrder::Asc)));
}
} else if let ast::TableConstraint::Unique {
columns,
conflict_clause,
} = c.constraint
{
if conflict_clause.is_some() {
unimplemented!("ON CONFLICT not implemented");
}
let unique_set = columns
.into_iter()
.map(|column| {
let column_name = match column.expr.as_ref() {
Expr::Id(id) => normalize_ident(id.as_str()),
Expr::Literal(Literal::String(value)) => {
value.trim_matches('\'').to_owned()
}
_ => {
todo!("Unsupported primary key expression");
todo!("Unsupported unique expression");
}
};
primary_key_columns
.push((col_name, column.order.unwrap_or(SortOrder::Asc)));
}
} else if let turso_sqlite3_parser::ast::TableConstraint::Unique {
columns,
conflict_clause,
} = c.constraint
{
if conflict_clause.is_some() {
unimplemented!("ON CONFLICT not implemented");
}
let unique_set = columns
.into_iter()
.map(|column| {
let column_name = match column.expr {
Expr::Id(id) => normalize_ident(id.as_str()),
_ => {
todo!("Unsupported unique expression");
}
};
UniqueColumnProps {
column_name,
order: column.order.unwrap_or(SortOrder::Asc),
}
})
.collect();
unique_sets.push(unique_set);
}
UniqueColumnProps {
column_name,
order: column.order.unwrap_or(SortOrder::Asc),
}
})
.collect();
unique_sets.push(unique_set);
}
}
for (col_name, col_def) in columns {
for ast::ColumnDefinition {
col_name,
col_type,
constraints,
} in &columns
{
let name = col_name.as_str().to_string();
// Regular sqlite tables have an integer rowid that uniquely identifies a row.
// Even if you create a table with a column e.g. 'id INT PRIMARY KEY', there will still
@@ -889,17 +886,17 @@ fn create_table(
// A column defined as exactly INTEGER PRIMARY KEY is a rowid alias, meaning that the rowid
// and the value of this column are the same.
// https://www.sqlite.org/lang_createtable.html#rowids_and_the_integer_primary_key
let ty_str = col_def
.col_type
let ty_str = col_type
.as_ref()
.cloned()
.map(|ast::Type { name, .. }| name.clone())
.unwrap_or_default();
let mut typename_exactly_integer = false;
let ty = match col_def.col_type {
let ty = match col_type {
Some(data_type) => 'ty: {
// https://www.sqlite.org/datatype3.html
let mut type_name = data_type.name;
let mut type_name = data_type.name.clone();
type_name.make_ascii_uppercase();
if type_name.is_empty() {
@@ -938,33 +935,28 @@ fn create_table(
let mut order = SortOrder::Asc;
let mut unique = false;
let mut collation = None;
for c_def in col_def.constraints {
for c_def in constraints {
match c_def.constraint {
turso_sqlite3_parser::ast::ColumnConstraint::PrimaryKey {
order: o,
..
} => {
ast::ColumnConstraint::PrimaryKey { order: o, .. } => {
primary_key = true;
if let Some(o) = o {
order = o;
}
}
turso_sqlite3_parser::ast::ColumnConstraint::NotNull {
ast::ColumnConstraint::NotNull {
nullable, ..
} => {
notnull = !nullable;
}
turso_sqlite3_parser::ast::ColumnConstraint::Default(expr) => {
default = Some(expr)
}
ast::ColumnConstraint::Default(ref expr) => default = Some(expr),
// TODO: for now we don't check Resolve type of unique
turso_sqlite3_parser::ast::ColumnConstraint::Unique(on_conflict) => {
ast::ColumnConstraint::Unique(on_conflict) => {
if on_conflict.is_some() {
unimplemented!("ON CONFLICT not implemented");
}
unique = true;
}
turso_sqlite3_parser::ast::ColumnConstraint::Collate { collation_name } => {
ast::ColumnConstraint::Collate { ref collation_name } => {
collation = Some(CollationSeq::new(collation_name.as_str())?);
}
_ => {}
@@ -987,7 +979,7 @@ fn create_table(
primary_key,
is_rowid_alias: typename_exactly_integer && primary_key,
notnull,
default,
default: default.cloned(),
unique,
collation,
hidden: false,
@@ -1059,7 +1051,7 @@ pub struct Column {
pub primary_key: bool,
pub is_rowid_alias: bool,
pub notnull: bool,
pub default: Option<Expr>,
pub default: Option<Box<Expr>>,
pub unique: bool,
pub collation: Option<CollationSeq>,
pub hidden: bool,
@@ -1441,13 +1433,13 @@ pub struct IndexColumn {
/// b.pos_in_table == 1
pub pos_in_table: usize,
pub collation: Option<CollationSeq>,
pub default: Option<Expr>,
pub default: Option<Box<Expr>>,
}
impl Index {
pub fn from_sql(sql: &str, root_page: usize, table: &BTreeTable) -> Result<Index> {
let mut parser = Parser::new(sql.as_bytes());
let cmd = parser.next()?;
let cmd = parser.next_cmd()?;
match cmd {
Some(Cmd::Stmt(Stmt::CreateIndex {
idx_name,

View File

@@ -7114,7 +7114,7 @@ mod tests {
};
use sorted_vec::SortedVec;
use test_log::test;
use turso_sqlite3_parser::ast::SortOrder;
use turso_parser::ast::SortOrder;
use super::*;
use crate::{

View File

@@ -68,7 +68,7 @@ fn emit_collseq_if_needed(
) {
// Check if this is a column expression with explicit COLLATE clause
if let ast::Expr::Collate(_, collation_name) = expr {
if let Ok(collation) = CollationSeq::new(collation_name) {
if let Ok(collation) = CollationSeq::new(collation_name.as_str()) {
program.emit_insn(Insn::CollSeq {
reg: None,
collation,
@@ -189,8 +189,8 @@ pub fn translate_aggregation_step(
if agg.args.len() == 2 {
match &agg.args[1] {
ast::Expr::Column { .. } => {
delimiter_expr = agg.args[1].clone();
arg @ ast::Expr::Column { .. } => {
delimiter_expr = arg.clone();
}
ast::Expr::Literal(ast::Literal::String(s)) => {
delimiter_expr = ast::Expr::Literal(ast::Literal::String(s.to_string()));
@@ -309,7 +309,7 @@ pub fn translate_aggregation_step(
let expr = &agg.args[0];
let delimiter_expr = match &agg.args[1] {
ast::Expr::Column { .. } => agg.args[1].clone(),
arg @ ast::Expr::Column { .. } => arg.clone(),
ast::Expr::Literal(ast::Literal::String(s)) => {
ast::Expr::Literal(ast::Literal::String(s.to_string()))
}

View File

@@ -1,4 +1,3 @@
use fallible_iterator::FallibleIterator as _;
use std::sync::Arc;
use turso_parser::{ast, parser::Parser};
@@ -16,7 +15,7 @@ use crate::{
use super::{schema::SQLITE_TABLEID, update::translate_update_for_schema_change};
pub fn translate_alter_table(
alter: (ast::QualifiedName, ast::AlterTableBody),
alter: ast::AlterTable,
syms: &SymbolTable,
schema: &Schema,
mut program: ProgramBuilder,
@@ -24,7 +23,10 @@ pub fn translate_alter_table(
input: &str,
) -> Result<ProgramBuilder> {
program.begin_write_operation();
let (table_name, alter_table) = alter;
let ast::AlterTable {
name: table_name,
body: alter_table,
} = alter;
let table_name = table_name.name.as_str();
if schema.table_has_indexes(table_name) && !schema.indexes_enabled() {
// Let's disable altering a table with indices altogether instead of checking column by
@@ -91,7 +93,8 @@ pub fn translate_alter_table(
);
let mut parser = Parser::new(stmt.as_bytes());
let Some(ast::Cmd::Stmt(ast::Stmt::Update(mut update))) = parser.next().unwrap() else {
let Some(ast::Cmd::Stmt(ast::Stmt::Update(mut update))) = parser.next_cmd().unwrap()
else {
unreachable!();
};
@@ -167,7 +170,7 @@ pub fn translate_alter_table(
if let Some(default) = &column.default {
if !matches!(
default,
default.as_ref(),
ast::Expr::Literal(
ast::Literal::Null
| ast::Literal::Blob(_)
@@ -204,7 +207,8 @@ pub fn translate_alter_table(
);
let mut parser = Parser::new(stmt.as_bytes());
let Some(ast::Cmd::Stmt(ast::Stmt::Update(mut update))) = parser.next().unwrap() else {
let Some(ast::Cmd::Stmt(ast::Stmt::Update(mut update))) = parser.next_cmd().unwrap()
else {
unreachable!();
};

View File

@@ -16,8 +16,8 @@ pub fn translate_delete(
schema: &Schema,
tbl_name: &QualifiedName,
where_clause: Option<Box<Expr>>,
limit: Option<Box<Limit>>,
returning: Option<Vec<ResultColumn>>,
limit: Option<Limit>,
returning: Vec<ResultColumn>,
syms: &SymbolTable,
mut program: ProgramBuilder,
connection: &Arc<crate::Connection>,
@@ -35,7 +35,7 @@ pub fn translate_delete(
// the result set, and only after that it opens the table for writing and deletes the rows. It
// also uses a couple of instructions that we don't implement yet (i.e.: RowSetAdd, RowSetRead,
// RowSetTest). So for now I'll just defer it altogether.
if returning.is_some() {
if !returning.is_empty() {
crate::bail_parse_error!("RETURNING currently not implemented for DELETE statements.");
}
let result_columns = vec![];
@@ -67,7 +67,7 @@ pub fn prepare_delete_plan(
schema: &Schema,
tbl_name: String,
where_clause: Option<Box<Expr>>,
limit: Option<Box<Limit>>,
limit: Option<Limit>,
result_columns: Vec<super::plan::ResultSetColumn>,
table_ref_counter: &mut TableRefIdCounter,
connection: &Arc<crate::Connection>,
@@ -99,7 +99,7 @@ pub fn prepare_delete_plan(
// Parse the WHERE clause
parse_where(
where_clause.map(|e| *e),
where_clause.as_deref(),
&mut table_references,
None,
&mut where_predicates,
@@ -113,7 +113,7 @@ pub fn prepare_delete_plan(
table_references,
result_columns,
where_clause: where_predicates,
order_by: None,
order_by: vec![],
limit: resolved_limit,
offset: resolved_offset,
contains_constant_false_condition: false,

View File

@@ -202,9 +202,9 @@ impl fmt::Display for UpdatePlan {
},
}
}
if let Some(order_by) = &self.order_by {
if !self.order_by.is_empty() {
writeln!(f, "ORDER BY:")?;
for (expr, dir) in order_by {
for (expr, dir) in &self.order_by {
writeln!(
f,
" - {} {}",
@@ -301,7 +301,7 @@ impl ToTokens for Plan {
s.comma(
order_by.iter().map(|(expr, order)| ast::SortedColumn {
expr: expr.clone(),
expr: expr.clone().into(),
order: Some(*order),
nulls: None,
}),
@@ -368,7 +368,13 @@ impl ToTokens for SelectPlan {
context: &C,
) -> Result<(), S::Error> {
if !self.values.is_empty() {
ast::OneSelect::Values(self.values.clone()).to_tokens_with_context(s, context)?;
ast::OneSelect::Values(
self.values
.iter()
.map(|values| values.iter().map(|v| Box::from(v.clone())).collect())
.collect(),
)
.to_tokens_with_context(s, context)?;
} else {
s.append(TokenType::TK_SELECT, None)?;
if self.distinctness.is_distinct() {
@@ -436,12 +442,12 @@ impl ToTokens for SelectPlan {
}
}
if let Some(order_by) = &self.order_by {
if !self.order_by.is_empty() {
s.append(TokenType::TK_ORDER, None)?;
s.append(TokenType::TK_BY, None)?;
s.comma(
order_by.iter().map(|(expr, order)| ast::SortedColumn {
self.order_by.iter().map(|(expr, order)| ast::SortedColumn {
expr: expr.clone(),
order: Some(*order),
nulls: None,
@@ -498,12 +504,12 @@ impl ToTokens for DeletePlan {
}
}
if let Some(order_by) = &self.order_by {
if !self.order_by.is_empty() {
s.append(TokenType::TK_ORDER, None)?;
s.append(TokenType::TK_BY, None)?;
s.comma(
order_by.iter().map(|(expr, order)| ast::SortedColumn {
self.order_by.iter().map(|(expr, order)| ast::SortedColumn {
expr: expr.clone(),
order: Some(*order),
nulls: None,
@@ -556,7 +562,7 @@ impl ToTokens for UpdatePlan {
.unwrap();
ast::Set {
col_names: ast::Names::single(ast::Name::from_str(col_name)),
col_names: vec![ast::Name::new(col_name)],
expr: set_expr.clone(),
}
}),
@@ -579,12 +585,12 @@ impl ToTokens for UpdatePlan {
}
}
if let Some(order_by) = &self.order_by {
if !self.order_by.is_empty() {
s.append(TokenType::TK_ORDER, None)?;
s.append(TokenType::TK_BY, None)?;
s.comma(
order_by.iter().map(|(expr, order)| ast::SortedColumn {
self.order_by.iter().map(|(expr, order)| ast::SortedColumn {
expr: expr.clone(),
order: Some(*order),
nulls: None,

View File

@@ -287,12 +287,12 @@ pub fn emit_query<'a>(
}
// Initialize cursors and other resources needed for query execution
if let Some(ref mut order_by) = plan.order_by {
if !plan.order_by.is_empty() {
init_order_by(
program,
t_ctx,
&plan.result_columns,
order_by,
&plan.order_by,
&plan.table_references,
)?;
}
@@ -359,8 +359,9 @@ pub fn emit_query<'a>(
program.preassign_label_to_next_insn(after_main_loop_label);
let mut order_by_necessary = plan.order_by.is_some() && !plan.contains_constant_false_condition;
let order_by = plan.order_by.as_ref();
let mut order_by_necessary =
!plan.order_by.is_empty() && !plan.contains_constant_false_condition;
let order_by = &plan.order_by;
// Handle GROUP BY and aggregation processing
if plan.group_by.is_some() {
@@ -381,7 +382,7 @@ pub fn emit_query<'a>(
}
// Process ORDER BY results if needed
if order_by.is_some() && order_by_necessary {
if !order_by.is_empty() && order_by_necessary {
emit_order_by(program, t_ctx, plan)?;
}

View File

@@ -60,7 +60,8 @@ macro_rules! expect_arguments_exact {
$expected_arguments:expr,
$func:ident
) => {{
let args = if let Some(args) = $args {
let args = $args;
let args = if !args.is_empty() {
if args.len() != $expected_arguments {
crate::bail_parse_error!(
"{} function called with not exactly {} arguments",
@@ -83,7 +84,8 @@ macro_rules! expect_arguments_max {
$expected_arguments:expr,
$func:ident
) => {{
let args = if let Some(args) = $args {
let args = $args;
let args = if !args.is_empty() {
if args.len() > $expected_arguments {
crate::bail_parse_error!(
"{} function called with more than {} arguments",
@@ -106,7 +108,8 @@ macro_rules! expect_arguments_min {
$expected_arguments:expr,
$func:ident
) => {{
let args = if let Some(args) = $args {
let args = $args;
let args = if !args.is_empty() {
if args.len() < $expected_arguments {
crate::bail_parse_error!(
"{} function with less than {} arguments",
@@ -128,7 +131,7 @@ macro_rules! expect_arguments_even {
$args:expr,
$func:ident
) => {{
let args = $args.as_deref().unwrap_or_default();
let args = $args;
if args.len() % 2 != 0 {
crate::bail_parse_error!(
"{} function requires an even number of arguments",
@@ -151,7 +154,7 @@ fn translate_in_list(
program: &mut ProgramBuilder,
referenced_tables: Option<&TableReferences>,
lhs: &ast::Expr,
rhs: &Option<Vec<ast::Expr>>,
rhs: &[Box<ast::Expr>],
not: bool,
condition_metadata: ConditionMetadata,
resolver: &Resolver,
@@ -171,7 +174,7 @@ fn translate_in_list(
// which is what SQLite also does for small lists of values.
// TODO: Let's refactor this later to use a more efficient implementation conditionally based on the number of values.
if rhs.is_none() {
if rhs.is_empty() {
// If rhs is None, IN expressions are always false and NOT IN expressions are always true.
if not {
// On a trivially true NOT IN () expression we can only jump to the 'jump_target_when_true' label if 'jump_if_condition_is_true'; otherwise me must fall through.
@@ -195,8 +198,6 @@ fn translate_in_list(
let lhs_reg = program.alloc_register();
let _ = translate_expr(program, referenced_tables, lhs, lhs_reg, resolver)?;
let rhs = rhs.as_ref().unwrap();
// The difference between a local jump and an "upper level" jump is that for example in this case:
// WHERE foo IN (1,2,3) OR bar = 5,
// we can immediately jump to the 'jump_target_when_true' label of the ENTIRE CONDITION if foo = 1, foo = 2, or foo = 3 without evaluating the bar = 5 condition.
@@ -689,7 +690,7 @@ pub fn translate_expr(
// First translate inner expr, then set the curr collation. If we set curr collation before,
// it may be overwritten later by inner translate.
translate_expr(program, referenced_tables, expr, target_register, resolver)?;
let collation = CollationSeq::new(collation)?;
let collation = CollationSeq::new(collation.as_str())?;
program.set_collation(Some((collation, true)));
Ok(target_register)
}
@@ -702,7 +703,7 @@ pub fn translate_expr(
filter_over: _,
order_by: _,
} => {
let args_count = if let Some(args) = args { args.len() } else { 0 };
let args_count = args.len();
let func_type = resolver.resolve_function(name.as_str(), args_count);
if func_type.is_none() {
@@ -720,16 +721,8 @@ pub fn translate_expr(
}
Func::External(_) => {
let regs = program.alloc_registers(args_count);
if let Some(args) = args {
for (i, arg_expr) in args.iter().enumerate() {
translate_expr(
program,
referenced_tables,
arg_expr,
regs + i,
resolver,
)?;
}
for (i, arg_expr) in args.iter().enumerate() {
translate_expr(program, referenced_tables, arg_expr, regs + i, resolver)?;
}
// Use shared function call helper
@@ -764,7 +757,7 @@ pub fn translate_expr(
| JsonFunc::JsonInsert
| JsonFunc::JsonbInsert => translate_function(
program,
args.as_deref().unwrap_or_default(),
args,
referenced_tables,
resolver,
target_register,
@@ -788,20 +781,12 @@ pub fn translate_expr(
)
}
JsonFunc::JsonErrorPosition => {
let args = if let Some(args) = args {
if args.len() != 1 {
crate::bail_parse_error!(
"{} function with not exactly 1 argument",
j.to_string()
);
}
args
} else {
if args.len() != 1 {
crate::bail_parse_error!(
"{} function with no arguments",
"{} function with not exactly 1 argument",
j.to_string()
);
};
}
let json_reg = program.alloc_register();
translate_expr(program, referenced_tables, &args[0], json_reg, resolver)?;
program.emit_insn(Insn::Function {
@@ -826,7 +811,7 @@ pub fn translate_expr(
}
JsonFunc::JsonValid => translate_function(
program,
args.as_deref().unwrap_or_default(),
args,
referenced_tables,
resolver,
target_register,
@@ -844,19 +829,16 @@ pub fn translate_expr(
)
}
JsonFunc::JsonRemove => {
let start_reg =
program.alloc_registers(args.as_ref().map(|x| x.len()).unwrap_or(1));
if let Some(args) = args {
for (i, arg) in args.iter().enumerate() {
// register containing result of each argument expression
translate_expr(
program,
referenced_tables,
arg,
start_reg + i,
resolver,
)?;
}
let start_reg = program.alloc_registers(args.len().max(1));
for (i, arg) in args.iter().enumerate() {
// register containing result of each argument expression
translate_expr(
program,
referenced_tables,
arg,
start_reg + i,
resolver,
)?;
}
program.emit_insn(Insn::Function {
constant_mask: 0,
@@ -959,7 +941,7 @@ pub fn translate_expr(
unreachable!("this is always ast::Expr::Cast")
}
ScalarFunc::Changes => {
if args.is_some() {
if !args.is_empty() {
crate::bail_parse_error!(
"{} function with more than 0 arguments",
srf
@@ -976,7 +958,7 @@ pub fn translate_expr(
}
ScalarFunc::Char => translate_function(
program,
args.as_deref().unwrap_or_default(),
args,
referenced_tables,
resolver,
target_register,
@@ -1019,9 +1001,7 @@ pub fn translate_expr(
Ok(target_register)
}
ScalarFunc::Concat => {
let args = if let Some(args) = args {
args
} else {
if args.is_empty() {
crate::bail_parse_error!(
"{} function with no arguments",
srf.to_string()
@@ -1069,17 +1049,12 @@ pub fn translate_expr(
Ok(target_register)
}
ScalarFunc::IfNull => {
let args = match args {
Some(args) if args.len() == 2 => args,
Some(_) => crate::bail_parse_error!(
if args.len() != 2 {
crate::bail_parse_error!(
"{} function requires exactly 2 arguments",
srf.to_string()
),
None => crate::bail_parse_error!(
"{} function requires arguments",
srf.to_string()
),
};
);
}
let temp_reg = program.alloc_register();
translate_expr_no_constant_opt(
@@ -1114,13 +1089,12 @@ pub fn translate_expr(
Ok(target_register)
}
ScalarFunc::Iif => {
let args = match args {
Some(args) if args.len() == 3 => args,
_ => crate::bail_parse_error!(
if args.len() != 3 {
crate::bail_parse_error!(
"{} requires exactly 3 arguments",
srf.to_string()
),
};
);
}
let temp_reg = program.alloc_register();
translate_expr_no_constant_opt(
program,
@@ -1161,20 +1135,12 @@ pub fn translate_expr(
Ok(target_register)
}
ScalarFunc::Glob | ScalarFunc::Like => {
let args = if let Some(args) = args {
if args.len() < 2 {
crate::bail_parse_error!(
"{} function with less than 2 arguments",
srf.to_string()
);
}
args
} else {
if args.len() < 2 {
crate::bail_parse_error!(
"{} function with no arguments",
"{} function with less than 2 arguments",
srf.to_string()
);
};
}
let func_registers = program.alloc_registers(args.len());
for (i, arg) in args.iter().enumerate() {
let _ = translate_expr(
@@ -1245,7 +1211,7 @@ pub fn translate_expr(
Ok(target_register)
}
ScalarFunc::Random => {
if args.is_some() {
if !args.is_empty() {
crate::bail_parse_error!(
"{} function with arguments",
srf.to_string()
@@ -1261,19 +1227,16 @@ pub fn translate_expr(
Ok(target_register)
}
ScalarFunc::Date | ScalarFunc::DateTime | ScalarFunc::JulianDay => {
let start_reg = program
.alloc_registers(args.as_ref().map(|x| x.len()).unwrap_or(1));
if let Some(args) = args {
for (i, arg) in args.iter().enumerate() {
// register containing result of each argument expression
translate_expr(
program,
referenced_tables,
arg,
start_reg + i,
resolver,
)?;
}
let start_reg = program.alloc_registers(args.len().max(1));
for (i, arg) in args.iter().enumerate() {
// register containing result of each argument expression
translate_expr(
program,
referenced_tables,
arg,
start_reg + i,
resolver,
)?;
}
program.emit_insn(Insn::Function {
constant_mask: 0,
@@ -1284,20 +1247,12 @@ pub fn translate_expr(
Ok(target_register)
}
ScalarFunc::Substr | ScalarFunc::Substring => {
let args = if let Some(args) = args {
if !(args.len() == 2 || args.len() == 3) {
crate::bail_parse_error!(
"{} function with wrong number of arguments",
srf.to_string()
)
}
args
} else {
if !(args.len() == 2 || args.len() == 3) {
crate::bail_parse_error!(
"{} function with no arguments",
"{} function with wrong number of arguments",
srf.to_string()
);
};
)
}
let str_reg = program.alloc_register();
let start_reg = program.alloc_register();
@@ -1334,16 +1289,11 @@ pub fn translate_expr(
Ok(target_register)
}
ScalarFunc::Hex => {
let args = if let Some(args) = args {
if args.len() != 1 {
crate::bail_parse_error!(
"hex function must have exactly 1 argument",
);
}
args
} else {
crate::bail_parse_error!("hex function with no arguments",);
};
if args.len() != 1 {
crate::bail_parse_error!(
"hex function must have exactly 1 argument",
);
}
let start_reg = program.alloc_register();
translate_expr(
program,
@@ -1362,22 +1312,19 @@ pub fn translate_expr(
}
ScalarFunc::UnixEpoch => {
let mut start_reg = 0;
match args {
Some(args) if args.len() > 1 => {
crate::bail_parse_error!("epoch function with > 1 arguments. Modifiers are not yet supported.");
}
Some(args) if args.len() == 1 => {
let arg_reg = program.alloc_register();
let _ = translate_expr(
program,
referenced_tables,
&args[0],
arg_reg,
resolver,
)?;
start_reg = arg_reg;
}
_ => {}
if args.len() > 1 {
crate::bail_parse_error!("epoch function with > 1 arguments. Modifiers are not yet supported.");
}
if args.len() == 1 {
let arg_reg = program.alloc_register();
let _ = translate_expr(
program,
referenced_tables,
&args[0],
arg_reg,
resolver,
)?;
start_reg = arg_reg;
}
program.emit_insn(Insn::Function {
constant_mask: 0,
@@ -1388,19 +1335,16 @@ pub fn translate_expr(
Ok(target_register)
}
ScalarFunc::Time => {
let start_reg = program
.alloc_registers(args.as_ref().map(|x| x.len()).unwrap_or(1));
if let Some(args) = args {
for (i, arg) in args.iter().enumerate() {
// register containing result of each argument expression
translate_expr(
program,
referenced_tables,
arg,
start_reg + i,
resolver,
)?;
}
let start_reg = program.alloc_registers(args.len().max(1));
for (i, arg) in args.iter().enumerate() {
// register containing result of each argument expression
translate_expr(
program,
referenced_tables,
arg,
start_reg + i,
resolver,
)?;
}
program.emit_insn(Insn::Function {
constant_mask: 0,
@@ -1438,7 +1382,7 @@ pub fn translate_expr(
Ok(target_register)
}
ScalarFunc::TotalChanges => {
if args.is_some() {
if !args.is_empty() {
crate::bail_parse_error!(
"{} function with more than 0 arguments",
srf.to_string()
@@ -1479,16 +1423,9 @@ pub fn translate_expr(
Ok(target_register)
}
ScalarFunc::Min => {
let args = if let Some(args) = args {
if args.is_empty() {
crate::bail_parse_error!(
"min function with less than one argument"
);
}
args
} else {
if args.is_empty() {
crate::bail_parse_error!("min function with no arguments");
};
}
let start_reg = program.alloc_registers(args.len());
for (i, arg) in args.iter().enumerate() {
translate_expr(
@@ -1509,16 +1446,9 @@ pub fn translate_expr(
Ok(target_register)
}
ScalarFunc::Max => {
let args = if let Some(args) = args {
if args.is_empty() {
crate::bail_parse_error!(
"max function with less than one argument"
);
}
args
} else {
crate::bail_parse_error!("max function with no arguments");
};
if args.is_empty() {
crate::bail_parse_error!("min function with no arguments");
}
let start_reg = program.alloc_registers(args.len());
for (i, arg) in args.iter().enumerate() {
translate_expr(
@@ -1539,20 +1469,12 @@ pub fn translate_expr(
Ok(target_register)
}
ScalarFunc::Nullif | ScalarFunc::Instr => {
let args = if let Some(args) = args {
if args.len() != 2 {
crate::bail_parse_error!(
"{} function must have two argument",
srf.to_string()
);
}
args
} else {
if args.len() != 2 {
crate::bail_parse_error!(
"{} function with no arguments",
"{} function must have two argument",
srf.to_string()
);
};
}
let first_reg = program.alloc_register();
translate_expr(
@@ -1580,7 +1502,7 @@ pub fn translate_expr(
Ok(target_register)
}
ScalarFunc::SqliteVersion => {
if args.is_some() {
if !args.is_empty() {
crate::bail_parse_error!("sqlite_version function with arguments");
}
@@ -1600,7 +1522,7 @@ pub fn translate_expr(
Ok(target_register)
}
ScalarFunc::SqliteSourceId => {
if args.is_some() {
if !args.is_empty() {
crate::bail_parse_error!(
"sqlite_source_id function with arguments"
);
@@ -1622,20 +1544,13 @@ pub fn translate_expr(
Ok(target_register)
}
ScalarFunc::Replace => {
let args = if let Some(args) = args {
if !args.len() == 3 {
crate::bail_parse_error!(
"function {}() requires exactly 3 arguments",
srf.to_string()
)
}
args
} else {
if !args.len() == 3 {
crate::bail_parse_error!(
"function {}() requires exactly 3 arguments",
srf.to_string()
);
};
)
}
let str_reg = program.alloc_register();
let pattern_reg = program.alloc_register();
let replacement_reg = program.alloc_register();
@@ -1669,19 +1584,16 @@ pub fn translate_expr(
Ok(target_register)
}
ScalarFunc::StrfTime => {
let start_reg = program
.alloc_registers(args.as_ref().map(|x| x.len()).unwrap_or(1));
if let Some(args) = args {
for (i, arg) in args.iter().enumerate() {
// register containing result of each argument expression
translate_expr(
program,
referenced_tables,
arg,
start_reg + i,
resolver,
)?;
}
let start_reg = program.alloc_registers(args.len().max(1));
for (i, arg) in args.iter().enumerate() {
// register containing result of each argument expression
translate_expr(
program,
referenced_tables,
arg,
start_reg + i,
resolver,
)?;
}
program.emit_insn(Insn::Function {
constant_mask: 0,
@@ -1693,23 +1605,18 @@ pub fn translate_expr(
}
ScalarFunc::Printf => translate_function(
program,
args.as_deref().unwrap_or(&[]),
args,
referenced_tables,
resolver,
target_register,
func_ctx,
),
ScalarFunc::Likely => {
let args = if let Some(args) = args {
if args.len() != 1 {
crate::bail_parse_error!(
"likely function must have exactly 1 argument",
);
}
args
} else {
crate::bail_parse_error!("likely function with no arguments",);
};
if args.len() != 1 {
crate::bail_parse_error!(
"likely function must have exactly 1 argument",
);
}
translate_expr(
program,
referenced_tables,
@@ -1720,18 +1627,15 @@ pub fn translate_expr(
Ok(target_register)
}
ScalarFunc::Likelihood => {
let args = if let Some(args) = args {
if args.len() != 2 {
crate::bail_parse_error!(
"likelihood() function must have exactly 2 arguments",
);
}
args
} else {
crate::bail_parse_error!("likelihood() function with no arguments",);
};
if args.len() != 2 {
crate::bail_parse_error!(
"likelihood() function must have exactly 2 arguments",
);
}
if let ast::Expr::Literal(ast::Literal::Numeric(ref value)) = args[1] {
if let ast::Expr::Literal(ast::Literal::Numeric(ref value)) =
args[1].as_ref()
{
if let Ok(probability) = value.parse::<f64>() {
if !(0.0..=1.0).contains(&probability) {
crate::bail_parse_error!(
@@ -1763,12 +1667,11 @@ pub fn translate_expr(
Ok(target_register)
}
ScalarFunc::TableColumnsJsonArray => {
if args.is_none() || args.as_ref().unwrap().len() != 1 {
if args.len() != 1 {
crate::bail_parse_error!(
"table_columns_json_array() function must have exactly 1 argument",
);
}
let args = args.as_ref().unwrap();
let start_reg = program.alloc_register();
translate_expr(
program,
@@ -1786,12 +1689,11 @@ pub fn translate_expr(
Ok(target_register)
}
ScalarFunc::BinRecordJsonObject => {
if args.is_none() || args.as_ref().unwrap().len() != 2 {
if args.len() != 2 {
crate::bail_parse_error!(
"bin_record_json_object() function must have exactly 2 arguments",
);
}
let args = args.as_ref().unwrap();
let start_reg = program.alloc_registers(2);
translate_expr(
program,
@@ -1828,16 +1730,11 @@ pub fn translate_expr(
);
}
ScalarFunc::Unlikely => {
let args = if let Some(args) = args {
if args.len() != 1 {
crate::bail_parse_error!(
"Unlikely function must have exactly 1 argument",
);
}
args
} else {
crate::bail_parse_error!("Unlikely function with no arguments",);
};
if args.len() != 1 {
crate::bail_parse_error!(
"Unlikely function must have exactly 1 argument",
);
}
translate_expr(
program,
referenced_tables,
@@ -1852,7 +1749,7 @@ pub fn translate_expr(
}
Func::Math(math_func) => match math_func.arity() {
MathFuncArity::Nullary => {
if args.is_some() {
if !args.is_empty() {
crate::bail_parse_error!("{} function with arguments", math_func);
}
@@ -1924,7 +1821,7 @@ pub fn translate_expr(
Func::AlterTable(_) => unreachable!(),
}
}
ast::Expr::FunctionCallStar { .. } => todo!(),
ast::Expr::FunctionCallStar { .. } => todo!("{:?}", &expr),
ast::Expr::Id(id) => {
// Treat double-quoted identifiers as string literals (SQLite compatibility)
program.emit_insn(Insn::String8 {
@@ -2639,7 +2536,7 @@ fn translate_like_base(
/// Returns the target register for the function.
fn translate_function(
program: &mut ProgramBuilder,
args: &[ast::Expr],
args: &[Box<ast::Expr>],
referenced_tables: Option<&TableReferences>,
resolver: &Resolver,
target_register: usize,
@@ -2765,7 +2662,7 @@ pub fn unwrap_parens_owned(expr: ast::Expr) -> Result<(ast::Expr, usize)> {
ast::Expr::Parenthesized(mut exprs) => match exprs.len() {
1 => {
paren_count += 1;
let (expr, count) = unwrap_parens_owned(exprs.pop().unwrap())?;
let (expr, count) = unwrap_parens_owned(*exprs.pop().unwrap().clone())?;
paren_count += count;
Ok((expr, paren_count))
}
@@ -2830,81 +2727,63 @@ where
filter_over,
..
} => {
if let Some(args) = args {
for arg in args {
walk_expr(arg, func)?;
}
for arg in args {
walk_expr(arg, func)?;
}
if let Some(order_by) = order_by {
for sort_col in order_by {
walk_expr(&sort_col.expr, func)?;
}
for sort_col in order_by {
walk_expr(&sort_col.expr, func)?;
}
if let Some(filter_over) = filter_over {
if let Some(filter_clause) = &filter_over.filter_clause {
walk_expr(filter_clause, func)?;
}
if let Some(over_clause) = &filter_over.over_clause {
match over_clause.as_ref() {
ast::Over::Window(window) => {
if let Some(partition_by) = &window.partition_by {
for part_expr in partition_by {
walk_expr(part_expr, func)?;
}
}
if let Some(order_by_clause) = &window.order_by {
for sort_col in order_by_clause {
walk_expr(&sort_col.expr, func)?;
}
}
if let Some(frame_clause) = &window.frame_clause {
walk_expr_frame_bound(&frame_clause.start, func)?;
if let Some(end_bound) = &frame_clause.end {
walk_expr_frame_bound(end_bound, func)?;
}
if let Some(filter_clause) = &filter_over.filter_clause {
walk_expr(filter_clause, func)?;
}
if let Some(over_clause) = &filter_over.over_clause {
match over_clause {
ast::Over::Window(window) => {
for part_expr in &window.partition_by {
walk_expr(part_expr, func)?;
}
for sort_col in &window.order_by {
walk_expr(&sort_col.expr, func)?;
}
if let Some(frame_clause) = &window.frame_clause {
walk_expr_frame_bound(&frame_clause.start, func)?;
if let Some(end_bound) = &frame_clause.end {
walk_expr_frame_bound(end_bound, func)?;
}
}
ast::Over::Name(_) => {}
}
ast::Over::Name(_) => {}
}
}
}
ast::Expr::FunctionCallStar { filter_over, .. } => {
if let Some(filter_over) = filter_over {
if let Some(filter_clause) = &filter_over.filter_clause {
walk_expr(filter_clause, func)?;
}
if let Some(over_clause) = &filter_over.over_clause {
match over_clause.as_ref() {
ast::Over::Window(window) => {
if let Some(partition_by) = &window.partition_by {
for part_expr in partition_by {
walk_expr(part_expr, func)?;
}
}
if let Some(order_by_clause) = &window.order_by {
for sort_col in order_by_clause {
walk_expr(&sort_col.expr, func)?;
}
}
if let Some(frame_clause) = &window.frame_clause {
walk_expr_frame_bound(&frame_clause.start, func)?;
if let Some(end_bound) = &frame_clause.end {
walk_expr_frame_bound(end_bound, func)?;
}
if let Some(filter_clause) = &filter_over.filter_clause {
walk_expr(filter_clause, func)?;
}
if let Some(over_clause) = &filter_over.over_clause {
match over_clause {
ast::Over::Window(window) => {
for part_expr in &window.partition_by {
walk_expr(part_expr, func)?;
}
for sort_col in &window.order_by {
walk_expr(&sort_col.expr, func)?;
}
if let Some(frame_clause) = &window.frame_clause {
walk_expr_frame_bound(&frame_clause.start, func)?;
if let Some(end_bound) = &frame_clause.end {
walk_expr_frame_bound(end_bound, func)?;
}
}
ast::Over::Name(_) => {}
}
ast::Over::Name(_) => {}
}
}
}
ast::Expr::InList { lhs, rhs, .. } => {
walk_expr(lhs, func)?;
if let Some(rhs_exprs) = rhs {
for expr in rhs_exprs {
walk_expr(expr, func)?;
}
for expr in rhs {
walk_expr(expr, func)?;
}
}
ast::Expr::InSelect { lhs, rhs: _, .. } => {
@@ -2913,10 +2792,8 @@ where
}
ast::Expr::InTable { lhs, args, .. } => {
walk_expr(lhs, func)?;
if let Some(arg_exprs) = args {
for expr in arg_exprs {
walk_expr(expr, func)?;
}
for expr in args {
walk_expr(expr, func)?;
}
}
ast::Expr::IsNull(expr) | ast::Expr::NotNull(expr) => {
@@ -3026,81 +2903,63 @@ where
filter_over,
..
} => {
if let Some(args) = args {
for arg in args {
walk_expr_mut(arg, func)?;
}
for arg in args {
walk_expr_mut(arg, func)?;
}
if let Some(order_by) = order_by {
for sort_col in order_by {
walk_expr_mut(&mut sort_col.expr, func)?;
}
for sort_col in order_by {
walk_expr_mut(&mut sort_col.expr, func)?;
}
if let Some(filter_over) = filter_over {
if let Some(filter_clause) = &mut filter_over.filter_clause {
walk_expr_mut(filter_clause, func)?;
}
if let Some(over_clause) = &mut filter_over.over_clause {
match over_clause.as_mut() {
ast::Over::Window(window) => {
if let Some(partition_by) = &mut window.partition_by {
for part_expr in partition_by {
walk_expr_mut(part_expr, func)?;
}
}
if let Some(order_by_clause) = &mut window.order_by {
for sort_col in order_by_clause {
walk_expr_mut(&mut sort_col.expr, func)?;
}
}
if let Some(frame_clause) = &mut window.frame_clause {
walk_expr_mut_frame_bound(&mut frame_clause.start, func)?;
if let Some(end_bound) = &mut frame_clause.end {
walk_expr_mut_frame_bound(end_bound, func)?;
}
if let Some(filter_clause) = &mut filter_over.filter_clause {
walk_expr_mut(filter_clause, func)?;
}
if let Some(over_clause) = &mut filter_over.over_clause {
match over_clause {
ast::Over::Window(window) => {
for part_expr in &mut window.partition_by {
walk_expr_mut(part_expr, func)?;
}
for sort_col in &mut window.order_by {
walk_expr_mut(&mut sort_col.expr, func)?;
}
if let Some(frame_clause) = &mut window.frame_clause {
walk_expr_mut_frame_bound(&mut frame_clause.start, func)?;
if let Some(end_bound) = &mut frame_clause.end {
walk_expr_mut_frame_bound(end_bound, func)?;
}
}
ast::Over::Name(_) => {}
}
ast::Over::Name(_) => {}
}
}
}
ast::Expr::FunctionCallStar { filter_over, .. } => {
if let Some(filter_over) = filter_over {
if let Some(filter_clause) = &mut filter_over.filter_clause {
walk_expr_mut(filter_clause, func)?;
}
if let Some(over_clause) = &mut filter_over.over_clause {
match over_clause.as_mut() {
ast::Over::Window(window) => {
if let Some(partition_by) = &mut window.partition_by {
for part_expr in partition_by {
walk_expr_mut(part_expr, func)?;
}
}
if let Some(order_by_clause) = &mut window.order_by {
for sort_col in order_by_clause {
walk_expr_mut(&mut sort_col.expr, func)?;
}
}
if let Some(frame_clause) = &mut window.frame_clause {
walk_expr_mut_frame_bound(&mut frame_clause.start, func)?;
if let Some(end_bound) = &mut frame_clause.end {
walk_expr_mut_frame_bound(end_bound, func)?;
}
if let Some(ref mut filter_clause) = filter_over.filter_clause {
walk_expr_mut(filter_clause, func)?;
}
if let Some(ref mut over_clause) = filter_over.over_clause {
match over_clause {
ast::Over::Window(window) => {
for part_expr in &mut window.partition_by {
walk_expr_mut(part_expr, func)?;
}
for sort_col in &mut window.order_by {
walk_expr_mut(&mut sort_col.expr, func)?;
}
if let Some(frame_clause) = &mut window.frame_clause {
walk_expr_mut_frame_bound(&mut frame_clause.start, func)?;
if let Some(end_bound) = &mut frame_clause.end {
walk_expr_mut_frame_bound(end_bound, func)?;
}
}
ast::Over::Name(_) => {}
}
ast::Over::Name(_) => {}
}
}
}
ast::Expr::InList { lhs, rhs, .. } => {
walk_expr_mut(lhs, func)?;
if let Some(rhs_exprs) = rhs {
for expr in rhs_exprs {
walk_expr_mut(expr, func)?;
}
for expr in rhs {
walk_expr_mut(expr, func)?;
}
}
ast::Expr::InSelect { lhs, rhs: _, .. } => {
@@ -3109,10 +2968,8 @@ where
}
ast::Expr::InTable { lhs, args, .. } => {
walk_expr_mut(lhs, func)?;
if let Some(arg_exprs) = args {
for expr in arg_exprs {
walk_expr_mut(expr, func)?;
}
for expr in args {
walk_expr_mut(expr, func)?;
}
}
ast::Expr::IsNull(expr) | ast::Expr::NotNull(expr) => {
@@ -3317,12 +3174,10 @@ pub fn translate_expr_for_returning(
Expr::FunctionCall { name, args, .. } => {
// Evaluate arguments into registers
let mut arg_regs = Vec::new();
if let Some(args) = args {
for arg in args.iter() {
let arg_reg = program.alloc_register();
translate_expr_for_returning(program, arg, value_registers, arg_reg)?;
arg_regs.push(arg_reg);
}
for arg in args.iter() {
let arg_reg = program.alloc_register();
translate_expr_for_returning(program, arg, value_registers, arg_reg)?;
arg_regs.push(arg_reg);
}
// Resolve and call the function using shared helper
@@ -3492,7 +3347,7 @@ pub fn process_returning_clause(
bind_column_references(expr, &mut table_references, None, connection)?;
result_columns.push(ResultSetColumn {
expr: expr.clone(),
expr: *expr.clone(),
alias: column_alias,
contains_aggregates: false,
});

View File

@@ -85,7 +85,7 @@ pub fn init_group_by<'a>(
group_by: &'a GroupBy,
plan: &SelectPlan,
result_columns: &'a [ResultSetColumn],
order_by: &'a Option<Vec<(ast::Expr, ast::SortOrder)>>,
order_by: &'a [(Box<ast::Expr>, ast::SortOrder)],
) -> Result<()> {
collect_non_aggregate_expressions(
&mut t_ctx.non_aggregate_expressions,
@@ -141,7 +141,7 @@ pub fn init_group_by<'a>(
.iter()
.map(|expr| match expr {
ast::Expr::Collate(_, collation_name) => {
CollationSeq::new(collation_name).map(Some)
CollationSeq::new(collation_name.as_str()).map(Some)
}
ast::Expr::Column { table, column, .. } => {
let table_reference = plan
@@ -238,13 +238,13 @@ fn collect_non_aggregate_expressions<'a>(
group_by: &'a GroupBy,
plan: &SelectPlan,
root_result_columns: &'a [ResultSetColumn],
order_by: &'a Option<Vec<(ast::Expr, ast::SortOrder)>>,
order_by: &'a [(Box<ast::Expr>, ast::SortOrder)],
) -> Result<()> {
let mut result_columns = Vec::new();
for expr in root_result_columns
.iter()
.map(|col| &col.expr)
.chain(order_by.iter().flat_map(|o| o.iter().map(|(e, _)| e)))
.chain(order_by.iter().map(|(e, _)| e.as_ref()))
.chain(group_by.having.iter().flatten())
{
collect_result_columns(expr, plan, &mut result_columns)?;
@@ -821,8 +821,8 @@ pub fn group_by_emit_row_phase<'a>(
}
}
match &plan.order_by {
None => {
match plan.order_by.is_empty() {
true => {
emit_select_result(
program,
&t_ctx.resolver,
@@ -835,7 +835,7 @@ pub fn group_by_emit_row_phase<'a>(
t_ctx.limit_ctx,
)?;
}
Some(_) => {
false => {
order_by_sorter_insert(
program,
&t_ctx.resolver,
@@ -954,8 +954,8 @@ pub fn translate_aggregation_step_groupby(
if num_args == 2 {
match &agg_arg_source.args()[1] {
ast::Expr::Column { .. } => {
delimiter_expr = agg_arg_source.args()[1].clone();
arg @ ast::Expr::Column { .. } => {
delimiter_expr = arg.clone();
}
ast::Expr::Literal(ast::Literal::String(s)) => {
delimiter_expr = ast::Expr::Literal(ast::Literal::String(s.to_string()));
@@ -1054,7 +1054,7 @@ pub fn translate_aggregation_step_groupby(
let delimiter_reg = program.alloc_register();
let delimiter_expr = match &agg_arg_source.args()[1] {
ast::Expr::Column { .. } => agg_arg_source.args()[1].clone(),
arg @ ast::Expr::Column { .. } => arg.clone(),
ast::Expr::Literal(ast::Literal::String(s)) => {
ast::Expr::Literal(ast::Literal::String(s.to_string()))
}

View File

@@ -255,7 +255,7 @@ fn resolve_sorted_columns<'a>(
) -> crate::Result<Vec<((usize, &'a Column), SortOrder)>> {
let mut resolved = Vec::with_capacity(cols.len());
for sc in cols {
let ident = normalize_ident(match &sc.expr {
let ident = normalize_ident(match sc.expr.as_ref() {
// SQLite supports indexes on arbitrary expressions, but we don't (yet).
// See "How to use indexes on expressions" in https://www.sqlite.org/expridx.html
Expr::Id(ast::Name::Ident(col_name))

View File

@@ -1,8 +1,7 @@
use std::sync::Arc;
use turso_parser::ast::{
self, Expr, InsertBody, OneSelect, QualifiedName, ResolveType, ResultColumn,
With,
self, Expr, InsertBody, OneSelect, QualifiedName, ResolveType, ResultColumn, With,
};
use crate::error::{SQLITE_CONSTRAINT_NOTNULL, SQLITE_CONSTRAINT_PRIMARYKEY};
@@ -13,7 +12,6 @@ use crate::translate::emitter::{
use crate::translate::expr::{
emit_returning_results, process_returning_clause, ReturningValueRegisters,
};
use crate::translate::plan::TableReferences;
use crate::translate::planner::ROWID;
use crate::util::normalize_ident;
use crate::vdbe::builder::ProgramBuilderOpts;
@@ -46,9 +44,9 @@ pub fn translate_insert(
with: Option<With>,
on_conflict: Option<ResolveType>,
tbl_name: QualifiedName,
columns: Option<DistinctNames>,
columns: Vec<ast::Name>,
mut body: InsertBody,
mut returning: Option<Vec<ResultColumn>>,
mut returning: Vec<ResultColumn>,
syms: &SymbolTable,
mut program: ProgramBuilder,
connection: &Arc<crate::Connection>,
@@ -102,9 +100,9 @@ pub fn translate_insert(
let root_page = btree_table.root_page;
let mut values: Option<Vec<Expr>> = None;
let mut values: Option<Vec<Box<Expr>>> = None;
let inserting_multiple_rows = match &mut body {
InsertBody::Select(select, _) => match select.body.select.as_mut() {
InsertBody::Select(select, _) => match &mut select.body.select {
// TODO see how to avoid clone
OneSelect::Values(values_expr) if values_expr.len() <= 1 => {
if values_expr.is_empty() {
@@ -112,10 +110,11 @@ pub fn translate_insert(
}
let mut param_idx = 1;
for expr in values_expr.iter_mut().flat_map(|v| v.iter_mut()) {
match expr {
match expr.as_mut() {
Expr::Id(name) => {
if name.is_double_quoted() {
*expr = Expr::Literal(ast::Literal::String(format!("{name}")));
*expr =
Expr::Literal(ast::Literal::String(format!("{name}"))).into();
} else {
// an INSERT INTO ... VALUES (...) cannot reference columns
crate::bail_parse_error!("no such column: {name}");
@@ -143,17 +142,13 @@ pub fn translate_insert(
let cdc_table = prepare_cdc_if_necessary(&mut program, schema, table.get_name())?;
// Process RETURNING clause using shared module
let (result_columns, _) = if let Some(returning) = &mut returning {
process_returning_clause(
returning,
&table,
table_name.as_str(),
&mut program,
connection,
)?
} else {
(vec![], TableReferences::new(vec![], vec![]))
};
let (result_columns, _) = process_returning_clause(
&mut returning,
&table,
table_name.as_str(),
&mut program,
connection,
)?;
// Set up the program to return result columns if RETURNING is specified
if !result_columns.is_empty() {
@@ -166,8 +161,7 @@ pub fn translate_insert(
// TODO: upsert
InsertBody::Select(select, _) => {
// Simple Common case of INSERT INTO <table> VALUES (...)
if matches!(select.body.select.as_ref(), OneSelect::Values(values) if values.len() <= 1)
{
if matches!(&select.body.select, OneSelect::Values(values) if values.len() <= 1) {
(
values.as_ref().unwrap().len(),
program.alloc_cursor_id(CursorType::BTreeTable(btree_table.clone())),
@@ -190,14 +184,8 @@ pub fn translate_insert(
coroutine_implementation_start: halt_label,
};
program.incr_nesting();
let result = translate_select(
schema,
*select,
syms,
program,
query_destination,
connection,
)?;
let result =
translate_select(schema, select, syms, program, query_destination, connection)?;
program = result.program;
program.decr_nesting();
@@ -721,7 +709,7 @@ struct ColMapping<'a> {
fn build_insertion<'a>(
program: &mut ProgramBuilder,
table: &'a Table,
columns: &Option<DistinctNames>,
columns: &'a [ast::Name],
num_values: usize,
) -> Result<Insertion<'a>> {
let table_columns = table.columns();
@@ -739,7 +727,7 @@ fn build_insertion<'a>(
})
.collect::<Vec<_>>();
if columns.is_none() {
if columns.is_empty() {
// Case 1: No columns specified - map values to columns in order
if num_values != table_columns.iter().filter(|c| !c.hidden).count() {
crate::bail_parse_error!(
@@ -769,7 +757,7 @@ fn build_insertion<'a>(
} else {
// Case 2: Columns specified - map named columns to their values
// Map each named column to its value index
for (value_index, column_name) in columns.as_ref().unwrap().iter().enumerate() {
for (value_index, column_name) in columns.iter().enumerate() {
let column_name = normalize_ident(column_name.as_str());
if let Some((idx_in_table, col_in_table)) = table.get_column_by_name(&column_name) {
// Named column
@@ -850,7 +838,7 @@ fn translate_rows_multiple<'short, 'long: 'short>(
#[allow(clippy::too_many_arguments)]
fn translate_rows_single(
program: &mut ProgramBuilder,
value: &[Expr],
value: &[Box<Expr>],
insertion: &Insertion,
resolver: &Resolver,
) -> Result<()> {
@@ -976,7 +964,7 @@ fn translate_column(
fn translate_virtual_table_insert(
mut program: ProgramBuilder,
virtual_table: Arc<VirtualTable>,
columns: Option<DistinctNames>,
columns: Vec<ast::Name>,
mut body: InsertBody,
on_conflict: Option<ResolveType>,
resolver: &Resolver,
@@ -985,7 +973,7 @@ fn translate_virtual_table_insert(
crate::bail_constraint_error!("Table is read-only: {}", virtual_table.name);
}
let (num_values, value) = match &mut body {
InsertBody::Select(select, None) => match select.body.select.as_mut() {
InsertBody::Select(select, None) => match &mut select.body.select {
OneSelect::Values(values) => (values[0].len(), values.pop().unwrap()),
_ => crate::bail_parse_error!("Virtual tables only support VALUES clause in INSERT"),
},

View File

@@ -741,7 +741,7 @@ pub fn emit_loop(
return emit_loop_source(program, t_ctx, plan, LoopEmitTarget::AggStep);
}
// if we DONT have a group by, but we have an order by, we emit a record into the order by sorter.
if plan.order_by.is_some() {
if !plan.order_by.is_empty() {
return emit_loop_source(program, t_ctx, plan, LoopEmitTarget::OrderBySorter);
}
// if we have neither, we emit a ResultRow. In that case, if we have a Limit, we handle that with DecrJumpZero.

View File

@@ -70,9 +70,9 @@ pub fn translate(
let change_cnt_on = matches!(
stmt,
ast::Stmt::CreateIndex { .. }
| ast::Stmt::Delete(..)
| ast::Stmt::Insert(..)
| ast::Stmt::Update(..)
| ast::Stmt::Delete { .. }
| ast::Stmt::Insert { .. }
| ast::Stmt::Update { .. }
);
let mut program = ProgramBuilder::new(
@@ -90,11 +90,11 @@ pub fn translate(
program = match stmt {
// There can be no nesting with pragma, so lift it up here
ast::Stmt::Pragma(name, body) => pragma::translate_pragma(
ast::Stmt::Pragma { name, body } => pragma::translate_pragma(
schema,
syms,
&name,
body.map(|b| *b),
body,
pager,
connection.clone(),
program,
@@ -120,20 +120,20 @@ pub fn translate_inner(
) -> Result<ProgramBuilder> {
let is_write = matches!(
stmt,
ast::Stmt::AlterTable(..)
ast::Stmt::AlterTable { .. }
| ast::Stmt::CreateIndex { .. }
| ast::Stmt::CreateTable { .. }
| ast::Stmt::CreateTrigger { .. }
| ast::Stmt::CreateView { .. }
| ast::Stmt::CreateMaterializedView { .. }
| ast::Stmt::CreateVirtualTable(..)
| ast::Stmt::Delete(..)
| ast::Stmt::Delete { .. }
| ast::Stmt::DropIndex { .. }
| ast::Stmt::DropTable { .. }
| ast::Stmt::DropView { .. }
| ast::Stmt::Reindex { .. }
| ast::Stmt::Update(..)
| ast::Stmt::Insert(..)
| ast::Stmt::Update { .. }
| ast::Stmt::Insert { .. }
);
if is_write && connection.get_query_only() {
@@ -144,16 +144,14 @@ pub fn translate_inner(
let mut program = match stmt {
ast::Stmt::AlterTable(alter) => {
translate_alter_table(*alter, syms, schema, program, connection, input)?
translate_alter_table(alter, syms, schema, program, connection, input)?
}
ast::Stmt::Analyze(_) => bail_parse_error!("ANALYZE not supported yet"),
ast::Stmt::Analyze { .. } => bail_parse_error!("ANALYZE not supported yet"),
ast::Stmt::Attach { expr, db_name, key } => {
attach::translate_attach(&expr, &db_name, &key, schema, syms, program)?
}
ast::Stmt::Begin(tx_type, tx_name) => {
translate_tx_begin(tx_type, tx_name, schema, program)?
}
ast::Stmt::Commit(tx_name) => translate_tx_commit(tx_name, program)?,
ast::Stmt::Begin { typ, name } => translate_tx_begin(typ, name, schema, program)?,
ast::Stmt::Commit { name } => translate_tx_commit(name, program)?,
ast::Stmt::CreateIndex {
unique,
if_not_exists,
@@ -183,7 +181,7 @@ pub fn translate_inner(
} => translate_create_table(
tbl_name,
temporary,
*body,
body,
if_not_exists,
schema,
syms,
@@ -199,7 +197,7 @@ pub fn translate_inner(
schema,
view_name.name.as_str(),
&select,
columns.as_ref(),
&columns,
connection.clone(),
syms,
program,
@@ -215,25 +213,24 @@ pub fn translate_inner(
program,
)?,
ast::Stmt::CreateVirtualTable(vtab) => {
translate_create_virtual_table(*vtab, schema, syms, program)?
translate_create_virtual_table(vtab, schema, syms, program)?
}
ast::Stmt::Delete(delete) => {
let Delete {
tbl_name,
where_clause,
limit,
returning,
indexed,
order_by,
with,
} = *delete;
ast::Stmt::Delete {
tbl_name,
where_clause,
limit,
returning,
indexed,
order_by,
with,
} => {
if with.is_some() {
bail_parse_error!("WITH clause is not supported in DELETE");
}
if indexed.is_some_and(|i| matches!(i, Indexed::IndexedBy(_))) {
bail_parse_error!("INDEXED BY clause is not supported in DELETE");
}
if order_by.is_some() {
if !order_by.is_empty() {
bail_parse_error!("ORDER BY clause is not supported in DELETE");
}
translate_delete(
@@ -247,7 +244,7 @@ pub fn translate_inner(
connection,
)?
}
ast::Stmt::Detach(expr) => attach::translate_detach(&expr, schema, syms, program)?,
ast::Stmt::Detach { name } => attach::translate_detach(&name, schema, syms, program)?,
ast::Stmt::DropIndex {
if_exists,
idx_name,
@@ -261,20 +258,20 @@ pub fn translate_inner(
if_exists,
view_name,
} => view::translate_drop_view(schema, view_name.name.as_str(), if_exists, program)?,
ast::Stmt::Pragma(..) => {
ast::Stmt::Pragma { .. } => {
bail_parse_error!("PRAGMA statement cannot be evaluated in a nested context")
}
ast::Stmt::Reindex { .. } => bail_parse_error!("REINDEX not supported yet"),
ast::Stmt::Release(_) => bail_parse_error!("RELEASE not supported yet"),
ast::Stmt::Release { .. } => bail_parse_error!("RELEASE not supported yet"),
ast::Stmt::Rollback {
tx_name,
savepoint_name,
} => translate_rollback(schema, syms, program, tx_name, savepoint_name)?,
ast::Stmt::Savepoint(_) => bail_parse_error!("SAVEPOINT not supported yet"),
ast::Stmt::Savepoint { .. } => bail_parse_error!("SAVEPOINT not supported yet"),
ast::Stmt::Select(select) => {
translate_select(
schema,
*select,
select,
syms,
program,
plan::QueryDestination::ResultRows,
@@ -285,29 +282,26 @@ pub fn translate_inner(
ast::Stmt::Update(mut update) => {
translate_update(schema, &mut update, syms, program, connection)?
}
ast::Stmt::Vacuum(_, _) => bail_parse_error!("VACUUM not supported yet"),
ast::Stmt::Insert(insert) => {
let Insert {
with,
or_conflict,
tbl_name,
columns,
body,
returning,
} = *insert;
translate_insert(
schema,
with,
or_conflict,
tbl_name,
columns,
body,
returning,
syms,
program,
connection,
)?
}
ast::Stmt::Vacuum { .. } => bail_parse_error!("VACUUM not supported yet"),
ast::Stmt::Insert {
with,
or_conflict,
tbl_name,
columns,
body,
returning,
} => translate_insert(
schema,
with,
or_conflict,
tbl_name,
columns,
body,
returning,
syms,
program,
connection,
)?,
};
// Indicate write operations so that in the epilogue we can emit the correct type of transaction

View File

@@ -719,7 +719,7 @@ mod tests {
t2.clone(),
Some(JoinInfo {
outer: false,
using: None,
using: vec![],
}),
table_id_counter.next(),
),
@@ -823,7 +823,7 @@ mod tests {
table_customers.clone(),
Some(JoinInfo {
outer: false,
using: None,
using: vec![],
}),
table_id_counter.next(),
),
@@ -831,7 +831,7 @@ mod tests {
table_order_items.clone(),
Some(JoinInfo {
outer: false,
using: None,
using: vec![],
}),
table_id_counter.next(),
),
@@ -1007,7 +1007,7 @@ mod tests {
t2.clone(),
Some(JoinInfo {
outer: false,
using: None,
using: vec![],
}),
table_id_counter.next(),
),
@@ -1015,7 +1015,7 @@ mod tests {
t3.clone(),
Some(JoinInfo {
outer: false,
using: None,
using: vec![],
}),
table_id_counter.next(),
),
@@ -1113,7 +1113,7 @@ mod tests {
t.clone(),
Some(JoinInfo {
outer: false,
using: None,
using: vec![],
}),
table_id_counter.next(),
)
@@ -1122,7 +1122,7 @@ mod tests {
fact_table.clone(),
Some(JoinInfo {
outer: false,
using: None,
using: vec![],
}),
table_id_counter.next(),
));

View File

@@ -104,7 +104,7 @@ pub(crate) fn lift_common_subexpressions_from_binary_or_terms(
// If we unwrapped parentheses before, let's add them back.
let mut top_level_expr = rebuild_and_expr_from_list(conjunct_list_for_or_branch);
while num_unwrapped_parens > 0 {
top_level_expr = Expr::Parenthesized(vec![top_level_expr]);
top_level_expr = Expr::Parenthesized(vec![top_level_expr.into()]);
num_unwrapped_parens -= 1;
}
new_or_operands_for_original_term.push(top_level_expr);
@@ -246,11 +246,13 @@ mod tests {
let or_expr = Expr::Binary(
Box::new(ast::Expr::Parenthesized(vec![rebuild_and_expr_from_list(
vec![a_expr.clone(), x_expr.clone(), b_expr.clone()],
)])),
)
.into()])),
Operator::Or,
Box::new(ast::Expr::Parenthesized(vec![rebuild_and_expr_from_list(
vec![a_expr.clone(), y_expr.clone(), b_expr.clone()],
)])),
)
.into()])),
);
let mut where_clause = vec![WhereTerm {
@@ -273,9 +275,9 @@ mod tests {
assert_eq!(
nonconsumed_terms[0].expr,
Expr::Binary(
Box::new(ast::Expr::Parenthesized(vec![x_expr.clone()])),
Box::new(ast::Expr::Parenthesized(vec![x_expr.clone().into()])),
Operator::Or,
Box::new(ast::Expr::Parenthesized(vec![y_expr.clone()]))
Box::new(ast::Expr::Parenthesized(vec![y_expr.clone().into()]))
)
);
assert_eq!(nonconsumed_terms[1].expr, a_expr);
@@ -340,16 +342,19 @@ mod tests {
Box::new(Expr::Binary(
Box::new(ast::Expr::Parenthesized(vec![rebuild_and_expr_from_list(
vec![a_expr.clone(), x_expr.clone()],
)])),
)
.into()])),
Operator::Or,
Box::new(ast::Expr::Parenthesized(vec![rebuild_and_expr_from_list(
vec![a_expr.clone(), y_expr.clone()],
)])),
)
.into()])),
)),
Operator::Or,
Box::new(ast::Expr::Parenthesized(vec![rebuild_and_expr_from_list(
vec![a_expr.clone(), z_expr.clone()],
)])),
)
.into()])),
);
let mut where_clause = vec![WhereTerm {
@@ -372,12 +377,12 @@ mod tests {
nonconsumed_terms[0].expr,
Expr::Binary(
Box::new(Expr::Binary(
Box::new(ast::Expr::Parenthesized(vec![x_expr])),
Box::new(ast::Expr::Parenthesized(vec![x_expr.into()])),
Operator::Or,
Box::new(ast::Expr::Parenthesized(vec![y_expr])),
Box::new(ast::Expr::Parenthesized(vec![y_expr.into()])),
)),
Operator::Or,
Box::new(ast::Expr::Parenthesized(vec![z_expr])),
Box::new(ast::Expr::Parenthesized(vec![z_expr.into()])),
)
);
assert_eq!(nonconsumed_terms[1].expr, a_expr);
@@ -414,9 +419,9 @@ mod tests {
);
let or_expr = Expr::Binary(
Box::new(ast::Expr::Parenthesized(vec![x_expr.clone()])),
Box::new(ast::Expr::Parenthesized(vec![x_expr.into()])),
Operator::Or,
Box::new(ast::Expr::Parenthesized(vec![y_expr.clone()])),
Box::new(ast::Expr::Parenthesized(vec![y_expr.into()])),
);
let mut where_clause = vec![WhereTerm {
@@ -479,11 +484,13 @@ mod tests {
let or_expr = Expr::Binary(
Box::new(ast::Expr::Parenthesized(vec![rebuild_and_expr_from_list(
vec![a_expr.clone(), x_expr.clone()],
)])),
)
.into()])),
Operator::Or,
Box::new(ast::Expr::Parenthesized(vec![rebuild_and_expr_from_list(
vec![a_expr.clone(), y_expr.clone()],
)])),
)
.into()])),
);
let mut where_clause = vec![WhereTerm {
@@ -503,9 +510,9 @@ mod tests {
assert_eq!(
nonconsumed_terms[0].expr,
Expr::Binary(
Box::new(ast::Expr::Parenthesized(vec![x_expr])),
Box::new(ast::Expr::Parenthesized(vec![x_expr.into()])),
Operator::Or,
Box::new(ast::Expr::Parenthesized(vec![y_expr]))
Box::new(ast::Expr::Parenthesized(vec![y_expr.into()]))
)
);
assert_eq!(

View File

@@ -186,7 +186,7 @@ fn optimize_table_access(
table_references: &mut TableReferences,
available_indexes: &HashMap<String, Vec<Arc<Index>>>,
where_clause: &mut [WhereTerm],
order_by: &mut Option<Vec<(ast::Expr, SortOrder)>>,
order_by: &mut Vec<(Box<ast::Expr>, SortOrder)>,
group_by: &mut Option<GroupBy>,
) -> Result<Option<Vec<JoinOrderMember>>> {
let access_methods_arena = RefCell::new(Vec::new());
@@ -241,11 +241,11 @@ fn optimize_table_access(
let _ = group_by.as_mut().and_then(|g| g.sort_order.take());
}
EliminatesSortBy::Order => {
let _ = order_by.take();
order_by.clear();
}
EliminatesSortBy::GroupByAndOrder => {
let _ = group_by.as_mut().and_then(|g| g.sort_order.take());
let _ = order_by.take();
order_by.clear();
}
}
}
@@ -467,7 +467,7 @@ fn build_vtab_scan_op(
.map(|(i, c)| {
c.ok_or_else(|| {
LimboError::ExtensionError(format!(
"argv_index values must form contiguous sequence starting from 1, missing index {}",
"argv_index values must form contiguous sequence starting from 1, missing index {}",
i + 1
))
})
@@ -536,10 +536,8 @@ fn rewrite_exprs_select(plan: &mut SelectPlan) -> Result<()> {
rewrite_expr(expr, &mut param_count)?;
}
}
if let Some(order_by) = &mut plan.order_by {
for (expr, _) in order_by.iter_mut() {
rewrite_expr(expr, &mut param_count)?;
}
for (expr, _) in plan.order_by.iter_mut() {
rewrite_expr(expr, &mut param_count)?;
}
Ok(())
@@ -561,10 +559,8 @@ fn rewrite_exprs_update(plan: &mut UpdatePlan) -> Result<()> {
for cond in plan.where_clause.iter_mut() {
rewrite_expr(&mut cond.expr, &mut param_idx)?;
}
if let Some(order_by) = &mut plan.order_by {
for (expr, _) in order_by.iter_mut() {
rewrite_expr(expr, &mut param_idx)?;
}
for (expr, _) in plan.order_by.iter_mut() {
rewrite_expr(expr, &mut param_idx)?;
}
if let Some(rc) = plan.returning.as_mut() {
for rc in rc.iter_mut() {
@@ -651,10 +647,7 @@ impl Optimizable for ast::Expr {
}
Expr::RowId { .. } => true,
Expr::InList { lhs, rhs, .. } => {
lhs.is_nonnull(tables)
&& rhs
.as_ref()
.is_none_or(|rhs| rhs.iter().all(|rhs| rhs.is_nonnull(tables)))
lhs.is_nonnull(tables) && rhs.is_empty() || rhs.iter().all(|v| v.is_nonnull(tables))
}
Expr::InSelect { .. } => false,
Expr::InTable { .. } => false,
@@ -715,15 +708,10 @@ impl Optimizable for ast::Expr {
}
Expr::Exists(_) => false,
Expr::FunctionCall { args, name, .. } => {
let Some(func) = resolver
.resolve_function(name.as_str(), args.as_ref().map_or(0, |args| args.len()))
else {
let Some(func) = resolver.resolve_function(name.as_str(), args.len()) else {
return false;
};
func.is_deterministic()
&& args
.as_ref()
.is_none_or(|args| args.iter().all(|arg| arg.is_constant(resolver)))
func.is_deterministic() && args.iter().all(|arg| arg.is_constant(resolver))
}
Expr::FunctionCallStar { .. } => false,
Expr::Id(id) => {
@@ -734,10 +722,8 @@ impl Optimizable for ast::Expr {
Expr::Column { .. } => false,
Expr::RowId { .. } => false,
Expr::InList { lhs, rhs, .. } => {
lhs.is_constant(resolver)
&& rhs
.as_ref()
.is_none_or(|rhs| rhs.iter().all(|rhs| rhs.is_constant(resolver)))
lhs.is_constant(resolver) && rhs.is_empty()
|| rhs.iter().all(|v| v.is_constant(resolver))
}
Expr::InSelect { .. } => {
false // might be constant, too annoying to check subqueries etc. implement later
@@ -827,14 +813,6 @@ impl Optimizable for ast::Expr {
Ok(None)
}
Self::InList { lhs: _, not, rhs } => {
if rhs.is_none() {
return Ok(Some(if *not {
AlwaysTrueOrFalse::AlwaysTrue
} else {
AlwaysTrueOrFalse::AlwaysFalse
}));
}
let rhs = rhs.as_ref().unwrap();
if rhs.is_empty() {
return Ok(Some(if *not {
AlwaysTrueOrFalse::AlwaysTrue

View File

@@ -71,19 +71,19 @@ impl OrderTarget {
/// TODO: this does not currently handle the case where we definitely cannot eliminate
/// the ORDER BY sorter, but we could still eliminate the GROUP BY sorter.
pub fn compute_order_target(
order_by_opt: &mut Option<Vec<(ast::Expr, SortOrder)>>,
order_by: &mut Vec<(Box<ast::Expr>, SortOrder)>,
group_by_opt: Option<&mut GroupBy>,
) -> Option<OrderTarget> {
match (&order_by_opt, group_by_opt) {
match (order_by.is_empty(), group_by_opt) {
// No ordering demands - we don't care what order the joined result rows are in
(None, None) => None,
(true, None) => None,
// Only ORDER BY - we would like the joined result rows to be in the order specified by the ORDER BY
(Some(order_by), None) => OrderTarget::maybe_from_iterator(
order_by.iter().map(|(expr, order)| (expr, *order)),
(false, None) => OrderTarget::maybe_from_iterator(
order_by.iter().map(|(expr, order)| (expr.as_ref(), *order)),
EliminatesSortBy::Order,
),
// Only GROUP BY - we would like the joined result rows to be in the order specified by the GROUP BY
(None, Some(group_by)) => OrderTarget::maybe_from_iterator(
(true, Some(group_by)) => OrderTarget::maybe_from_iterator(
group_by.exprs.iter().map(|expr| (expr, SortOrder::Asc)),
EliminatesSortBy::Group,
),
@@ -96,7 +96,7 @@ pub fn compute_order_target(
// If the GROUP BY contains all the expressions in the ORDER BY,
// then we again can use the GROUP BY expressions as the target order for the join;
// however in this case we must take the ASC/DESC from ORDER BY into account.
(Some(order_by), Some(group_by)) => {
(false, Some(group_by)) => {
// Does the group by contain all expressions in the order by?
let group_by_contains_all = order_by.iter().all(|(expr, _)| {
group_by
@@ -133,7 +133,7 @@ pub fn compute_order_target(
*order_by_dir;
}
// Now we can remove the ORDER BY from the query.
order_by_opt.take();
order_by.clear();
OrderTarget::maybe_from_iterator(
group_by

View File

@@ -36,7 +36,7 @@ pub fn init_order_by(
program: &mut ProgramBuilder,
t_ctx: &mut TranslateCtx,
result_columns: &[ResultSetColumn],
order_by: &[(ast::Expr, SortOrder)],
order_by: &[(Box<ast::Expr>, SortOrder)],
referenced_tables: &TableReferences,
) -> Result<()> {
let sort_cursor = program.alloc_cursor_id(CursorType::Sorter);
@@ -55,8 +55,10 @@ pub fn init_order_by(
*/
let collations = order_by
.iter()
.map(|(expr, _)| match expr {
ast::Expr::Collate(_, collation_name) => CollationSeq::new(collation_name).map(Some),
.map(|(expr, _)| match expr.as_ref() {
ast::Expr::Collate(_, collation_name) => {
CollationSeq::new(collation_name.as_str()).map(Some)
}
ast::Expr::Column { table, column, .. } => {
let table = referenced_tables.find_table_by_internal_id(*table).unwrap();
@@ -86,7 +88,7 @@ pub fn emit_order_by(
t_ctx: &mut TranslateCtx,
plan: &SelectPlan,
) -> Result<()> {
let order_by = plan.order_by.as_ref().unwrap();
let order_by = &plan.order_by;
let result_columns = &plan.result_columns;
let sort_loop_start_label = program.allocate_label();
let sort_loop_next_label = program.allocate_label();
@@ -161,7 +163,7 @@ pub fn order_by_sorter_insert(
sort_metadata: &SortMetadata,
plan: &SelectPlan,
) -> Result<()> {
let order_by = plan.order_by.as_ref().unwrap();
let order_by = &plan.order_by;
let order_by_len = order_by.len();
let result_columns = &plan.result_columns;
let result_columns_to_skip_len = sort_metadata
@@ -322,7 +324,7 @@ pub struct OrderByRemapping {
///
/// If any result columns can be skipped, this returns list of 2-tuples of (SkippedResultColumnIndex: usize, ResultColumnIndexInOrderBySorter: usize)
pub fn order_by_deduplicate_result_columns(
order_by: &[(ast::Expr, SortOrder)],
order_by: &[(Box<ast::Expr>, SortOrder)],
result_columns: &[ResultSetColumn],
) -> Vec<OrderByRemapping> {
let mut result_column_remapping: Vec<OrderByRemapping> = Vec::new();

View File

@@ -288,7 +288,7 @@ pub struct SelectPlan {
/// group by clause
pub group_by: Option<GroupBy>,
/// order by clause
pub order_by: Option<Vec<(ast::Expr, SortOrder)>>,
pub order_by: Vec<(Box<ast::Expr>, SortOrder)>,
/// all the aggregates collected from the result columns, order by, and (TODO) having clauses
pub aggregates: Vec<Aggregate>,
/// limit clause
@@ -342,16 +342,22 @@ impl SelectPlan {
return false;
}
let count = turso_parser::ast::Expr::FunctionCall {
name: turso_parser::ast::Name::Ident("count".to_string()),
let count = ast::Expr::FunctionCall {
name: ast::Name::Ident("count".to_string()),
distinctness: None,
args: None,
order_by: None,
filter_over: None,
args: vec![],
order_by: vec![],
filter_over: ast::FunctionTail {
filter_clause: None,
over_clause: None,
},
};
let count_star = turso_parser::ast::Expr::FunctionCallStar {
name: turso_parser::ast::Name::Ident("count".to_string()),
filter_over: None,
let count_star = ast::Expr::FunctionCallStar {
name: ast::Name::Ident("count".to_string()),
filter_over: ast::FunctionTail {
filter_clause: None,
over_clause: None,
},
};
let result_col_expr = &self.result_columns.first().unwrap().expr;
if *result_col_expr != count && *result_col_expr != count_star {
@@ -370,7 +376,7 @@ pub struct DeletePlan {
/// where clause split into a vec at 'AND' boundaries.
pub where_clause: Vec<WhereTerm>,
/// order by clause
pub order_by: Option<Vec<(ast::Expr, SortOrder)>>,
pub order_by: Vec<(Box<ast::Expr>, SortOrder)>,
/// limit clause
pub limit: Option<isize>,
/// offset clause
@@ -385,9 +391,9 @@ pub struct DeletePlan {
pub struct UpdatePlan {
pub table_references: TableReferences,
// (colum index, new value) pairs
pub set_clauses: Vec<(usize, ast::Expr)>,
pub set_clauses: Vec<(usize, Box<ast::Expr>)>,
pub where_clause: Vec<WhereTerm>,
pub order_by: Option<Vec<(ast::Expr, SortOrder)>>,
pub order_by: Vec<(Box<ast::Expr>, SortOrder)>,
pub limit: Option<isize>,
pub offset: Option<isize>,
// TODO: optional RETURNING clause
@@ -410,10 +416,6 @@ pub enum IterationDirection {
pub fn select_star(tables: &[JoinedTable], out_columns: &mut Vec<ResultSetColumn>) {
for table in tables.iter() {
let maybe_using_cols = table
.join_info
.as_ref()
.and_then(|join_info| join_info.using.as_ref());
out_columns.extend(
table
.columns()
@@ -423,8 +425,8 @@ pub fn select_star(tables: &[JoinedTable], out_columns: &mut Vec<ResultSetColumn
.filter(|(_, col)| {
// If we are joining with USING, we need to deduplicate the columns from the right table
// that are also present in the USING clause.
if let Some(using_cols) = maybe_using_cols {
!using_cols.iter().any(|using_col| {
if let Some(join_info) = &table.join_info {
!join_info.using.iter().any(|using_col| {
col.name
.as_ref()
.is_some_and(|name| name.eq_ignore_ascii_case(using_col.as_str()))

View File

@@ -49,21 +49,17 @@ pub fn resolve_aggregates(
filter_over,
order_by,
} => {
if filter_over.is_some() {
if filter_over.filter_clause.is_some() || filter_over.over_clause.is_some() {
crate::bail_parse_error!(
"FILTER clause is not supported yet in aggregate functions"
);
}
if order_by.is_some() {
if !order_by.is_empty() {
crate::bail_parse_error!(
"ORDER BY clause is not supported yet in aggregate functions"
);
}
let args_count = if let Some(args) = &args {
args.len()
} else {
0
};
let args_count = args.len();
match Func::resolve_function(name.as_str(), args_count) {
Ok(Func::Agg(f)) => {
let distinctness = Distinctness::from_ast(distinctness.as_ref());
@@ -72,31 +68,28 @@ pub fn resolve_aggregates(
"SELECT with DISTINCT is not allowed without indexes enabled"
);
}
let num_args = args.as_ref().map_or(0, |args| args.len());
if distinctness.is_distinct() && num_args != 1 {
if distinctness.is_distinct() && args.len() != 1 {
crate::bail_parse_error!(
"DISTINCT aggregate functions must have exactly one argument"
);
}
aggs.push(Aggregate {
func: f,
args: args.clone().unwrap_or_default(),
args: args.iter().map(|arg| *arg.clone()).collect(),
original_expr: expr.clone(),
distinctness,
});
contains_aggregates = true;
}
_ => {
if let Some(args) = args {
for arg in args.iter() {
contains_aggregates |= resolve_aggregates(schema, arg, aggs)?;
}
for arg in args.iter() {
contains_aggregates |= resolve_aggregates(schema, arg, aggs)?;
}
}
}
}
Expr::FunctionCallStar { name, filter_over } => {
if filter_over.is_some() {
if filter_over.filter_clause.is_some() || filter_over.over_clause.is_some() {
crate::bail_parse_error!(
"FILTER clause is not supported yet in aggregate functions"
);
@@ -356,15 +349,15 @@ fn parse_from_clause_table(
ctes,
table_ref_counter,
vtab_predicates,
qualified_name,
maybe_alias,
None,
&qualified_name,
maybe_alias.as_ref(),
&[],
connection,
),
ast::SelectTable::Select(subselect, maybe_alias) => {
let Plan::Select(subplan) = prepare_select_plan(
schema,
*subselect,
subselect,
syms,
table_references.outer_query_refs(),
table_ref_counter,
@@ -392,16 +385,16 @@ fn parse_from_clause_table(
));
Ok(())
}
ast::SelectTable::TableCall(qualified_name, maybe_args, maybe_alias) => parse_table(
ast::SelectTable::TableCall(qualified_name, args, maybe_alias) => parse_table(
schema,
syms,
table_references,
ctes,
table_ref_counter,
vtab_predicates,
qualified_name,
maybe_alias,
maybe_args,
&qualified_name,
maybe_alias.as_ref(),
&args,
connection,
),
_ => todo!(),
@@ -416,14 +409,14 @@ fn parse_table(
ctes: &mut Vec<JoinedTable>,
table_ref_counter: &mut TableRefIdCounter,
vtab_predicates: &mut Vec<Expr>,
qualified_name: QualifiedName,
maybe_alias: Option<As>,
maybe_args: Option<Vec<Expr>>,
qualified_name: &QualifiedName,
maybe_alias: Option<&As>,
args: &[Box<Expr>],
connection: &Arc<crate::Connection>,
) -> Result<()> {
let normalized_qualified_name = normalize_ident(qualified_name.name.as_str());
let database_id = connection.resolve_database_id(&qualified_name)?;
let table_name = qualified_name.name;
let database_id = connection.resolve_database_id(qualified_name)?;
let table_name = qualified_name.name.clone();
// Check if the FROM clause table is referring to a CTE in the current scope.
if let Some(cte_idx) = ctes
@@ -448,14 +441,7 @@ fn parse_table(
.map(|a| a.as_str().to_string());
let internal_id = table_ref_counter.next();
let tbl_ref = if let Table::Virtual(tbl) = table.as_ref() {
if let Some(args) = maybe_args {
transform_args_into_where_terms(
args,
internal_id,
vtab_predicates,
table.as_ref(),
)?;
}
transform_args_into_where_terms(args, internal_id, vtab_predicates, table.as_ref())?;
Table::Virtual(tbl.clone())
} else if let Table::BTree(table) = table.as_ref() {
Table::BTree(table.clone())
@@ -485,12 +471,14 @@ fn parse_table(
let subselect = Box::new(view_select);
// Use the view name as alias if no explicit alias was provided
let view_alias = maybe_alias.or_else(|| Some(ast::As::As(table_name.clone())));
let view_alias = maybe_alias
.cloned()
.or_else(|| Some(ast::As::As(table_name.clone())));
// Recursively call parse_from_clause_table with the view as a SELECT
return parse_from_clause_table(
schema,
ast::SelectTable::Select(subselect, view_alias),
ast::SelectTable::Select(*subselect.clone(), view_alias),
table_references,
vtab_predicates,
ctes,
@@ -559,12 +547,12 @@ fn parse_table(
}
fn transform_args_into_where_terms(
args: Vec<Expr>,
args: &[Box<Expr>],
internal_id: TableInternalId,
predicates: &mut Vec<Expr>,
table: &Table,
) -> Result<()> {
let mut args_iter = args.into_iter();
let mut args_iter = args.iter();
let mut hidden_count = 0;
for (i, col) in table.columns().iter().enumerate() {
if !col.hidden {
@@ -579,12 +567,12 @@ fn transform_args_into_where_terms(
column: i,
is_rowid_alias: col.is_rowid_alias,
};
let expr = match arg_expr {
let expr = match arg_expr.as_ref() {
Expr::Literal(Null) => Expr::IsNull(Box::new(column_expr)),
other => Expr::Binary(
Box::new(column_expr),
column_expr.into(),
ast::Operator::Equals,
Box::new(other),
other.clone().into(),
),
};
predicates.push(expr);
@@ -615,7 +603,7 @@ pub fn parse_from(
table_ref_counter: &mut TableRefIdCounter,
connection: &Arc<crate::Connection>,
) -> Result<()> {
if from.as_ref().and_then(|f| f.select.as_ref()).is_none() {
if from.is_none() {
return Ok(());
}
@@ -629,7 +617,7 @@ pub fn parse_from(
if cte.materialized == Materialized::Yes {
crate::bail_parse_error!("Materialized CTEs are not yet supported");
}
if cte.columns.is_some() {
if !cte.columns.is_empty() {
crate::bail_parse_error!("CTE columns are not yet supported");
}
@@ -668,7 +656,7 @@ pub fn parse_from(
// CTE can refer to other CTEs that came before it, plus any schema tables or tables in the outer scope.
let cte_plan = prepare_select_plan(
schema,
*cte.select,
cte.select,
syms,
&outer_query_refs_for_cte,
table_ref_counter,
@@ -690,12 +678,12 @@ pub fn parse_from(
}
}
let mut from_owned = std::mem::take(&mut from).unwrap();
let select_owned = *std::mem::take(&mut from_owned.select).unwrap();
let joins_owned = std::mem::take(&mut from_owned.joins).unwrap_or_default();
let from_owned = std::mem::take(&mut from).unwrap();
let select_owned = from_owned.select;
let joins_owned = from_owned.joins;
parse_from_clause_table(
schema,
select_owned,
*select_owned,
table_references,
vtab_predicates,
&mut ctes_as_subqueries,
@@ -722,7 +710,7 @@ pub fn parse_from(
}
pub fn parse_where(
where_clause: Option<Expr>,
where_clause: Option<&Expr>,
table_references: &mut TableReferences,
result_columns: Option<&[ResultSetColumn]>,
out_where_clause: &mut Vec<WhereTerm>,
@@ -941,7 +929,7 @@ fn parse_join(
parse_from_clause_table(
schema,
table,
table.as_ref().clone(),
table_references,
vtab_predicates,
ctes,
@@ -959,8 +947,6 @@ fn parse_join(
_ => (false, false),
};
let mut using = None;
if natural && constraint.is_some() {
crate::bail_parse_error!("NATURAL JOIN cannot be combined with ON or USING clause");
}
@@ -969,7 +955,7 @@ fn parse_join(
assert!(table_references.joined_tables().len() >= 2);
let rightmost_table = table_references.joined_tables().last().unwrap();
// NATURAL JOIN is first transformed into a USING join with the common columns
let mut distinct_names: Option<ast::DistinctNames> = None;
let mut distinct_names: Vec<ast::Name> = vec![];
// TODO: O(n^2) maybe not great for large tables or big multiway joins
// SQLite doesn't use HIDDEN columns for NATURAL joins: https://www3.sqlite.org/src/info/ab09ef427181130b
for right_col in rightmost_table.columns().iter().filter(|col| !col.hidden) {
@@ -981,17 +967,9 @@ fn parse_join(
{
for left_col in left_table.columns().iter().filter(|col| !col.hidden) {
if left_col.name == right_col.name {
if let Some(distinct_names) = distinct_names.as_mut() {
distinct_names
.insert(ast::Name::from_str(
&left_col.name.clone().expect("column name is None"),
))
.unwrap();
} else {
distinct_names = Some(ast::DistinctNames::new(ast::Name::from_str(
&left_col.name.clone().expect("column name is None"),
)));
}
distinct_names.push(ast::Name::new(
left_col.name.clone().expect("column name is None"),
));
found_match = true;
break;
}
@@ -1001,18 +979,20 @@ fn parse_join(
}
}
}
if let Some(distinct_names) = distinct_names {
Some(ast::JoinConstraint::Using(distinct_names))
} else {
if distinct_names.is_empty() {
crate::bail_parse_error!("No columns found to NATURAL join on");
} else {
Some(ast::JoinConstraint::Using(distinct_names))
}
} else {
constraint
};
let mut using = vec![];
if let Some(constraint) = constraint {
match constraint {
ast::JoinConstraint::On(expr) => {
ast::JoinConstraint::On(ref expr) => {
let mut preds = vec![];
break_predicate_at_and_boundaries(expr, &mut preds);
for predicate in preds.iter_mut() {
@@ -1110,7 +1090,7 @@ fn parse_join(
consumed: false,
});
}
using = Some(distinct_names);
using = distinct_names;
}
}
}
@@ -1128,7 +1108,7 @@ fn parse_join(
pub fn parse_limit(limit: &Limit) -> Result<(Option<isize>, Option<isize>)> {
let offset_val = match &limit.offset {
Some(offset_expr) => match offset_expr {
Some(offset_expr) => match offset_expr.as_ref() {
Expr::Literal(ast::Literal::Numeric(n)) => n.parse().ok(),
// If OFFSET is negative, the result is as if OFFSET is zero
Expr::Unary(UnaryOperator::Negative, expr) => {
@@ -1143,16 +1123,16 @@ pub fn parse_limit(limit: &Limit) -> Result<(Option<isize>, Option<isize>)> {
None => Some(0),
};
if let Expr::Literal(ast::Literal::Numeric(n)) = &limit.expr {
if let Expr::Literal(ast::Literal::Numeric(n)) = limit.expr.as_ref() {
Ok((n.parse().ok(), offset_val))
} else if let Expr::Unary(UnaryOperator::Negative, expr) = &limit.expr {
if let Expr::Literal(ast::Literal::Numeric(n)) = &**expr {
} else if let Expr::Unary(UnaryOperator::Negative, expr) = limit.expr.as_ref() {
if let Expr::Literal(ast::Literal::Numeric(n)) = expr.as_ref() {
let limit_val = n.parse::<isize>().ok().map(|num| -num);
Ok((limit_val, offset_val))
} else {
crate::bail_parse_error!("Invalid LIMIT clause");
}
} else if let Expr::Id(id) = &limit.expr {
} else if let Expr::Id(id) = limit.expr.as_ref() {
if id.as_str().eq_ignore_ascii_case("true") {
Ok((Some(1), offset_val))
} else if id.as_str().eq_ignore_ascii_case("false") {
@@ -1165,14 +1145,14 @@ pub fn parse_limit(limit: &Limit) -> Result<(Option<isize>, Option<isize>)> {
}
}
pub fn break_predicate_at_and_boundaries(predicate: Expr, out_predicates: &mut Vec<Expr>) {
pub fn break_predicate_at_and_boundaries(predicate: &Expr, out_predicates: &mut Vec<Expr>) {
match predicate {
Expr::Binary(left, ast::Operator::And, right) => {
break_predicate_at_and_boundaries(*left, out_predicates);
break_predicate_at_and_boundaries(*right, out_predicates);
break_predicate_at_and_boundaries(left, out_predicates);
break_predicate_at_and_boundaries(right, out_predicates);
}
_ => {
out_predicates.push(predicate);
out_predicates.push(predicate.clone());
}
}
}

View File

@@ -63,9 +63,9 @@ pub fn translate_pragma(
None => query_pragma(pragma, schema, None, pager, connection, program)?,
Some(ast::PragmaBody::Equals(value) | ast::PragmaBody::Call(value)) => match pragma {
PragmaName::TableInfo => {
query_pragma(pragma, schema, Some(value), pager, connection, program)?
query_pragma(pragma, schema, Some(*value), pager, connection, program)?
}
_ => update_pragma(pragma, schema, syms, value, pager, connection, program)?,
_ => update_pragma(pragma, schema, syms, *value, pager, connection, program)?,
},
};
match mode {
@@ -275,14 +275,17 @@ fn update_pragma(
if let Some(table) = &opts.table() {
// make sure that we have table created
program = translate_create_table(
QualifiedName::single(ast::Name::from_str(table)),
QualifiedName {
db_name: None,
name: ast::Name::new(table),
alias: None,
},
false,
ast::CreateTableBody::columns_and_constraints_from_definition(
turso_cdc_table_columns(),
None,
ast::TableOptions::NONE,
)
.unwrap(),
ast::CreateTableBody::ColumnsAndConstraints {
columns: turso_cdc_table_columns(),
constraints: vec![],
options: ast::TableOptions::NONE,
},
true,
schema,
syms,
@@ -460,9 +463,7 @@ fn query_pragma(
let view = view_mutex.lock().unwrap();
emit_columns_for_table_info(&mut program, &view.columns, base_reg);
} else if let Some(view) = schema.get_view(&name) {
if let Some(ref columns) = view.columns {
emit_columns_for_table_info(&mut program, columns, base_reg);
}
emit_columns_for_table_info(&mut program, &view.columns, base_reg);
}
}
let col_names = ["cid", "name", "type", "notnull", "dflt_value", "pk"];
@@ -698,7 +699,7 @@ pub const TURSO_CDC_DEFAULT_TABLE_NAME: &str = "turso_cdc";
fn turso_cdc_table_columns() -> Vec<ColumnDefinition> {
vec![
ast::ColumnDefinition {
col_name: ast::Name::from_str("change_id"),
col_name: ast::Name::new("change_id"),
col_type: Some(ast::Type {
name: "INTEGER".to_string(),
size: None,
@@ -713,7 +714,7 @@ fn turso_cdc_table_columns() -> Vec<ColumnDefinition> {
}],
},
ast::ColumnDefinition {
col_name: ast::Name::from_str("change_time"),
col_name: ast::Name::new("change_time"),
col_type: Some(ast::Type {
name: "INTEGER".to_string(),
size: None,
@@ -721,7 +722,7 @@ fn turso_cdc_table_columns() -> Vec<ColumnDefinition> {
constraints: vec![],
},
ast::ColumnDefinition {
col_name: ast::Name::from_str("change_type"),
col_name: ast::Name::new("change_type"),
col_type: Some(ast::Type {
name: "INTEGER".to_string(),
size: None,
@@ -729,7 +730,7 @@ fn turso_cdc_table_columns() -> Vec<ColumnDefinition> {
constraints: vec![],
},
ast::ColumnDefinition {
col_name: ast::Name::from_str("table_name"),
col_name: ast::Name::new("table_name"),
col_type: Some(ast::Type {
name: "TEXT".to_string(),
size: None,
@@ -737,12 +738,12 @@ fn turso_cdc_table_columns() -> Vec<ColumnDefinition> {
constraints: vec![],
},
ast::ColumnDefinition {
col_name: ast::Name::from_str("id"),
col_name: ast::Name::new("id"),
col_type: None,
constraints: vec![],
},
ast::ColumnDefinition {
col_name: ast::Name::from_str("before"),
col_name: ast::Name::new("before"),
col_type: Some(ast::Type {
name: "BLOB".to_string(),
size: None,
@@ -750,7 +751,7 @@ fn turso_cdc_table_columns() -> Vec<ColumnDefinition> {
constraints: vec![],
},
ast::ColumnDefinition {
col_name: ast::Name::from_str("after"),
col_name: ast::Name::new("after"),
col_type: Some(ast::Type {
name: "BLOB".to_string(),
size: None,
@@ -758,7 +759,7 @@ fn turso_cdc_table_columns() -> Vec<ColumnDefinition> {
constraints: vec![],
},
ast::ColumnDefinition {
col_name: ast::Name::from_str("updates"),
col_name: ast::Name::new("updates"),
col_type: Some(ast::Type {
name: "BLOB".to_string(),
size: None,

View File

@@ -308,100 +308,110 @@ fn check_automatic_pk_index_required(
let mut unique_sets = vec![];
// Check table constraints for PRIMARY KEY
if let Some(constraints) = constraints {
for constraint in constraints {
if let ast::TableConstraint::PrimaryKey {
columns: pk_cols, ..
} = &constraint.constraint
{
if primary_key_definition.is_some() {
bail_parse_error!("table {} has more than one primary key", tbl_name);
}
let primary_key_column_results = pk_cols
for constraint in constraints {
if let ast::TableConstraint::PrimaryKey {
columns: pk_cols, ..
} = &constraint.constraint
{
if primary_key_definition.is_some() {
bail_parse_error!("table {} has more than one primary key", tbl_name);
}
let primary_key_column_results = pk_cols
.iter()
.map(|col| match col.expr.as_ref() {
ast::Expr::Id(name) => {
if !columns.iter().any(
|ast::ColumnDefinition { col_name, .. }| {
col_name.as_str() == name.as_str()
},
) {
bail_parse_error!("No such column: {}", name.as_str());
}
Ok(PrimaryKeyColumnInfo {
name: name.as_str(),
is_descending: matches!(col.order, Some(ast::SortOrder::Desc)),
})
}
_ => Err(LimboError::ParseError(
"expressions prohibited in PRIMARY KEY and UNIQUE constraints"
.to_string(),
)),
})
.collect::<Result<Vec<_>>>()?;
for pk_info in primary_key_column_results {
let column_name = pk_info.name;
let column_def = columns
.iter()
.map(|col| match &col.expr {
ast::Expr::Id(name) => {
if !columns.iter().any(|(k, _)| k.as_str() == name.as_str()) {
bail_parse_error!("No such column: {}", name.as_str());
}
Ok(PrimaryKeyColumnInfo {
name: name.as_str(),
is_descending: matches!(
col.order,
Some(ast::SortOrder::Desc)
),
})
}
_ => Err(LimboError::ParseError(
"expressions prohibited in PRIMARY KEY and UNIQUE constraints"
.to_string(),
)),
.find(|ast::ColumnDefinition { col_name, .. }| {
col_name.as_str() == column_name
})
.collect::<Result<Vec<_>>>()?;
.expect("primary key column should be in Create Body columns");
for pk_info in primary_key_column_results {
let column_name = pk_info.name;
let (_, column_def) = columns
.iter()
.find(|(k, _)| k.as_str() == column_name)
.expect("primary key column should be in Create Body columns");
match &mut primary_key_definition {
Some(PrimaryKeyDefinitionType::Simple { column, .. }) => {
let mut columns = HashSet::new();
columns.insert(std::mem::take(column));
// Have to also insert the current column_name we are iterating over in primary_key_column_results
columns.insert(column_name.to_string());
primary_key_definition =
Some(PrimaryKeyDefinitionType::Composite { columns });
}
Some(PrimaryKeyDefinitionType::Composite { columns }) => {
columns.insert(column_name.to_string());
}
None => {
let typename =
column_def.col_type.as_ref().map(|t| t.name.as_str());
let is_descending = pk_info.is_descending;
primary_key_definition =
Some(PrimaryKeyDefinitionType::Simple {
typename,
is_descending,
column: column_name.to_string(),
});
}
match &mut primary_key_definition {
Some(PrimaryKeyDefinitionType::Simple { column, .. }) => {
let mut columns = HashSet::new();
columns.insert(std::mem::take(column));
// Have to also insert the current column_name we are iterating over in primary_key_column_results
columns.insert(column_name.to_string());
primary_key_definition =
Some(PrimaryKeyDefinitionType::Composite { columns });
}
Some(PrimaryKeyDefinitionType::Composite { columns }) => {
columns.insert(column_name.to_string());
}
None => {
let typename =
column_def.col_type.as_ref().map(|t| t.name.as_str());
let is_descending = pk_info.is_descending;
primary_key_definition = Some(PrimaryKeyDefinitionType::Simple {
typename,
is_descending,
column: column_name.to_string(),
});
}
}
} else if let ast::TableConstraint::Unique {
columns: unique_columns,
conflict_clause,
} = &constraint.constraint
{
if conflict_clause.is_some() {
unimplemented!("ON CONFLICT not implemented");
}
let col_names = unique_columns
.iter()
.map(|column| match &column.expr {
turso_parser::ast::Expr::Id(id) => {
if !columns.iter().any(|(k, _)| k.as_str() == id.as_str()) {
bail_parse_error!("No such column: {}", id.as_str());
}
Ok(crate::util::normalize_ident(id.as_str()))
}
_ => {
todo!("Unsupported unique expression");
}
})
.collect::<Result<HashSet<String>>>()?;
unique_sets.push(col_names);
}
} else if let ast::TableConstraint::Unique {
columns: unique_columns,
conflict_clause,
} = &constraint.constraint
{
if conflict_clause.is_some() {
unimplemented!("ON CONFLICT not implemented");
}
let col_names = unique_columns
.iter()
.map(|column| match column.expr.as_ref() {
turso_parser::ast::Expr::Id(id) => {
if !columns.iter().any(
|ast::ColumnDefinition { col_name, .. }| {
col_name.as_str() == id.as_str()
},
) {
bail_parse_error!("No such column: {}", id.as_str());
}
Ok(crate::util::normalize_ident(id.as_str()))
}
_ => {
todo!("Unsupported unique expression");
}
})
.collect::<Result<HashSet<String>>>()?;
unique_sets.push(col_names);
}
}
// Check column constraints for PRIMARY KEY and UNIQUE
for (_, col_def) in columns.iter() {
for constraint in &col_def.constraints {
for ast::ColumnDefinition {
col_name,
col_type,
constraints,
..
} in columns.iter()
{
for constraint in constraints {
if matches!(
constraint.constraint,
ast::ColumnConstraint::PrimaryKey { .. }
@@ -409,15 +419,15 @@ fn check_automatic_pk_index_required(
if primary_key_definition.is_some() {
bail_parse_error!("table {} has more than one primary key", tbl_name);
}
let typename = col_def.col_type.as_ref().map(|t| t.name.as_str());
let typename = col_type.as_ref().map(|t| t.name.as_str());
primary_key_definition = Some(PrimaryKeyDefinitionType::Simple {
typename,
is_descending: false,
column: col_def.col_name.as_str().to_string(),
column: col_name.as_str().to_string(),
});
} else if matches!(constraint.constraint, ast::ColumnConstraint::Unique(..)) {
let mut single_set = HashSet::new();
single_set.insert(col_def.col_name.as_str().to_string());
single_set.insert(col_name.as_str().to_string());
unique_sets.push(single_set);
}
}
@@ -506,15 +516,13 @@ fn create_table_body_to_str(tbl_name: &ast::QualifiedName, body: &ast::CreateTab
sql
}
fn create_vtable_body_to_str(vtab: &CreateVirtualTable, module: Rc<VTabImpl>) -> String {
let args = if let Some(args) = &vtab.args {
args.iter()
.map(|arg| arg.to_string())
.collect::<Vec<String>>()
.join(", ")
} else {
"".to_string()
};
fn create_vtable_body_to_str(vtab: &ast::CreateVirtualTable, module: Rc<VTabImpl>) -> String {
let args = vtab
.args
.iter()
.map(|arg| arg.to_string())
.collect::<Vec<String>>()
.join(", ");
let if_not_exists = if vtab.if_not_exists {
"IF NOT EXISTS "
} else {
@@ -522,8 +530,6 @@ fn create_vtable_body_to_str(vtab: &CreateVirtualTable, module: Rc<VTabImpl>) ->
};
let ext_args = vtab
.args
.as_ref()
.unwrap_or(&vec![])
.iter()
.map(|a| turso_ext::Value::from_text(a.to_string()))
.collect::<Vec<_>>();
@@ -553,7 +559,7 @@ fn create_vtable_body_to_str(vtab: &CreateVirtualTable, module: Rc<VTabImpl>) ->
}
pub fn translate_create_virtual_table(
vtab: CreateVirtualTable,
vtab: ast::CreateVirtualTable,
schema: &Schema,
syms: &SymbolTable,
mut program: ProgramBuilder,
@@ -567,7 +573,7 @@ pub fn translate_create_virtual_table(
let table_name = tbl_name.name.as_str().to_string();
let module_name_str = module_name.as_str().to_string();
let args_vec = args.clone().unwrap_or_default();
let args_vec = args.clone();
let Some(vtab_module) = syms.vtab_modules.get(&module_name_str) else {
bail_parse_error!("no such module: {}", module_name_str);
};

View File

@@ -17,8 +17,8 @@ use crate::vdbe::insn::Insn;
use crate::{schema::Schema, vdbe::builder::ProgramBuilder, Result};
use crate::{Connection, SymbolTable};
use std::sync::Arc;
use turso_parser::ast::{self, CompoundSelect, Expr, SortOrder};
use turso_parser::ast::ResultColumn;
use turso_parser::ast::{self, CompoundSelect, Expr, SortOrder};
pub struct TranslateSelectResult {
pub program: ProgramBuilder,
@@ -90,36 +90,33 @@ pub fn translate_select(
pub fn prepare_select_plan(
schema: &Schema,
mut select: ast::Select,
select: ast::Select,
syms: &SymbolTable,
outer_query_refs: &[OuterQueryReference],
table_ref_counter: &mut TableRefIdCounter,
query_destination: QueryDestination,
connection: &Arc<crate::Connection>,
) -> Result<Plan> {
let compounds = select.body.compounds.take();
match compounds {
None => {
let limit = select.limit.take();
Ok(Plan::Select(prepare_one_select_plan(
schema,
*select.body.select,
limit.as_deref(),
select.order_by.take(),
select.with.take(),
syms,
outer_query_refs,
table_ref_counter,
query_destination,
connection,
)?))
}
Some(compounds) => {
let compounds = select.body.compounds;
match compounds.is_empty() {
true => Ok(Plan::Select(prepare_one_select_plan(
schema,
select.body.select,
select.limit,
select.order_by,
select.with,
syms,
outer_query_refs,
table_ref_counter,
query_destination,
connection,
)?)),
false => {
let mut last = prepare_one_select_plan(
schema,
*select.body.select,
None,
select.body.select,
None,
vec![],
None,
syms,
outer_query_refs,
@@ -133,9 +130,9 @@ pub fn prepare_select_plan(
left.push((last, operator));
last = prepare_one_select_plan(
schema,
*select,
None,
select,
None,
vec![],
None,
syms,
outer_query_refs,
@@ -152,10 +149,13 @@ pub fn prepare_select_plan(
crate::bail_parse_error!("SELECTs to the left and right of {} do not have the same number of result columns", operator);
}
}
let (limit, offset) = select.limit.map_or(Ok((None, None)), |l| parse_limit(&l))?;
let (limit, offset) = select
.limit
.as_ref()
.map_or(Ok((None, None)), parse_limit)?;
// FIXME: handle ORDER BY for compound selects
if select.order_by.is_some() {
if !select.order_by.is_empty() {
crate::bail_parse_error!("ORDER BY is not supported for compound SELECTs yet");
}
// FIXME: handle WITH for compound selects
@@ -177,8 +177,8 @@ pub fn prepare_select_plan(
fn prepare_one_select_plan(
schema: &Schema,
select: ast::OneSelect,
limit: Option<&ast::Limit>,
order_by: Option<Vec<ast::SortedColumn>>,
limit: Option<ast::Limit>,
order_by: Vec<ast::SortedColumn>,
with: Option<ast::With>,
syms: &SymbolTable,
outer_query_refs: &[OuterQueryReference],
@@ -187,22 +187,21 @@ fn prepare_one_select_plan(
connection: &Arc<crate::Connection>,
) -> Result<SelectPlan> {
match select {
ast::OneSelect::Select(select_inner) => {
let SelectInner {
mut columns,
from,
where_clause,
group_by,
distinctness,
window_clause,
..
} = *select_inner;
ast::OneSelect::Select {
mut columns,
from,
where_clause,
group_by,
distinctness,
window_clause,
..
} => {
if !schema.indexes_enabled() && distinctness.is_some() {
crate::bail_parse_error!(
"SELECT with DISTINCT is not allowed without indexes enabled"
);
}
if window_clause.is_some() {
if !window_clause.is_empty() {
crate::bail_parse_error!("WINDOW clause is not supported yet");
}
let col_count = columns.len();
@@ -275,7 +274,7 @@ fn prepare_one_select_plan(
result_columns,
where_clause: where_predicates,
group_by: None,
order_by: None,
order_by: vec![],
aggregates: vec![],
limit: None,
offset: None,
@@ -341,7 +340,7 @@ fn prepare_one_select_plan(
Some(&plan.result_columns),
connection,
)?;
match expr {
match expr.as_ref() {
ast::Expr::FunctionCall {
name,
distinctness,
@@ -349,19 +348,17 @@ fn prepare_one_select_plan(
filter_over,
order_by,
} => {
if filter_over.is_some() {
if filter_over.filter_clause.is_some()
|| filter_over.over_clause.is_some()
{
crate::bail_parse_error!(
"FILTER clause is not supported yet in aggregate functions"
);
}
if order_by.is_some() {
if !order_by.is_empty() {
crate::bail_parse_error!("ORDER BY clause is not supported yet in aggregate functions");
}
let args_count = if let Some(args) = &args {
args.len()
} else {
0
};
let args_count = args.len();
let distinctness = Distinctness::from_ast(distinctness.as_ref());
if !schema.indexes_enabled() && distinctness.is_distinct() {
@@ -374,24 +371,25 @@ fn prepare_one_select_plan(
}
match Func::resolve_function(name.as_str(), args_count) {
Ok(Func::Agg(f)) => {
let agg_args = match (args, &f) {
(None, crate::function::AggFunc::Count0) => {
let agg_args = match (args.is_empty(), &f) {
(true, crate::function::AggFunc::Count0) => {
// COUNT() case
vec![ast::Expr::Literal(ast::Literal::Numeric(
"1".to_string(),
))]
))
.into()]
}
(None, _) => crate::bail_parse_error!(
(true, _) => crate::bail_parse_error!(
"Aggregate function {} requires arguments",
name.as_str()
),
(Some(args), _) => args.clone(),
(false, _) => args.clone(),
};
let agg = Aggregate {
func: f,
args: agg_args.clone(),
original_expr: expr.clone(),
args: agg_args.iter().map(|arg| *arg.clone()).collect(),
original_expr: *expr.clone(),
distinctness,
};
aggregate_expressions.push(agg.clone());
@@ -402,7 +400,7 @@ fn prepare_one_select_plan(
}
ast::As::As(alias) => alias.as_str().to_string(),
}),
expr: expr.clone(),
expr: *expr.clone(),
contains_aggregates: true,
});
}
@@ -419,7 +417,7 @@ fn prepare_one_select_plan(
}
ast::As::As(alias) => alias.as_str().to_string(),
}),
expr: expr.clone(),
expr: *expr.clone(),
contains_aggregates,
});
}
@@ -444,14 +442,17 @@ fn prepare_one_select_plan(
}
}
}),
expr: expr.clone(),
expr: *expr.clone(),
contains_aggregates,
});
} else {
let agg = Aggregate {
func: AggFunc::External(f.func.clone().into()),
args: args.as_ref().unwrap().clone(),
original_expr: expr.clone(),
args: args
.iter()
.map(|arg| *arg.clone())
.collect(),
original_expr: *expr.clone(),
distinctness,
};
aggregate_expressions.push(agg.clone());
@@ -466,7 +467,7 @@ fn prepare_one_select_plan(
}
}
}),
expr: expr.clone(),
expr: *expr.clone(),
contains_aggregates: true,
});
}
@@ -478,7 +479,9 @@ fn prepare_one_select_plan(
}
}
ast::Expr::FunctionCallStar { name, filter_over } => {
if filter_over.is_some() {
if filter_over.filter_clause.is_some()
|| filter_over.over_clause.is_some()
{
crate::bail_parse_error!(
"FILTER clause is not supported yet in aggregate functions"
);
@@ -490,7 +493,7 @@ fn prepare_one_select_plan(
args: vec![ast::Expr::Literal(ast::Literal::Numeric(
"1".to_string(),
))],
original_expr: expr.clone(),
original_expr: *expr.clone(),
distinctness: Distinctness::NonDistinct,
};
aggregate_expressions.push(agg.clone());
@@ -501,7 +504,7 @@ fn prepare_one_select_plan(
}
ast::As::As(alias) => alias.as_str().to_string(),
}),
expr: expr.clone(),
expr: *expr.clone(),
contains_aggregates: true,
});
}
@@ -548,7 +551,7 @@ fn prepare_one_select_plan(
// Parse the actual WHERE clause and add its conditions to the plan WHERE clause that already contains the join conditions.
parse_where(
where_clause,
where_clause.as_deref(),
&mut plan.table_references,
Some(&plan.result_columns),
&mut plan.where_clause,
@@ -568,10 +571,10 @@ fn prepare_one_select_plan(
plan.group_by = Some(GroupBy {
sort_order: Some((0..group_by.exprs.len()).map(|_| SortOrder::Asc).collect()),
exprs: group_by.exprs,
exprs: group_by.exprs.iter().map(|expr| *expr.clone()).collect(),
having: if let Some(having) = group_by.having {
let mut predicates = vec![];
break_predicate_at_and_boundaries(*having, &mut predicates);
break_predicate_at_and_boundaries(&having, &mut predicates);
for expr in predicates.iter_mut() {
bind_column_references(
expr,
@@ -601,30 +604,25 @@ fn prepare_one_select_plan(
plan.aggregates = aggregate_expressions;
// Parse the ORDER BY clause
if let Some(order_by) = order_by {
let mut key = Vec::new();
let mut key = Vec::new();
for mut o in order_by {
replace_column_number_with_copy_of_column_expr(
&mut o.expr,
&plan.result_columns,
)?;
for mut o in order_by {
replace_column_number_with_copy_of_column_expr(&mut o.expr, &plan.result_columns)?;
bind_column_references(
&mut o.expr,
&mut plan.table_references,
Some(&plan.result_columns),
connection,
)?;
resolve_aggregates(schema, &o.expr, &mut plan.aggregates)?;
bind_column_references(
&mut o.expr,
&mut plan.table_references,
Some(&plan.result_columns),
connection,
)?;
resolve_aggregates(schema, &o.expr, &mut plan.aggregates)?;
key.push((o.expr, o.order.unwrap_or(ast::SortOrder::Asc)));
}
plan.order_by = Some(key);
key.push((o.expr, o.order.unwrap_or(ast::SortOrder::Asc)));
}
plan.order_by = key;
// Parse the LIMIT/OFFSET clause
(plan.limit, plan.offset) = limit.map_or(Ok((None, None)), parse_limit)?;
(plan.limit, plan.offset) = limit.as_ref().map_or(Ok((None, None)), parse_limit)?;
// Return the unoptimized query plan
Ok(plan)
@@ -646,14 +644,17 @@ fn prepare_one_select_plan(
result_columns,
where_clause: vec![],
group_by: None,
order_by: None,
order_by: vec![],
aggregates: vec![],
limit: None,
offset: None,
contains_constant_false_condition: false,
query_destination,
distinctness: Distinctness::NonDistinct,
values,
values: values
.iter()
.map(|values| values.iter().map(|value| *value.clone()).collect())
.collect(),
};
Ok(plan)
@@ -725,8 +726,8 @@ fn count_plan_required_cursors(plan: &SelectPlan) -> usize {
0
})
.sum();
let num_sorter_cursors = plan.group_by.is_some() as usize + plan.order_by.is_some() as usize;
let num_pseudo_cursors = plan.group_by.is_some() as usize + plan.order_by.is_some() as usize;
let num_sorter_cursors = plan.group_by.is_some() as usize + !plan.order_by.is_empty() as usize;
let num_pseudo_cursors = plan.group_by.is_some() as usize + !plan.order_by.is_empty() as usize;
num_table_cursors + num_sorter_cursors + num_pseudo_cursors
}
@@ -746,7 +747,7 @@ fn estimate_num_instructions(select: &SelectPlan) -> usize {
.sum();
let group_by_instructions = select.group_by.is_some() as usize * 10;
let order_by_instructions = select.order_by.is_some() as usize * 10;
let order_by_instructions = !select.order_by.is_empty() as usize * 10;
let condition_instructions = select.where_clause.len() * 3;
20 + table_instructions + group_by_instructions + order_by_instructions + condition_instructions
@@ -770,7 +771,7 @@ fn estimate_num_labels(select: &SelectPlan) -> usize {
+ 1;
let group_by_labels = select.group_by.is_some() as usize * 10;
let order_by_labels = select.order_by.is_some() as usize * 10;
let order_by_labels = !select.order_by.is_empty() as usize * 10;
let condition_labels = select.where_clause.len() * 2;
init_halt_labels + table_labels + group_by_labels + order_by_labels + condition_labels

View File

@@ -12,7 +12,7 @@ use crate::{
vdbe::builder::{ProgramBuilder, ProgramBuilderOpts},
SymbolTable,
};
use turso_parser::ast::{Expr, Indexed, SortOrder};
use turso_parser::ast::{self, Expr, Indexed, SortOrder};
use super::emitter::emit_program;
use super::expr::process_returning_clause;
@@ -54,7 +54,7 @@ addr opcode p1 p2 p3 p4 p5 comment
*/
pub fn translate_update(
schema: &Schema,
body: &mut Update,
body: &mut ast::Update,
syms: &SymbolTable,
mut program: ProgramBuilder,
connection: &Arc<crate::Connection>,
@@ -74,7 +74,7 @@ pub fn translate_update(
pub fn translate_update_for_schema_change(
schema: &Schema,
body: &mut Update,
body: &mut ast::Update,
syms: &SymbolTable,
mut program: ProgramBuilder,
connection: &Arc<crate::Connection>,
@@ -104,7 +104,7 @@ pub fn translate_update_for_schema_change(
pub fn prepare_update_plan(
program: &mut ProgramBuilder,
schema: &Schema,
body: &mut Update,
body: &mut ast::Update,
connection: &Arc<crate::Connection>,
) -> crate::Result<Plan> {
if body.with.is_some() {
@@ -134,13 +134,11 @@ pub fn prepare_update_plan(
};
let iter_dir = body
.order_by
.as_ref()
.and_then(|order_by| {
order_by.first().and_then(|ob| {
ob.order.map(|o| match o {
SortOrder::Asc => IterationDirection::Forwards,
SortOrder::Desc => IterationDirection::Backwards,
})
.first()
.and_then(|ob| {
ob.order.map(|o| match o {
SortOrder::Asc => IterationDirection::Forwards,
SortOrder::Desc => IterationDirection::Backwards,
})
})
.unwrap_or(IterationDirection::Forwards);
@@ -174,9 +172,9 @@ pub fn prepare_update_plan(
for set in &mut body.sets {
bind_column_references(&mut set.expr, &mut table_references, None, connection)?;
let values = match &set.expr {
let values = match set.expr.as_ref() {
Expr::Parenthesized(vals) => vals.clone(),
expr => vec![expr.clone()],
expr => vec![expr.clone().into()],
};
if set.col_names.len() != values.len() {
@@ -203,27 +201,19 @@ pub fn prepare_update_plan(
}
}
let (result_columns, _table_references) = if let Some(returning) = &mut body.returning {
process_returning_clause(
returning,
&table,
body.tbl_name.name.as_str(),
program,
connection,
)?
} else {
(
vec![],
crate::translate::plan::TableReferences::new(vec![], vec![]),
)
};
let (result_columns, _table_references) = process_returning_clause(
&mut body.returning,
&table,
body.tbl_name.name.as_str(),
program,
connection,
)?;
let order_by = body.order_by.as_ref().map(|order| {
order
.iter()
.map(|o| (o.expr.clone(), o.order.unwrap_or(SortOrder::Asc)))
.collect()
});
let order_by = body
.order_by
.iter()
.map(|o| (o.expr.clone(), o.order.unwrap_or(SortOrder::Asc)))
.collect();
// Sqlite determines we should create an ephemeral table if we do not have a FROM clause
// Difficult to say what items from the plan can be checked for this so currently just checking if a RowId Alias is referenced
@@ -256,7 +246,7 @@ pub fn prepare_update_plan(
// Parse the WHERE clause
parse_where(
body.where_clause.as_ref().map(|w| *w.clone()),
body.where_clause.as_deref(),
&mut table_references,
Some(&result_columns),
&mut where_clause,
@@ -298,7 +288,7 @@ pub fn prepare_update_plan(
}],
where_clause, // original WHERE terms from the UPDATE clause
group_by: None, // N/A
order_by: None, // N/A
order_by: vec![], // N/A
aggregates: vec![], // N/A
limit: None, // N/A
query_destination: QueryDestination::EphemeralTable {
@@ -331,7 +321,7 @@ pub fn prepare_update_plan(
if ephemeral_plan.is_none() {
// Parse the WHERE clause
parse_where(
body.where_clause.as_ref().map(|w| *w.clone()),
body.where_clause.as_deref(),
&mut table_references,
Some(&result_columns),
&mut where_clause,
@@ -340,11 +330,7 @@ pub fn prepare_update_plan(
};
// Parse the LIMIT/OFFSET clause
let (limit, offset) = body
.limit
.as_ref()
.map(|l| parse_limit(l))
.unwrap_or(Ok((None, None)))?;
let (limit, offset) = body.limit.as_ref().map_or(Ok((None, None)), parse_limit)?;
// Check what indexes will need to be updated by checking set_clauses and see
// if a column is contained in an index.

View File

@@ -120,7 +120,7 @@ pub fn translate_create_view(
schema: &Schema,
view_name: &str,
select_stmt: &ast::Select,
_columns: Option<&Vec<ast::IndexedColumn>>,
_columns: &[ast::IndexedColumn],
_connection: Arc<Connection>,
syms: &SymbolTable,
mut program: ProgramBuilder,

View File

@@ -175,7 +175,7 @@ pub fn parse_schema_rows(
// Parse the SQL to determine if it's a regular or materialized view
let mut parser = Parser::new(sql.as_bytes());
if let Ok(Some(Cmd::Stmt(stmt))) = parser.next() {
if let Ok(Some(Cmd::Stmt(stmt))) = parser.next_cmd() {
match stmt {
Stmt::CreateMaterializedView { .. } => {
// Handle materialized view with potential reuse
@@ -234,16 +234,15 @@ pub fn parse_schema_rows(
..
} => {
// Extract actual columns from the SELECT statement
let view_columns = extract_view_columns(&select, schema);
let view_columns =
crate::util::extract_view_columns(&select, schema);
// If column names were provided in CREATE VIEW (col1, col2, ...),
// use them to rename the columns
let mut final_columns = view_columns;
if let Some(ref names) = column_names {
for (i, indexed_col) in names.iter().enumerate() {
if let Some(col) = final_columns.get_mut(i) {
col.name = Some(indexed_col.col_name.to_string());
}
for (i, indexed_col) in column_names.iter().enumerate() {
if let Some(col) = final_columns.get_mut(i) {
col.name = Some(indexed_col.col_name.to_string());
}
}
@@ -251,8 +250,8 @@ pub fn parse_schema_rows(
let view = View {
name: name.to_string(),
sql: sql.to_string(),
select_stmt: *select,
columns: Some(final_columns),
select_stmt: select,
columns: final_columns,
};
schema.add_view(view);
}
@@ -509,7 +508,11 @@ pub fn exprs_are_equivalent(expr1: &Expr, expr2: &Expr) -> bool {
}
}
(Expr::Collate(expr1, collation1), Expr::Collate(expr2, collation2)) => {
exprs_are_equivalent(expr1, expr2) && collation1.eq_ignore_ascii_case(collation2)
// TODO: check correctness of comparing colation as strings
exprs_are_equivalent(expr1, expr2)
&& collation1
.as_str()
.eq_ignore_ascii_case(collation2.as_str())
}
(
Expr::FunctionCall {
@@ -544,26 +547,12 @@ pub fn exprs_are_equivalent(expr1: &Expr, expr2: &Expr) -> bool {
},
) => {
name1.as_str().eq_ignore_ascii_case(name2.as_str())
&& match (filter1, filter2) {
&& match (&filter1.filter_clause, &filter2.filter_clause) {
(Some(expr1), Some(expr2)) => exprs_are_equivalent(expr1, expr2),
(None, None) => true,
(
Some(FunctionTail {
filter_clause: fc1,
over_clause: oc1,
}),
Some(FunctionTail {
filter_clause: fc2,
over_clause: oc2,
}),
) => match ((fc1, fc2), (oc1, oc2)) {
((Some(fc1), Some(fc2)), (Some(oc1), Some(oc2))) => {
exprs_are_equivalent(fc1, fc2) && oc1 == oc2
}
((Some(fc1), Some(fc2)), _) => exprs_are_equivalent(fc1, fc2),
_ => false,
},
_ => false,
}
&& filter1.over_clause == filter2.over_clause
}
(Expr::NotNull(expr1), Expr::NotNull(expr2)) => exprs_are_equivalent(expr1, expr2),
(Expr::IsNull(expr1), Expr::IsNull(expr2)) => exprs_are_equivalent(expr1, expr2),
@@ -610,17 +599,11 @@ pub fn exprs_are_equivalent(expr1: &Expr, expr2: &Expr) -> bool {
) => {
*not1 == *not2
&& exprs_are_equivalent(lhs1, lhs2)
&& rhs1.len() == rhs2.len()
&& rhs1
.as_ref()
.zip(rhs2.as_ref())
.map(|(list1, list2)| {
list1.len() == list2.len()
&& list1
.iter()
.zip(list2)
.all(|(e1, e2)| exprs_are_equivalent(e1, e2))
})
.unwrap_or(false)
.iter()
.zip(rhs2.iter())
.all(|(a, b)| exprs_are_equivalent(a, b))
}
// fall back to naive equality check
_ => expr1 == expr2,
@@ -639,63 +622,58 @@ pub fn columns_from_create_table_body(
use turso_parser::ast;
Ok(columns
.into_iter()
.map(|(name, column_def)| {
Column {
name: Some(normalize_ident(name.as_str())),
ty: match column_def.col_type {
Some(ref data_type) => {
// https://www.sqlite.org/datatype3.html
let type_name = data_type.name.as_str().to_uppercase();
if type_name.contains("INT") {
Type::Integer
} else if type_name.contains("CHAR")
|| type_name.contains("CLOB")
|| type_name.contains("TEXT")
{
Type::Text
} else if type_name.contains("BLOB") || type_name.is_empty() {
Type::Blob
} else if type_name.contains("REAL")
|| type_name.contains("FLOA")
|| type_name.contains("DOUB")
{
Type::Real
} else {
Type::Numeric
.iter()
.map(
|ast::ColumnDefinition {
col_name: name,
col_type,
constraints,
}| {
Column {
name: Some(normalize_ident(name.as_str())),
ty: match col_type {
Some(ref data_type) => {
// https://www.sqlite.org/datatype3.html
let type_name = data_type.name.as_str().to_uppercase();
if type_name.contains("INT") {
Type::Integer
} else if type_name.contains("CHAR")
|| type_name.contains("CLOB")
|| type_name.contains("TEXT")
{
Type::Text
} else if type_name.contains("BLOB") || type_name.is_empty() {
Type::Blob
} else if type_name.contains("REAL")
|| type_name.contains("FLOA")
|| type_name.contains("DOUB")
{
Type::Real
} else {
Type::Numeric
}
}
}
None => Type::Null,
},
default: column_def
.constraints
.iter()
.find_map(|c| match &c.constraint {
None => Type::Null,
},
default: constraints.iter().find_map(|c| match &c.constraint {
ast::ColumnConstraint::Default(val) => Some(val.clone()),
_ => None,
}),
notnull: column_def
.constraints
.iter()
.any(|c| matches!(c.constraint, ast::ColumnConstraint::NotNull { .. })),
ty_str: column_def
.col_type
.clone()
.map(|t| t.name.to_string())
.unwrap_or_default(),
primary_key: column_def
.constraints
.iter()
.any(|c| matches!(c.constraint, ast::ColumnConstraint::PrimaryKey { .. })),
is_rowid_alias: false,
unique: column_def
.constraints
.iter()
.any(|c| matches!(c.constraint, ast::ColumnConstraint::Unique(..))),
collation: column_def
.constraints
.iter()
.find_map(|c| match &c.constraint {
notnull: constraints
.iter()
.any(|c| matches!(c.constraint, ast::ColumnConstraint::NotNull { .. })),
ty_str: col_type
.clone()
.map(|t| t.name.to_string())
.unwrap_or_default(),
primary_key: constraints
.iter()
.any(|c| matches!(c.constraint, ast::ColumnConstraint::PrimaryKey { .. })),
is_rowid_alias: false,
unique: constraints
.iter()
.any(|c| matches!(c.constraint, ast::ColumnConstraint::Unique(..))),
collation: constraints.iter().find_map(|c| match &c.constraint {
// TODO: see if this should be the correct behavior
// currently there cannot be any user defined collation sequences.
// But in the future, when a user defines a collation sequence, creates a table with it,
@@ -707,13 +685,13 @@ pub fn columns_from_create_table_body(
),
_ => None,
}),
hidden: column_def
.col_type
.as_ref()
.map(|data_type| data_type.name.as_str().contains("HIDDEN"))
.unwrap_or(false),
}
})
hidden: col_type
.as_ref()
.map(|data_type| data_type.name.as_str().contains("HIDDEN"))
.unwrap_or(false),
}
},
)
.collect::<Vec<_>>())
}
@@ -735,10 +713,7 @@ pub fn can_pushdown_predicate(
can_pushdown &= join_idx <= table_idx;
}
Expr::FunctionCall { args, name, .. } => {
let function = crate::function::Func::resolve_function(
name.as_str(),
args.as_ref().map_or(0, |a| a.len()),
)?;
let function = crate::function::Func::resolve_function(name.as_str(), args.len())?;
// is deterministic
can_pushdown &= function.is_deterministic();
}
@@ -1216,8 +1191,8 @@ pub fn parse_pragma_bool(expr: &Expr) -> Result<bool> {
}
/// Extract column name from an expression (e.g., for SELECT clauses)
pub fn extract_column_name_from_expr(expr: &ast::Expr) -> Option<String> {
match expr {
pub fn extract_column_name_from_expr(expr: impl AsRef<ast::Expr>) -> Option<String> {
match expr.as_ref() {
ast::Expr::Id(name) => Some(name.as_str().to_string()),
ast::Expr::Qualified(_, name) => Some(name.as_str().to_string()),
_ => None,
@@ -1228,10 +1203,15 @@ pub fn extract_column_name_from_expr(expr: &ast::Expr) -> Option<String> {
pub fn extract_view_columns(select_stmt: &ast::Select, schema: &Schema) -> Vec<Column> {
let mut columns = Vec::new();
// Navigate to the first SELECT in the statement
if let ast::OneSelect::Select(select_core) = select_stmt.body.select.as_ref() {
if let ast::OneSelect::Select {
ref from,
columns: select_columns,
..
} = &select_stmt.body.select
{
// First, we need to figure out which table(s) are being selected from
let table_name = if let Some(from) = &select_core.from {
if let Some(ast::SelectTable::Table(qualified_name, _, _)) = from.select.as_deref() {
let table_name = if let Some(from) = from {
if let ast::SelectTable::Table(qualified_name, _, _) = from.select.as_ref() {
Some(normalize_ident(qualified_name.name.as_str()))
} else {
None
@@ -1242,7 +1222,7 @@ pub fn extract_view_columns(select_stmt: &ast::Select, schema: &Schema) -> Vec<C
// Get the table for column resolution
let _table = table_name.as_ref().and_then(|name| schema.get_table(name));
// Process each column in the SELECT list
for (i, result_col) in select_core.columns.iter().enumerate() {
for (i, result_col) in select_columns.iter().enumerate() {
match result_col {
ast::ResultColumn::Expr(expr, alias) => {
let name = alias
@@ -1456,25 +1436,34 @@ pub mod tests {
let func1 = Expr::FunctionCall {
name: Name::Ident("SUM".to_string()),
distinctness: None,
args: Some(vec![Expr::Id(Name::Ident("x".to_string()))]),
order_by: None,
filter_over: None,
args: vec![Expr::Id(Name::Ident("x".to_string())).into()],
order_by: vec![],
filter_over: FunctionTail {
filter_clause: None,
over_clause: None,
},
};
let func2 = Expr::FunctionCall {
name: Name::Ident("sum".to_string()),
distinctness: None,
args: Some(vec![Expr::Id(Name::Ident("x".to_string()))]),
order_by: None,
filter_over: None,
args: vec![Expr::Id(Name::Ident("x".to_string())).into()],
order_by: vec![],
filter_over: FunctionTail {
filter_clause: None,
over_clause: None,
},
};
assert!(exprs_are_equivalent(&func1, &func2));
let func3 = Expr::FunctionCall {
name: Name::Ident("SUM".to_string()),
distinctness: Some(ast::Distinctness::Distinct),
args: Some(vec![Expr::Id(Name::Ident("x".to_string()))]),
order_by: None,
filter_over: None,
args: vec![Expr::Id(Name::Ident("x".to_string())).into()],
order_by: vec![],
filter_over: FunctionTail {
filter_clause: None,
over_clause: None,
},
};
assert!(!exprs_are_equivalent(&func1, &func3));
}
@@ -1484,16 +1473,22 @@ pub mod tests {
let sum = Expr::FunctionCall {
name: Name::Ident("SUM".to_string()),
distinctness: None,
args: Some(vec![Expr::Id(Name::Ident("x".to_string()))]),
order_by: None,
filter_over: None,
args: vec![Expr::Id(Name::Ident("x".to_string())).into()],
order_by: vec![],
filter_over: FunctionTail {
filter_clause: None,
over_clause: None,
},
};
let sum_distinct = Expr::FunctionCall {
name: Name::Ident("SUM".to_string()),
distinctness: Some(ast::Distinctness::Distinct),
args: Some(vec![Expr::Id(Name::Ident("x".to_string()))]),
order_by: None,
filter_over: None,
args: vec![Expr::Id(Name::Ident("x".to_string())).into()],
order_by: vec![],
filter_over: FunctionTail {
filter_clause: None,
over_clause: None,
},
};
assert!(!exprs_are_equivalent(&sum, &sum_distinct));
}
@@ -1519,7 +1514,8 @@ pub mod tests {
Box::new(Expr::Literal(Literal::Numeric("683".to_string()))),
Add,
Box::new(Expr::Literal(Literal::Numeric("799.0".to_string()))),
)]);
)
.into()]);
let expr2 = Expr::Binary(
Box::new(Expr::Literal(Literal::Numeric("799".to_string()))),
Add,
@@ -1533,7 +1529,8 @@ pub mod tests {
Box::new(Expr::Literal(Literal::Numeric("6".to_string()))),
Add,
Box::new(Expr::Literal(Literal::Numeric("7".to_string()))),
)]);
)
.into()]);
let expr8 = Expr::Binary(
Box::new(Expr::Literal(Literal::Numeric("6".to_string()))),
Add,

View File

@@ -1,7 +1,7 @@
use std::{cell::Cell, cmp::Ordering, sync::Arc};
use tracing::{instrument, Level};
use turso_sqlite3_parser::ast::{self, TableInternalId};
use turso_parser::ast::{self, TableInternalId};
use crate::{
numeric::Numeric,
@@ -17,7 +17,7 @@ use crate::{
#[derive(Default)]
pub struct TableRefIdCounter {
next_free: TableInternalId,
next_free: ast::TableInternalId,
}
impl TableRefIdCounter {
@@ -868,7 +868,7 @@ impl ProgramBuilder {
_ => break 'value None,
};
let Some(ast::Expr::Literal(ref literal)) = default else {
let Some(ast::Expr::Literal(ref literal)) = default.as_ref().map(|v| v.as_ref()) else {
break 'value None;
};

View File

@@ -68,7 +68,6 @@ use super::{
insn::{Cookie, RegisterOrLiteral},
CommitState,
};
use fallible_iterator::FallibleIterator;
use parking_lot::RwLock;
use rand::{thread_rng, Rng};
use turso_parser::ast;
@@ -4866,7 +4865,7 @@ pub fn op_function(
unique,
if_not_exists,
idx_name,
tbl_name: ast::Name::from_str(&rename_to),
tbl_name: ast::Name::new(&rename_to),
columns,
where_clause,
}
@@ -4892,7 +4891,7 @@ pub fn op_function(
if_not_exists,
tbl_name: ast::QualifiedName {
db_name: None,
name: ast::Name::from_str(&rename_to),
name: ast::Name::new(&rename_to),
alias: None,
},
body,
@@ -4957,7 +4956,7 @@ pub fn op_function(
}
for column in &mut columns {
match &mut column.expr {
match column.expr.as_mut() {
ast::Expr::Id(ast::Name::Ident(id))
if normalize_ident(id) == rename_from =>
{
@@ -4994,43 +4993,28 @@ pub fn op_function(
mut columns,
constraints,
options,
} = *body
} = body
else {
todo!()
};
let column_index = columns
.get_index_of(&ast::Name::from_str(&rename_from))
let column = columns
.iter_mut()
.find(|column| column.col_name == ast::Name::new(&rename_from))
.expect("column being renamed should be present");
let mut column_definition =
columns.get_index(column_index).unwrap().1.clone();
column_definition.col_name = ast::Name::from_str(&rename_to);
assert!(columns
.insert(
ast::Name::from_str(&rename_to),
column_definition.clone()
)
.is_none());
// Swaps indexes with the last one and pops the end, effectively
// replacing the entry.
columns.swap_remove_index(column_index).unwrap();
column.col_name = ast::Name::new(&rename_to);
Some(
ast::Stmt::CreateTable {
temporary,
if_not_exists,
tbl_name,
body: Box::new(
ast::CreateTableBody::ColumnsAndConstraints {
columns,
constraints,
options,
},
),
body: ast::CreateTableBody::ColumnsAndConstraints {
columns,
constraints,
options,
},
}
.format()
.unwrap(),

View File

@@ -1,4 +1,4 @@
use turso_sqlite3_parser::ast::SortOrder;
use turso_parser::ast::SortOrder;
use crate::vdbe::{builder::CursorType, insn::RegisterOrLiteral};

View File

@@ -11,7 +11,7 @@ use crate::{
Value,
};
use turso_macros::Description;
use turso_sqlite3_parser::ast::SortOrder;
use turso_parser::ast::SortOrder;
/// Flags provided to comparison instructions (e.g. Eq, Ne) which determine behavior related to NULL values.
#[derive(Clone, Copy, Debug, Default)]

View File

@@ -1,4 +1,4 @@
use turso_sqlite3_parser::ast::SortOrder;
use turso_parser::ast::SortOrder;
use std::cell::{Cell, RefCell};
use std::cmp::{Eq, Ord, Ordering, PartialEq, PartialOrd, Reverse};

View File

@@ -2,7 +2,7 @@ use crate::pragma::{PragmaVirtualTable, PragmaVirtualTableCursor};
use crate::schema::Column;
use crate::util::columns_from_create_table_body;
use crate::{Connection, LimboError, SymbolTable, Value};
use fallible_iterator::FallibleIterator;
use std::ffi::c_void;
use std::ptr::NonNull;
use std::rc::Rc;
@@ -105,7 +105,7 @@ impl VirtualTable {
fn resolve_columns(schema: String) -> crate::Result<Vec<Column>> {
let mut parser = Parser::new(schema.as_bytes());
if let ast::Cmd::Stmt(ast::Stmt::CreateTable { body, .. }) = parser.next()?.ok_or(
if let ast::Cmd::Stmt(ast::Stmt::CreateTable { body, .. }) = parser.next_cmd()?.ok_or(
LimboError::ParseError("Failed to parse schema from virtual table module".to_string()),
)? {
columns_from_create_table_body(&body)