diff --git a/core/ext/mod.rs b/core/ext/mod.rs index 7e5cb375a..760291613 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -97,7 +97,7 @@ pub(crate) unsafe extern "C" fn register_scalar_function( unsafe { (*ext_ctx.syms).functions.insert( name_str.clone(), - Rc::new(ExternalFunc::new_scalar(name_str, func)), + Arc::new(ExternalFunc::new_scalar(name_str, func)), ); } ResultCode::OK @@ -123,7 +123,7 @@ pub(crate) unsafe extern "C" fn register_aggregate_function( unsafe { (*ext_ctx.syms).functions.insert( name_str.clone(), - Rc::new(ExternalFunc::new_aggregate( + Arc::new(ExternalFunc::new_aggregate( name_str, args, (init_func, step_func, finalize_func), diff --git a/core/function.rs b/core/function.rs index 3f2d2a71b..c5b3f595a 100644 --- a/core/function.rs +++ b/core/function.rs @@ -1,6 +1,7 @@ use std::fmt; use std::fmt::{Debug, Display}; use std::rc::Rc; +use std::sync::Arc; use turso_ext::{FinalizeFunction, InitAggFunction, ScalarFunction, StepFunction}; use crate::LimboError; @@ -593,7 +594,7 @@ pub enum Func { #[cfg(feature = "json")] Json(JsonFunc), AlterTable(AlterTableFunc), - External(Rc), + External(Arc), } impl Display for Func { diff --git a/core/lib.rs b/core/lib.rs index 4b7b40361..d1e17b2f3 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -1676,8 +1676,8 @@ pub type StepResult = vdbe::StepResult; #[derive(Default)] pub struct SymbolTable { - pub functions: HashMap>, - pub vtabs: HashMap>, + pub functions: HashMap>, + pub vtabs: HashMap>, pub vtab_modules: HashMap>, } @@ -1726,7 +1726,7 @@ impl SymbolTable { &self, name: &str, _arg_count: usize, - ) -> Option> { + ) -> Option> { self.functions.get(name).cloned() } diff --git a/core/schema.rs b/core/schema.rs index 5fbabad35..68a605c60 100644 --- a/core/schema.rs +++ b/core/schema.rs @@ -10,6 +10,7 @@ use fallible_iterator::FallibleIterator; use std::cell::RefCell; use std::collections::hash_map::Entry; use std::collections::{BTreeSet, HashMap}; +use std::ops::Deref; use std::rc::Rc; use std::sync::Arc; use tracing::trace; @@ -22,7 +23,7 @@ use turso_sqlite3_parser::{ const SCHEMA_TABLE_NAME: &str = "sqlite_schema"; const SCHEMA_TABLE_NAME_ALT: &str = "sqlite_master"; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct Schema { pub tables: HashMap>, /// table_name to list of indexes for the table @@ -43,7 +44,10 @@ impl Schema { Arc::new(Table::BTree(sqlite_schema_table().into())), ); for function in VirtualTable::builtin_functions() { - tables.insert(function.name.to_owned(), Arc::new(Table::Virtual(function))); + tables.insert( + function.name.to_owned(), + Arc::new(Table::Virtual(Arc::new((*function).clone()))), + ); } Self { tables, @@ -61,12 +65,12 @@ impl Schema { .any(|idx| idx.1.iter().any(|i| i.name == name)) } - pub fn add_btree_table(&mut self, table: Rc) { + pub fn add_btree_table(&mut self, table: Arc) { let name = normalize_ident(&table.name); self.tables.insert(name, Table::BTree(table).into()); } - pub fn add_virtual_table(&mut self, table: Rc) { + pub fn add_virtual_table(&mut self, table: Arc) { let name = normalize_ident(&table.name); self.tables.insert(name, Table::Virtual(table).into()); } @@ -86,7 +90,7 @@ impl Schema { self.tables.remove(&name); } - pub fn get_btree_table(&self, name: &str) -> Option> { + pub fn get_btree_table(&self, name: &str) -> Option> { let name = normalize_ident(name); if let Some(table) = self.tables.get(&name) { table.btree() @@ -197,22 +201,23 @@ impl Schema { // longer in the in-memory schema. We need to recreate it if // the module is loaded in the symbol table. let vtab = if let Some(vtab) = syms.vtabs.get(name) { - vtab.clone() + Arc::new((**vtab).clone()) } else { let mod_name = module_name_from_sql(sql)?; - crate::VirtualTable::table( + let vtab_rc = crate::VirtualTable::table( Some(name), mod_name, module_args_from_sql(sql)?, syms, - )? + )?; + Arc::new((*vtab_rc).clone()) }; self.add_virtual_table(vtab); continue; } let table = BTreeTable::from_sql(sql, root_page as usize)?; - self.add_btree_table(Rc::new(table)); + self.add_btree_table(Arc::new(table)); } "index" => { let root_page_value = record_cursor.get_value(&row, 3)?; @@ -311,10 +316,61 @@ impl Schema { } } +impl Clone for Schema { + /// Cloning a `Schema` requires deep cloning of all internal tables and indexes, even though they are wrapped in `Arc`. + /// Simply copying the `Arc` pointers would result in multiple `Schema` instances sharing the same underlying tables and indexes, + /// which could lead to panics or data races if any instance attempts to modify them. + /// To ensure each `Schema` is independent and safe to modify, we clone the underlying data for all tables and indexes. + fn clone(&self) -> Self { + let tables = self + .tables + .iter() + .map(|(name, table)| match table.deref() { + Table::BTree(table) => { + let table = Arc::deref(table); + ( + name.clone(), + Arc::new(Table::BTree(Arc::new(table.clone()))), + ) + } + Table::Virtual(table) => { + let table = Arc::deref(table); + ( + name.clone(), + Arc::new(Table::Virtual(Arc::new(table.clone()))), + ) + } + Table::FromClauseSubquery(from_clause_subquery) => ( + name.clone(), + Arc::new(Table::FromClauseSubquery(from_clause_subquery.clone())), + ), + }) + .collect(); + let indexes = self + .indexes + .iter() + .map(|(name, indexes)| { + let indexes = indexes + .iter() + .map(|index| Arc::new((**index).clone())) + .collect(); + (name.clone(), indexes) + }) + .collect(); + Self { + tables, + indexes, + has_indexes: self.has_indexes.clone(), + indexes_enabled: self.indexes_enabled, + schema_version: self.schema_version, + } + } +} + #[derive(Clone, Debug)] pub enum Table { - BTree(Rc), - Virtual(Rc), + BTree(Arc), + Virtual(Arc), FromClauseSubquery(FromClauseSubquery), } @@ -353,7 +409,7 @@ impl Table { } } - pub fn btree(&self) -> Option> { + pub fn btree(&self) -> Option> { match self { Self::BTree(table) => Some(table.clone()), Self::Virtual(_) => None, @@ -361,7 +417,7 @@ impl Table { } } - pub fn virtual_table(&self) -> Option> { + pub fn virtual_table(&self) -> Option> { match self { Self::Virtual(table) => Some(table.clone()), _ => None, @@ -372,8 +428,8 @@ impl Table { impl PartialEq for Table { fn eq(&self, other: &Self) -> bool { match (self, other) { - (Self::BTree(a), Self::BTree(b)) => Rc::ptr_eq(a, b), - (Self::Virtual(a), Self::Virtual(b)) => Rc::ptr_eq(a, b), + (Self::BTree(a), Self::BTree(b)) => Arc::ptr_eq(a, b), + (Self::Virtual(a), Self::Virtual(b)) => Arc::ptr_eq(a, b), _ => false, } } @@ -1135,7 +1191,7 @@ pub fn sqlite_schema_table() -> BTreeTable { } #[allow(dead_code)] -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Index { pub name: String, pub table_name: String, diff --git a/core/translate/emitter.rs b/core/translate/emitter.rs index 552317067..288ab3a80 100644 --- a/core/translate/emitter.rs +++ b/core/translate/emitter.rs @@ -1,7 +1,7 @@ // This module contains code for emitting bytecode instructions for SQL query execution. // It handles translating high-level SQL operations into low-level bytecode that can be executed by the virtual machine. -use std::rc::Rc; +use std::sync::Arc; use tracing::{instrument, Level}; use turso_sqlite3_parser::ast::{self, Expr}; @@ -1055,7 +1055,7 @@ fn emit_update_insns( start_reg: start, count: table_ref.columns().len(), check_generated: true, - table_reference: Rc::clone(&btree_table), + table_reference: Arc::clone(&btree_table), }); } diff --git a/core/translate/insert.rs b/core/translate/insert.rs index 4b1cd9aef..acbf10649 100644 --- a/core/translate/insert.rs +++ b/core/translate/insert.rs @@ -1,4 +1,3 @@ -use std::rc::Rc; use std::sync::Arc; use turso_sqlite3_parser::ast::{ @@ -439,7 +438,7 @@ pub fn translate_insert( start_reg: columns_start_register, count: num_cols, check_generated: true, - table_reference: Rc::clone(&t), + table_reference: Arc::clone(&t), }); } _ => (), @@ -962,7 +961,7 @@ fn populate_column_registers( // TODO: comeback here later to apply the same improvements on select fn translate_virtual_table_insert( mut program: ProgramBuilder, - virtual_table: Rc, + virtual_table: Arc, columns: Option, mut body: InsertBody, on_conflict: Option, diff --git a/core/translate/optimizer/join.rs b/core/translate/optimizer/join.rs index b46722d9d..078bee124 100644 --- a/core/translate/optimizer/join.rs +++ b/core/translate/optimizer/join.rs @@ -490,7 +490,7 @@ fn generate_join_bitmasks(table_number_max_exclusive: usize, how_many: usize) -> #[cfg(test)] mod tests { - use std::{cell::Cell, rc::Rc, sync::Arc}; + use std::{cell::Cell, sync::Arc}; use turso_sqlite3_parser::ast::{self, Expr, Operator, SortOrder, TableInternalId}; @@ -1640,8 +1640,8 @@ mod tests { } /// Creates a BTreeTable with the given name and columns - fn _create_btree_table(name: &str, columns: Vec) -> Rc { - Rc::new(BTreeTable { + fn _create_btree_table(name: &str, columns: Vec) -> Arc { + Arc::new(BTreeTable { root_page: 1, // Page number doesn't matter for tests name: name.to_string(), primary_key_columns: vec![], @@ -1654,7 +1654,7 @@ mod tests { /// Creates a TableReference for a BTreeTable fn _create_table_reference( - table: Rc, + table: Arc, join_info: Option, internal_id: TableInternalId, ) -> JoinedTable { diff --git a/core/translate/plan.rs b/core/translate/plan.rs index acf22aa79..663eb60e6 100644 --- a/core/translate/plan.rs +++ b/core/translate/plan.rs @@ -1,4 +1,4 @@ -use std::{cell::Cell, cmp::Ordering, rc::Rc, sync::Arc}; +use std::{cell::Cell, cmp::Ordering, sync::Arc}; use turso_ext::{ConstraintInfo, ConstraintOp}; use turso_sqlite3_parser::ast::{self, SortOrder}; @@ -333,7 +333,7 @@ pub enum QueryDestination { /// The cursor ID of the ephemeral table that will be used to store the results. cursor_id: CursorID, /// The table that will be used to store the results. - table: Rc, + table: Arc, }, } @@ -887,13 +887,13 @@ impl Operation { impl JoinedTable { /// Returns the btree table for this table reference, if it is a BTreeTable. - pub fn btree(&self) -> Option> { + pub fn btree(&self) -> Option> { match &self.table { Table::BTree(_) => self.table.btree(), _ => None, } } - pub fn virtual_table(&self) -> Option> { + pub fn virtual_table(&self) -> Option> { match &self.table { Table::Virtual(_) => self.table.virtual_table(), _ => None, diff --git a/core/translate/schema.rs b/core/translate/schema.rs index a0b05e09f..9ee88125d 100644 --- a/core/translate/schema.rs +++ b/core/translate/schema.rs @@ -1,6 +1,7 @@ use std::collections::HashSet; use std::ops::Range; use std::rc::Rc; +use std::sync::Arc; use crate::ast; use crate::ext::VTabImpl; @@ -772,7 +773,7 @@ pub fn translate_drop_table( // cursor id 1 let sqlite_schema_cursor_id_1 = program.alloc_cursor_id(CursorType::BTreeTable(schema_table.clone())); - let simple_table_rc = Rc::new(BTreeTable { + let simple_table_rc = Arc::new(BTreeTable { root_page: 0, // Not relevant for ephemeral table definition name: "ephemeral_scratch".to_string(), has_rowid: true, diff --git a/core/translate/update.rs b/core/translate/update.rs index 3cd8f279c..72a0d29b3 100644 --- a/core/translate/update.rs +++ b/core/translate/update.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::rc::Rc; use std::sync::Arc; use crate::schema::{BTreeTable, Column, Type}; @@ -235,7 +234,7 @@ pub fn prepare_update_plan( connection, )?; - let table = Rc::new(BTreeTable { + let table = Arc::new(BTreeTable { root_page: 0, // Not relevant for ephemeral table definition name: "ephemeral_scratch".to_string(), has_rowid: true, diff --git a/core/util.rs b/core/util.rs index 7697784ef..961288afe 100644 --- a/core/util.rs +++ b/core/util.rs @@ -111,7 +111,7 @@ pub fn parse_schema_rows( schema.add_virtual_table(vtab); } else { let table = schema::BTreeTable::from_sql(sql, root_page as usize)?; - schema.add_btree_table(Rc::new(table)); + schema.add_btree_table(Arc::new(table)); } } "index" => { diff --git a/core/vdbe/builder.rs b/core/vdbe/builder.rs index 2d8b77e26..d684472d8 100644 --- a/core/vdbe/builder.rs +++ b/core/vdbe/builder.rs @@ -1,4 +1,4 @@ -use std::{cell::Cell, cmp::Ordering, rc::Rc, sync::Arc}; +use std::{cell::Cell, cmp::Ordering, sync::Arc}; use tracing::{instrument, Level}; use turso_sqlite3_parser::ast::{self, TableInternalId}; @@ -115,11 +115,11 @@ pub struct ProgramBuilder { #[derive(Debug, Clone)] pub enum CursorType { - BTreeTable(Rc), + BTreeTable(Arc), BTreeIndex(Arc), Pseudo(PseudoCursorType), Sorter, - VirtualTable(Rc), + VirtualTable(Arc), } impl CursorType { diff --git a/core/vdbe/insn.rs b/core/vdbe/insn.rs index 5ccfaa0c0..92dc66a15 100644 --- a/core/vdbe/insn.rs +++ b/core/vdbe/insn.rs @@ -1,6 +1,5 @@ use std::{ num::{NonZero, NonZeroUsize}, - rc::Rc, sync::Arc, }; @@ -417,7 +416,7 @@ pub enum Insn { /// GENERATED ALWAYS AS ... STATIC columns are only checked if P3 is zero. /// When P3 is non-zero, no type checking occurs for static generated columns. check_generated: bool, // P3 - table_reference: Rc, // P4 + table_reference: Arc, // P4 }, // Make a record and write it to destination register. diff --git a/core/vtab.rs b/core/vtab.rs index 008f1183a..6d92ac7f1 100644 --- a/core/vtab.rs +++ b/core/vtab.rs @@ -25,14 +25,14 @@ pub struct VirtualTable { } impl VirtualTable { - pub(crate) fn readonly(self: &Rc) -> bool { + pub(crate) fn readonly(self: &Arc) -> bool { match &self.vtab_type { VirtualTableType::Pragma(_) => true, VirtualTableType::External(table) => table.readonly(), } } - pub(crate) fn builtin_functions() -> Vec> { + pub(crate) fn builtin_functions() -> Vec> { PragmaVirtualTable::functions() .into_iter() .map(|(tab, schema)| { @@ -43,12 +43,12 @@ impl VirtualTable { kind: VTabKind::TableValuedFunction, vtab_type: VirtualTableType::Pragma(tab), }; - Rc::new(vtab) + Arc::new(vtab) }) .collect() } - pub(crate) fn function(name: &str, syms: &SymbolTable) -> crate::Result> { + pub(crate) fn function(name: &str, syms: &SymbolTable) -> crate::Result> { let module = syms.vtab_modules.get(name); let (vtab_type, schema) = if module.is_some() { ExtVirtualTable::create(name, module, Vec::new(), VTabKind::TableValuedFunction) @@ -65,7 +65,7 @@ impl VirtualTable { kind: VTabKind::TableValuedFunction, vtab_type, }; - Ok(Rc::new(vtab)) + Ok(Arc::new(vtab)) } pub fn table( @@ -73,7 +73,7 @@ impl VirtualTable { module_name: &str, args: Vec, syms: &SymbolTable, - ) -> crate::Result> { + ) -> crate::Result> { let module = syms.vtab_modules.get(module_name); let (table, schema) = ExtVirtualTable::create(module_name, module, args, VTabKind::VirtualTable)?; @@ -83,7 +83,7 @@ impl VirtualTable { kind: VTabKind::VirtualTable, vtab_type: VirtualTableType::External(table), }; - Ok(Rc::new(vtab)) + Ok(Arc::new(vtab)) } fn resolve_columns(schema: String) -> crate::Result> {