diff --git a/core/schema.rs b/core/schema.rs index 49cacd67b..7e7fc953e 100644 --- a/core/schema.rs +++ b/core/schema.rs @@ -1,4 +1,8 @@ +use crate::function::Func; use crate::incremental::view::IncrementalView; +use crate::translate::emitter::Resolver; +use crate::translate::expr::{bind_and_rewrite_expr, walk_expr, ParamState, WalkControl}; +use crate::translate::optimizer::Optimizable; use parking_lot::RwLock; /// Simple view structure for non-materialized views @@ -15,13 +19,13 @@ pub type ViewsMap = HashMap; use crate::storage::btree::BTreeCursor; use crate::translate::collate::CollationSeq; -use crate::translate::plan::SelectPlan; +use crate::translate::plan::{SelectPlan, TableReferences}; use crate::util::{ module_args_from_sql, module_name_from_sql, type_from_name, IOExt, UnparsedFromSqlIndex, }; use crate::{ - contains_ignore_ascii_case, eq_ignore_ascii_case, match_ignore_ascii_case, LimboError, - MvCursor, MvStore, Pager, RefValue, SymbolTable, VirtualTable, + contains_ignore_ascii_case, eq_ignore_ascii_case, match_ignore_ascii_case, Connection, + LimboError, MvCursor, MvStore, Pager, RefValue, SymbolTable, VirtualTable, }; use crate::{util::normalize_ident, Result}; use core::fmt; @@ -32,7 +36,7 @@ use std::sync::Mutex; use tracing::trace; use turso_parser::ast::{self, ColumnDefinition, Expr, Literal, SortOrder, TableOptions}; use turso_parser::{ - ast::{Cmd, CreateTableBody, ResultColumn, Stmt}, + ast::{Cmd, CreateTableBody, Name, ResultColumn, Stmt}, parser::Parser, }; @@ -1750,6 +1754,102 @@ impl Index { .iter() .position(|c| c.pos_in_table == table_pos) } + + /// Walk the where_clause Expr of a partial index and validate that it doesn't reference any other + /// tables or use any disallowed constructs. + pub fn validate_where_expr(&self, table: &Table) -> bool { + let Some(where_clause) = &self.where_clause else { + return true; + }; + + let tbl_norm = normalize_ident(self.table_name.as_str()); + let has_col = |name: &str| { + let n = normalize_ident(name); + table + .columns() + .iter() + .any(|c| c.name.as_ref().is_some_and(|cn| normalize_ident(cn) == n)) + }; + let is_tbl = |ns: &str| normalize_ident(ns).eq_ignore_ascii_case(&tbl_norm); + let is_deterministic_fn = |name: &str, argc: usize| { + let n = normalize_ident(name); + Func::resolve_function(&n, argc).is_ok_and(|f| f.is_deterministic()) + }; + + let mut ok = true; + let _ = walk_expr(where_clause.as_ref(), &mut |e: &Expr| -> crate::Result< + WalkControl, + > { + if !ok { + return Ok(WalkControl::SkipChildren); + } + match e { + Expr::Literal(_) | Expr::RowId { .. } => {} + // Unqualified identifier: must be a column of the target table or ROWID + Expr::Id(Name::Ident(n)) | Expr::Id(Name::Quoted(n)) => { + let n = n.as_str(); + if !n.eq_ignore_ascii_case("rowid") && !has_col(n) { + ok = false; + } + } + // Qualified: qualifier must match this index's table; column must exist + Expr::Qualified(ns, col) | Expr::DoublyQualified(_, ns, col) => { + if !is_tbl(ns.as_str()) || !has_col(col.as_str()) { + ok = false; + } + } + Expr::FunctionCall { + name, filter_over, .. + } + | Expr::FunctionCallStar { + name, filter_over, .. + } => { + // reject windowed + if filter_over.over_clause.is_some() { + ok = false; + } else { + let argc = match e { + Expr::FunctionCall { args, .. } => args.len(), + Expr::FunctionCallStar { .. } => 0, + _ => unreachable!(), + }; + if !is_deterministic_fn(name.as_str(), argc) { + ok = false; + } + } + } + // Explicitly disallowed constructs + Expr::Exists(_) + | Expr::InSelect { .. } + | Expr::Subquery(_) + | Expr::Raise { .. } + | Expr::Variable(_) => { + ok = false; + } + _ => {} + } + Ok(if ok { + WalkControl::Continue + } else { + WalkControl::SkipChildren + }) + }); + ok + } + + pub fn bind_where_expr( + &self, + table_refs: Option<&mut TableReferences>, + connection: &Arc, + ) -> Option { + let Some(where_clause) = &self.where_clause else { + return None; + }; + let mut params = ParamState::disallow(); + let mut expr = where_clause.clone(); + bind_and_rewrite_expr(&mut expr, table_refs, None, connection, &mut params).ok()?; + Some(*expr) + } } #[cfg(test)]