diff --git a/bindings/go/rs_src/lib.rs b/bindings/go/rs_src/lib.rs index 297b5a7ce..628e067af 100644 --- a/bindings/go/rs_src/lib.rs +++ b/bindings/go/rs_src/lib.rs @@ -5,7 +5,6 @@ mod types; use limbo_core::{Connection, Database, LimboError, IO}; use std::{ ffi::{c_char, c_void}, - rc::Rc, sync::Arc, }; @@ -40,13 +39,13 @@ pub unsafe extern "C" fn db_open(path: *const c_char) -> *mut c_void { #[allow(dead_code)] struct LimboConn { - conn: Rc, + conn: Arc, io: Arc, err: Option, } impl LimboConn { - fn new(conn: Rc, io: Arc) -> Self { + fn new(conn: Arc, io: Arc) -> Self { LimboConn { conn, io, diff --git a/bindings/java/rs_src/limbo_connection.rs b/bindings/java/rs_src/limbo_connection.rs index 68e7ff2de..02b2a0eb8 100644 --- a/bindings/java/rs_src/limbo_connection.rs +++ b/bindings/java/rs_src/limbo_connection.rs @@ -8,19 +8,16 @@ use jni::objects::{JByteArray, JObject}; use jni::sys::jlong; use jni::JNIEnv; use limbo_core::Connection; -use std::rc::Rc; use std::sync::Arc; #[derive(Clone)] pub struct LimboConnection { - // Because java's LimboConnection is 1:1 mapped to limbo connection, we can use Rc - pub(crate) conn: Rc, - // Because io is shared across multiple `LimboConnection`s, wrap it with Arc + pub(crate) conn: Arc, pub(crate) io: Arc, } impl LimboConnection { - pub fn new(conn: Rc, io: Arc) -> Self { + pub fn new(conn: Arc, io: Arc) -> Self { LimboConnection { conn, io } } diff --git a/bindings/javascript/src/lib.rs b/bindings/javascript/src/lib.rs index 6d6b15939..ffa9f966e 100644 --- a/bindings/javascript/src/lib.rs +++ b/bindings/javascript/src/lib.rs @@ -41,7 +41,7 @@ pub struct Database { #[napi(writable = false)] pub name: String, _db: Arc, - conn: Rc, + conn: Arc, io: Arc, } diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 4958074a2..5cf3a9f60 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -228,7 +228,7 @@ fn stmt_is_ddl(sql: &str) -> bool { #[pyclass(unsendable)] #[derive(Clone)] pub struct Connection { - conn: Rc, + conn: Arc, io: Arc, } @@ -310,13 +310,13 @@ pub fn connect(path: &str) -> Result { ":memory:" => { let io: Arc = Arc::new(limbo_core::MemoryIO::new()); let db = open_or(io.clone(), path)?; - let conn: Rc = db.connect().unwrap(); + let conn: Arc = db.connect().unwrap(); Ok(Connection { conn, io }) } path => { let io: Arc = Arc::new(limbo_core::PlatformIO::new()?); let db = open_or(io.clone(), path)?; - let conn: Rc = db.connect().unwrap(); + let conn: Arc = db.connect().unwrap(); Ok(Connection { conn, io }) } } diff --git a/bindings/rust/src/lib.rs b/bindings/rust/src/lib.rs index ccd5dae35..9f29fb420 100644 --- a/bindings/rust/src/lib.rs +++ b/bindings/rust/src/lib.rs @@ -8,7 +8,6 @@ pub use params::params_from_iter; use crate::params::*; use std::fmt::Debug; use std::num::NonZero; -use std::rc::Rc; use std::sync::{Arc, Mutex}; #[derive(Debug, thiserror::Error)] @@ -84,7 +83,7 @@ impl Database { } pub struct Connection { - inner: Arc>>, + inner: Arc>>, } impl Clone for Connection { diff --git a/bindings/wasm/lib.rs b/bindings/wasm/lib.rs index 3bf8e14ff..bbb56b9c7 100644 --- a/bindings/wasm/lib.rs +++ b/bindings/wasm/lib.rs @@ -1,14 +1,13 @@ use js_sys::{Array, Object}; use limbo_core::{maybe_init_database_file, Clock, Instant, OpenFlags, Result}; use std::cell::RefCell; -use std::rc::Rc; use std::sync::Arc; use wasm_bindgen::prelude::*; #[allow(dead_code)] #[wasm_bindgen] pub struct Database { db: Arc, - conn: Rc, + conn: Arc, } #[allow(clippy::arc_with_non_send_sync)] diff --git a/cli/app.rs b/cli/app.rs index a403aa591..6ab03408f 100644 --- a/cli/app.rs +++ b/cli/app.rs @@ -21,7 +21,6 @@ use std::{ fmt, io::{self, BufRead as _, Write}, path::PathBuf, - rc::Rc, sync::{ atomic::{AtomicUsize, Ordering}, Arc, @@ -68,7 +67,7 @@ pub struct Limbo { pub prompt: String, io: Arc, writer: Box, - conn: Rc, + conn: Arc, pub interrupt_count: Arc, input_buff: String, opts: Settings, diff --git a/cli/commands/import.rs b/cli/commands/import.rs index 38ec5df45..7450178d0 100644 --- a/cli/commands/import.rs +++ b/cli/commands/import.rs @@ -1,7 +1,7 @@ use clap::Args; use clap_complete::{ArgValueCompleter, PathCompleter}; use limbo_core::Connection; -use std::{fs::File, io::Write, path::PathBuf, rc::Rc, sync::Arc}; +use std::{fs::File, io::Write, path::PathBuf, sync::Arc}; #[derive(Debug, Clone, Args)] pub struct ImportArgs { @@ -20,14 +20,14 @@ pub struct ImportArgs { } pub struct ImportFile<'a> { - conn: Rc, + conn: Arc, io: Arc, writer: &'a mut dyn Write, } impl<'a> ImportFile<'a> { pub fn new( - conn: Rc, + conn: Arc, io: Arc, writer: &'a mut dyn Write, ) -> Self { diff --git a/cli/helper.rs b/cli/helper.rs index f8e606a89..141e8f8d5 100644 --- a/cli/helper.rs +++ b/cli/helper.rs @@ -8,7 +8,6 @@ use rustyline::{Completer, Helper, Hinter, Validator}; use shlex::Shlex; use std::cell::RefCell; use std::marker::PhantomData; -use std::rc::Rc; use std::sync::Arc; use std::{ffi::OsString, path::PathBuf, str::FromStr as _}; use syntect::dumps::from_uncompressed_data; @@ -42,7 +41,7 @@ pub struct LimboHelper { impl LimboHelper { pub fn new( - conn: Rc, + conn: Arc, io: Arc, syntax_config: Option, ) -> Self { @@ -141,7 +140,7 @@ impl Highlighter for LimboHelper { } pub struct SqlCompleter { - conn: Rc, + conn: Arc, io: Arc, // Has to be a ref cell as Rustyline takes immutable reference to self // This problem would be solved with Reedline as it uses &mut self for completions @@ -150,7 +149,7 @@ pub struct SqlCompleter { } impl SqlCompleter { - pub fn new(conn: Rc, io: Arc) -> Self { + pub fn new(conn: Arc, io: Arc) -> Self { Self { conn, io, diff --git a/core/ext/dynamic.rs b/core/ext/dynamic.rs index 17138f268..363f01fa4 100644 --- a/core/ext/dynamic.rs +++ b/core/ext/dynamic.rs @@ -6,7 +6,6 @@ use libloading::{Library, Symbol}; use limbo_ext::{ExtensionApi, ExtensionApiRef, ExtensionEntryPoint, ResultCode, VfsImpl}; use std::{ ffi::{c_char, CString}, - rc::Rc, sync::{Arc, Mutex, OnceLock}, }; @@ -31,7 +30,7 @@ unsafe impl Sync for VfsMod {} impl Connection { pub fn load_extension>( - self: &Rc, + self: &Arc, path: P, ) -> crate::Result<()> { use limbo_ext::ExtensionApiRef; diff --git a/core/ext/vtab_xconnect.rs b/core/ext/vtab_xconnect.rs index cae4485e5..07a7596e1 100644 --- a/core/ext/vtab_xconnect.rs +++ b/core/ext/vtab_xconnect.rs @@ -5,7 +5,7 @@ use std::{ ffi::{c_char, c_void, CStr, CString}, num::NonZeroUsize, ptr, - rc::Weak, + sync::Arc, }; /// Free memory for the internal context of the connection. @@ -17,7 +17,7 @@ pub unsafe extern "C" fn close(ctx: *mut c_void) { } // only free the memory for the boxed connection, we don't upgrade // or actually close the core connection, as we were 'sharing' it. - let _ = Box::from_raw(ctx as *mut Weak); + let _ = Box::from_raw(ctx as *mut Arc); } /// Wrapper around core Connection::execute with optional arguments to bind @@ -41,12 +41,8 @@ pub unsafe extern "C" fn execute( tracing::error!("query: null connection"); return ResultCode::Error; }; - let weak_ptr = extcon._ctx as *const Weak; - let weak = &*weak_ptr; - let Some(conn) = weak.upgrade() else { - tracing::error!("prepare_stmt: failed to upgrade weak pointer in prepare stmt"); - return ResultCode::Error; - }; + let conn_ptr = extcon._ctx as *const Arc; + let conn = &*conn_ptr; match conn.query(&sql_str) { Ok(Some(mut stmt)) => { if arg_count > 0 { @@ -102,12 +98,8 @@ pub unsafe extern "C" fn prepare_stmt(ctx: *mut ExtConn, sql: *const c_char) -> tracing::error!("prepare_stmt: null connection"); return ptr::null_mut(); }; - let weak_ptr = extcon._ctx as *const Weak; - let weak = &*weak_ptr; - let Some(conn) = weak.upgrade() else { - tracing::error!("prepare_stmt: failed to upgrade weak pointer in prepare stmt"); - return ptr::null_mut(); - }; + let db_ptr = extcon._ctx as *const Arc; + let conn = &*db_ptr; match conn.prepare(&sql_str) { Ok(stmt) => { let raw_stmt = Box::into_raw(Box::new(stmt)) as *mut c_void; diff --git a/core/lib.rs b/core/lib.rs index 6beae3585..679ff78e7 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -209,7 +209,7 @@ impl Database { Ok(db) } - pub fn connect(self: &Arc) -> Result> { + pub fn connect(self: &Arc) -> Result> { let buffer_pool = Rc::new(BufferPool::new(self.page_size as usize)); let wal = Rc::new(RefCell::new(WalFile::new( @@ -227,7 +227,7 @@ impl Database { Arc::new(RwLock::new(DumbLruPageCache::default())), buffer_pool, )?); - let conn = Rc::new(Connection { + let conn = Arc::new(Connection { _db: self.clone(), pager: pager.clone(), schema: self.schema.clone(), @@ -345,7 +345,7 @@ pub struct Connection { impl Connection { #[instrument(skip_all, level = Level::TRACE)] - pub fn prepare(self: &Rc, sql: impl AsRef) -> Result { + pub fn prepare(self: &Arc, sql: impl AsRef) -> Result { if sql.as_ref().is_empty() { return Err(LimboError::InvalidArgument( "The supplied SQL string contains no statements".to_string(), @@ -372,7 +372,7 @@ impl Connection { stmt, self.header.clone(), self.pager.clone(), - Rc::downgrade(self), + self.clone(), &syms, QueryMode::Normal, &input, @@ -389,7 +389,7 @@ impl Connection { } #[instrument(skip_all, level = Level::TRACE)] - pub fn query(self: &Rc, sql: impl AsRef) -> Result> { + pub fn query(self: &Arc, sql: impl AsRef) -> Result> { let sql = sql.as_ref(); tracing::trace!("Querying: {}", sql); let mut parser = Parser::new(sql.as_bytes()); @@ -406,7 +406,7 @@ impl Connection { #[instrument(skip_all, level = Level::TRACE)] pub(crate) fn run_cmd( - self: &Rc, + self: &Arc, cmd: Cmd, input: &str, ) -> Result> { @@ -421,7 +421,7 @@ impl Connection { stmt.clone(), self.header.clone(), self.pager.clone(), - Rc::downgrade(self), + self.clone(), &syms, cmd.into(), input, @@ -464,14 +464,14 @@ impl Connection { } } - pub fn query_runner<'a>(self: &'a Rc, sql: &'a [u8]) -> QueryRunner<'a> { + pub fn query_runner<'a>(self: &'a Arc, sql: &'a [u8]) -> QueryRunner<'a> { QueryRunner::new(self, sql) } /// Execute will run a query from start to finish taking ownership of I/O because it will run pending I/Os if it didn't finish. /// TODO: make this api async #[instrument(skip_all, level = Level::TRACE)] - pub fn execute(self: &Rc, sql: impl AsRef) -> Result<()> { + pub fn execute(self: &Arc, sql: impl AsRef) -> Result<()> { let sql = sql.as_ref(); let mut parser = Parser::new(sql.as_bytes()); let cmd = parser.next()?; @@ -491,7 +491,7 @@ impl Connection { stmt, self.header.clone(), self.pager.clone(), - Rc::downgrade(self), + self.clone(), &syms, QueryMode::Explain, &input, @@ -508,7 +508,7 @@ impl Connection { stmt, self.header.clone(), self.pager.clone(), - Rc::downgrade(self), + self.clone(), &syms, QueryMode::Normal, &input, @@ -620,7 +620,7 @@ impl Connection { self.auto_commit.get() } - pub fn parse_schema_rows(self: &Rc) -> Result<()> { + pub fn parse_schema_rows(self: &Arc) -> Result<()> { let rows = self.query("SELECT * FROM sqlite_schema")?; let mut schema = self .schema @@ -641,7 +641,7 @@ impl Connection { // Clearly there is something to improve here, Vec> isn't a couple of tea /// Query the current rows/values of `pragma_name`. - pub fn pragma_query(self: &Rc, pragma_name: &str) -> Result>> { + pub fn pragma_query(self: &Arc, pragma_name: &str) -> Result>> { let pragma = format!("PRAGMA {}", pragma_name); let mut stmt = self.prepare(pragma)?; let mut results = Vec::new(); @@ -671,7 +671,7 @@ impl Connection { /// Some pragmas will return the updated value which cannot be retrieved /// with this method. pub fn pragma_update( - self: &Rc, + self: &Arc, pragma_name: &str, pragma_value: V, ) -> Result>> { @@ -706,7 +706,7 @@ impl Connection { /// (e.g. `table_info('one_tbl')`) or pragmas which returns value(s) /// (e.g. `integrity_check`). pub fn pragma( - self: &Rc, + self: &Arc, pragma_name: &str, pragma_value: V, ) -> Result>> { @@ -876,13 +876,13 @@ impl SymbolTable { pub struct QueryRunner<'a> { parser: Parser<'a>, - conn: &'a Rc, + conn: &'a Arc, statements: &'a [u8], last_offset: usize, } impl<'a> QueryRunner<'a> { - pub(crate) fn new(conn: &'a Rc, statements: &'a [u8]) -> Self { + pub(crate) fn new(conn: &'a Arc, statements: &'a [u8]) -> Self { Self { parser: Parser::new(statements), conn, diff --git a/core/pragma.rs b/core/pragma.rs index 2d820a77f..1904ffbb6 100644 --- a/core/pragma.rs +++ b/core/pragma.rs @@ -1,8 +1,8 @@ use crate::{Connection, LimboError, Statement, StepResult, Value}; use bitflags::bitflags; use limbo_sqlite3_parser::ast::PragmaName; -use std::rc::{Rc, Weak}; use std::str::FromStr; +use std::sync::Arc; bitflags! { // Flag names match those used in SQLite: @@ -141,13 +141,11 @@ impl PragmaVirtualTable { )) } - pub(crate) fn open(&self, conn: Weak) -> crate::Result { + pub(crate) fn open(&self, conn: Arc) -> crate::Result { Ok(PragmaVirtualTableCursor { pragma_name: self.pragma_name.clone(), pos: 0, - conn: conn - .upgrade() - .ok_or_else(|| LimboError::InternalError("Connection was dropped".into()))?, + conn, stmt: None, arg: None, visible_column_count: self.visible_column_count, @@ -160,7 +158,7 @@ impl PragmaVirtualTable { pub struct PragmaVirtualTableCursor { pragma_name: String, pos: usize, - conn: Rc, + conn: Arc, stmt: Option, arg: Option, visible_column_count: usize, diff --git a/core/storage/btree.rs b/core/storage/btree.rs index 80490f890..8e2dc7754 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -6227,7 +6227,7 @@ mod tests { pos: usize, page: &mut PageContent, record: ImmutableRecord, - conn: &Rc, + conn: &Arc, ) -> Vec { let mut payload: Vec = Vec::new(); fill_cell_payload( diff --git a/core/translate/mod.rs b/core/translate/mod.rs index 39d37056c..27cfde0b0 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -46,7 +46,7 @@ use insert::translate_insert; use limbo_sqlite3_parser::ast::{self, Delete, Insert}; use schema::{translate_create_table, translate_create_virtual_table, translate_drop_table}; use select::translate_select; -use std::rc::{Rc, Weak}; +use std::rc::Rc; use std::sync::Arc; use tracing::{instrument, Level}; use transaction::{translate_tx_begin, translate_tx_commit}; @@ -58,7 +58,7 @@ pub fn translate( stmt: ast::Stmt, database_header: Arc>, pager: Rc, - connection: Weak, + connection: Arc, syms: &SymbolTable, query_mode: QueryMode, _input: &str, // TODO: going to be used for CREATE VIEW diff --git a/core/translate/pragma.rs b/core/translate/pragma.rs index bad117bea..72f740ab0 100644 --- a/core/translate/pragma.rs +++ b/core/translate/pragma.rs @@ -3,7 +3,7 @@ use limbo_sqlite3_parser::ast::PragmaName; use limbo_sqlite3_parser::ast::{self, Expr}; -use std::rc::{Rc, Weak}; +use std::rc::Rc; use std::sync::Arc; use crate::fast_lock::SpinLock; @@ -34,7 +34,7 @@ pub fn translate_pragma( body: Option, database_header: Arc>, pager: Rc, - connection: Weak, + connection: Arc, mut program: ProgramBuilder, ) -> crate::Result { let opts = ProgramBuilderOpts { @@ -124,7 +124,7 @@ fn update_pragma( value: ast::Expr, header: Arc>, pager: Rc, - connection: Weak, + connection: Arc, program: &mut ProgramBuilder, ) -> crate::Result<()> { match pragma { @@ -268,16 +268,13 @@ fn query_pragma( value: Option, database_header: Arc>, pager: Rc, - connection: Weak, + connection: Arc, program: &mut ProgramBuilder, ) -> crate::Result<()> { let register = program.alloc_register(); match pragma { PragmaName::CacheSize => { - program.emit_int( - connection.upgrade().unwrap().get_cache_size() as i64, - register, - ); + program.emit_int(connection.get_cache_size() as i64, register); program.emit_result_row(register, 1); program.add_pragma_result_column(pragma.to_string()); } @@ -417,7 +414,7 @@ fn update_cache_size( value: i64, header: Arc>, pager: Rc, - connection: Weak, + connection: Arc, ) -> crate::Result<()> { let mut cache_size_unformatted: i64 = value; let mut cache_size = if cache_size_unformatted < 0 { @@ -432,10 +429,7 @@ fn update_cache_size( cache_size = MIN_PAGE_CACHE_SIZE; cache_size_unformatted = MIN_PAGE_CACHE_SIZE as i64; } - connection - .upgrade() - .unwrap() - .set_cache_size(cache_size_unformatted as i32); + connection.set_cache_size(cache_size_unformatted as i32); // update cache size pager diff --git a/core/vdbe/builder.rs b/core/vdbe/builder.rs index 8254695d8..e90f6d736 100644 --- a/core/vdbe/builder.rs +++ b/core/vdbe/builder.rs @@ -1,9 +1,4 @@ -use std::{ - cell::Cell, - cmp::Ordering, - rc::{Rc, Weak}, - sync::Arc, -}; +use std::{cell::Cell, cmp::Ordering, rc::Rc, sync::Arc}; use limbo_sqlite3_parser::ast::{self, TableInternalId}; use tracing::{instrument, Level}; @@ -856,7 +851,7 @@ impl ProgramBuilder { pub fn build( mut self, database_header: Arc>, - connection: Weak, + connection: Arc, change_cnt_on: bool, ) -> Program { self.resolve_labels(); diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index 16fb32db1..77aee268a 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -214,10 +214,8 @@ pub fn op_drop_index( let Insn::DropIndex { index, db: _ } = insn else { unreachable!("unexpected Insn {:?}", insn) }; - if let Some(conn) = program.connection.upgrade() { - let mut schema = conn.schema.write(); - schema.remove_index(&index); - } + let mut schema = program.connection.schema.write(); + schema.remove_index(&index); state.pc += 1; Ok(InsnFunctionStepResult::Step) } @@ -310,7 +308,7 @@ pub fn op_checkpoint( else { unreachable!("unexpected Insn {:?}", insn) }; - let result = program.connection.upgrade().unwrap().checkpoint(); + let result = program.connection.checkpoint(); match result { Ok(CheckpointResult { num_wal_frames: num_wal_pages, @@ -900,7 +898,7 @@ pub fn op_open_read( .replace(Cursor::new_btree(cursor)); } CursorType::BTreeIndex(index) => { - let conn = program.connection.upgrade().unwrap(); + let conn = program.connection.clone(); let schema = conn.schema.try_read().ok_or(LimboError::SchemaLocked)?; let table = schema .get_table(&index.table_name) @@ -998,11 +996,7 @@ pub fn op_vcreate( } else { vec![] }; - let Some(conn) = program.connection.upgrade() else { - return Err(crate::LimboError::ExtensionError( - "Failed to upgrade Connection".to_string(), - )); - }; + let conn = program.connection.clone(); let table = crate::VirtualTable::table(Some(&table_name), &module_name, args, &conn.syms.borrow())?; { @@ -1123,9 +1117,7 @@ pub fn op_vupdate( Ok(Some(new_rowid)) => { if *conflict_action == 5 { // ResolveType::Replace - if let Some(conn) = program.connection.upgrade() { - conn.update_last_rowid(new_rowid); - } + program.connection.update_last_rowid(new_rowid); } state.pc += 1; } @@ -1181,12 +1173,7 @@ pub fn op_vdestroy( let Insn::VDestroy { db, table_name } = insn else { unreachable!("unexpected Insn {:?}", insn) }; - let Some(conn) = program.connection.upgrade() else { - return Err(crate::LimboError::ExtensionError( - "Failed to upgrade Connection".to_string(), - )); - }; - + let conn = program.connection.clone(); { let Some(vtab) = conn.syms.borrow_mut().vtabs.remove(table_name) else { return Err(crate::LimboError::InternalError( @@ -1691,7 +1678,7 @@ pub fn op_transaction( let Insn::Transaction { write } = insn else { unreachable!("unexpected Insn {:?}", insn) }; - let connection = program.connection.upgrade().unwrap(); + let connection = program.connection.clone(); if *write && connection._db.open_flags.contains(OpenFlags::ReadOnly) { return Err(LimboError::ReadOnly); } @@ -1746,7 +1733,7 @@ pub fn op_auto_commit( else { unreachable!("unexpected Insn {:?}", insn) }; - let conn = program.connection.upgrade().unwrap(); + let conn = program.connection.clone(); if state.commit_state == CommitState::Committing { return match program.commit_txn(pager.clone(), state, mv_store)? { super::StepResult::Done => Ok(InsnFunctionStepResult::Done), @@ -3387,7 +3374,7 @@ pub fn op_function( state.registers[*dest] = Register::Value(result); } ScalarFunc::Changes => { - let res = &program.connection.upgrade().unwrap().last_change; + let res = &program.connection.last_change; let changes = res.get(); state.registers[*dest] = Register::Value(Value::Integer(changes)); } @@ -3435,12 +3422,8 @@ pub fn op_function( state.registers[*dest] = Register::Value(result); } ScalarFunc::LastInsertRowid => { - if let Some(conn) = program.connection.upgrade() { - state.registers[*dest] = - Register::Value(Value::Integer(conn.last_insert_rowid() as i64)); - } else { - state.registers[*dest] = Register::Value(Value::Null); - } + state.registers[*dest] = + Register::Value(Value::Integer(program.connection.last_insert_rowid() as i64)); } ScalarFunc::Like => { let pattern = &state.registers[*start_reg]; @@ -3645,7 +3628,7 @@ pub fn op_function( } } ScalarFunc::TotalChanges => { - let res = &program.connection.upgrade().unwrap().total_changes; + let res = &program.connection.total_changes; let total_changes = res.get(); state.registers[*dest] = Register::Value(Value::Integer(total_changes)); } @@ -3706,9 +3689,7 @@ pub fn op_function( ScalarFunc::LoadExtension => { let extension = &state.registers[*start_reg]; let ext = resolve_ext_path(&extension.get_owned_value().to_string())?; - if let Some(conn) = program.connection.upgrade() { - conn.load_extension(ext)?; - } + program.connection.load_extension(ext)?; } ScalarFunc::StrfTime => { let result = exec_strftime(&state.registers[*start_reg..*start_reg + arg_count]); @@ -4234,9 +4215,7 @@ pub fn op_insert( // Only update last_insert_rowid for regular table inserts, not schema modifications if cursor.root_page() != 1 { if let Some(rowid) = return_if_io!(cursor.rowid()) { - if let Some(conn) = program.connection.upgrade() { - conn.update_last_rowid(rowid); - } + program.connection.update_last_rowid(rowid); let prev_changes = program.n_change.get(); program.n_change.set(prev_changes + 1); } @@ -4697,7 +4676,7 @@ pub fn op_open_write( None => None, }; if let Some(index) = maybe_index { - let conn = program.connection.upgrade().unwrap(); + let conn = program.connection.clone(); let schema = conn.schema.try_read().ok_or(LimboError::SchemaLocked)?; let table = schema .get_table(&index.table_name) @@ -4823,7 +4802,8 @@ pub fn op_drop_table( if *db > 0 { todo!("temp databases not implemented yet"); } - if let Some(conn) = program.connection.upgrade() { + let conn = program.connection.clone(); + { let mut schema = conn.schema.write(); schema.remove_indices_for_table(table_name); schema.remove_table(table_name); @@ -4900,8 +4880,7 @@ pub fn op_parse_schema( else { unreachable!("unexpected Insn {:?}", insn) }; - let conn = program.connection.upgrade(); - let conn = conn.as_ref().unwrap(); + let conn = program.connection.clone(); if let Some(where_clause) = where_clause { let stmt = conn.prepare(format!( @@ -5187,7 +5166,7 @@ pub fn op_open_ephemeral( _ => unreachable!("unexpected Insn {:?}", insn), }; - let conn = program.connection.upgrade().unwrap(); + let conn = program.connection.clone(); let io = conn.pager.io.get_memory_io(); let file = io.open_file("", OpenFlags::Create, true)?; diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 46704d7b5..35a5fd1cd 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -54,8 +54,7 @@ use std::{ cell::{Cell, RefCell}, collections::HashMap, num::NonZero, - ops::Deref, - rc::{Rc, Weak}, + rc::Rc, sync::Arc, }; use tracing::{instrument, Level}; @@ -351,7 +350,6 @@ macro_rules! must_be_btree_cursor { }}; } -#[derive(Debug)] pub struct Program { pub max_registers: usize, pub insns: Vec<(Insn, InsnFunction)>, @@ -359,7 +357,7 @@ pub struct Program { pub database_header: Arc>, pub comments: Option>, pub parameters: crate::parameters::Parameters, - pub connection: Weak, + pub connection: Arc, pub n_change: Cell, pub change_cnt_on: bool, pub result_columns: Vec, @@ -401,7 +399,7 @@ impl Program { mv_store: Option<&Rc>, ) -> Result { if let Some(mv_store) = mv_store { - let conn = self.connection.upgrade().unwrap(); + let conn = self.connection.clone(); let auto_commit = conn.auto_commit.get(); if auto_commit { let mut mv_transactions = conn.mv_transactions.borrow_mut(); @@ -412,21 +410,18 @@ impl Program { } Ok(StepResult::Done) } else { - let connection = self - .connection - .upgrade() - .expect("only weak ref to connection?"); + let connection = self.connection.clone(); let auto_commit = connection.auto_commit.get(); tracing::trace!("Halt auto_commit {}", auto_commit); if program_state.commit_state == CommitState::Committing { - self.step_end_write_txn(&pager, &mut program_state.commit_state, connection.deref()) + self.step_end_write_txn(&pager, &mut program_state.commit_state, &connection) } else if auto_commit { let current_state = connection.transaction_state.get(); match current_state { TransactionState::Write => self.step_end_write_txn( &pager, &mut program_state.commit_state, - connection.deref(), + &connection, ), TransactionState::Read => { connection.transaction_state.replace(TransactionState::None); @@ -437,9 +432,7 @@ impl Program { } } else { if self.change_cnt_on { - if let Some(conn) = self.connection.upgrade() { - conn.set_changes(self.n_change.get()); - } + self.connection.set_changes(self.n_change.get()); } Ok(StepResult::Done) } @@ -457,9 +450,7 @@ impl Program { match cacheflush_status { PagerCacheflushStatus::Done(_) => { if self.change_cnt_on { - if let Some(conn) = self.connection.upgrade() { - conn.set_changes(self.n_change.get()); - } + self.connection.set_changes(self.n_change.get()); } connection.transaction_state.replace(TransactionState::None); *commit_state = CommitState::Ready; diff --git a/core/vtab.rs b/core/vtab.rs index 3297b0346..b9d81781a 100644 --- a/core/vtab.rs +++ b/core/vtab.rs @@ -7,7 +7,8 @@ use limbo_ext::{ConstraintInfo, IndexInfo, OrderByInfo, ResultCode, VTabKind, VT use limbo_sqlite3_parser::{ast, lexer::sql::Parser}; use std::cell::RefCell; use std::ffi::c_void; -use std::rc::{Rc, Weak}; +use std::rc::Rc; +use std::sync::Arc; #[derive(Debug, Clone)] enum VirtualTableType { @@ -90,7 +91,7 @@ impl VirtualTable { } } - pub(crate) fn open(&self, conn: Weak) -> crate::Result { + pub(crate) fn open(&self, conn: Arc) -> crate::Result { match &self.vtab_type { VirtualTableType::Pragma(table) => Ok(VirtualTableCursor::Pragma(table.open(conn)?)), VirtualTableType::External(table) => { @@ -237,11 +238,11 @@ impl ExtVirtualTable { Ok((vtab, schema)) } - /// Accepts a Weak pointer to the connection that owns the VTable, that the module + /// Accepts a pointer connection that owns the VTable, that the module /// can optionally use to query the other tables. - fn open(&self, conn: Weak) -> crate::Result { + fn open(&self, conn: Arc) -> crate::Result { // we need a Weak to upgrade and call from the extension. - let weak_box: *mut Weak = Box::into_raw(Box::new(conn)); + let weak_box: *mut Arc = Box::into_raw(Box::new(conn)); let conn = limbo_ext::Conn::new( weak_box.cast(), crate::ext::prepare_stmt, diff --git a/extensions/completion/src/lib.rs b/extensions/completion/src/lib.rs index ae7a9d50e..db4ca26bf 100644 --- a/extensions/completion/src/lib.rs +++ b/extensions/completion/src/lib.rs @@ -3,7 +3,7 @@ mod keywords; -use std::rc::Rc; +use std::sync::Arc; use keywords::KEYWORDS; use limbo_ext::{ @@ -87,7 +87,7 @@ impl VTable for CompletionTable { type Cursor = CompletionCursor; type Error = ResultCode; - fn open(&self, _conn: Option>) -> Result { + fn open(&self, _conn: Option>) -> Result { Ok(CompletionCursor::default()) } } diff --git a/extensions/core/src/vtabs.rs b/extensions/core/src/vtabs.rs index a77e709a5..fb26f6499 100644 --- a/extensions/core/src/vtabs.rs +++ b/extensions/core/src/vtabs.rs @@ -2,7 +2,7 @@ use crate::{types::StepResult, ExtResult, ResultCode, Value}; use std::{ ffi::{c_char, c_void, CStr, CString}, num::NonZeroUsize, - rc::Rc, + sync::Arc, }; pub type RegisterModuleFn = unsafe extern "C" fn( @@ -131,7 +131,7 @@ pub trait VTable { /// 'conn' is an Option to allow for testing. Otherwise a valid connection to the core database /// that created the virtual table will be available to use in your extension here. - fn open(&self, _conn: Option>) -> Result; + fn open(&self, _conn: Option>) -> Result; fn update(&mut self, _rowid: i64, _args: &[Value]) -> Result<(), Self::Error> { Ok(()) } @@ -463,7 +463,7 @@ impl Connection { } /// From the included SQL string, prepare a statement for execution. - pub fn prepare(self: &Rc, sql: &str) -> ExtResult { + pub fn prepare(self: &Arc, sql: &str) -> ExtResult { let stmt = unsafe { (*self.0).prepare_stmt(sql) }; if stmt.is_null() { return Err(ResultCode::Error); @@ -473,7 +473,7 @@ impl Connection { /// Execute a SQL statement with the given arguments. /// Optionally returns the last inserted rowid for the query. - pub fn execute(self: &Rc, sql: &str, args: &[Value]) -> crate::ExtResult> { + pub fn execute(self: &Arc, sql: &str, args: &[Value]) -> crate::ExtResult> { if self.0.is_null() { return Err(ResultCode::Error); } diff --git a/extensions/csv/src/lib.rs b/extensions/csv/src/lib.rs index 3dfd0e2fa..55dee67fc 100644 --- a/extensions/csv/src/lib.rs +++ b/extensions/csv/src/lib.rs @@ -26,7 +26,7 @@ use limbo_ext::{ }; use std::fs::File; use std::io::{Read, Seek, SeekFrom}; -use std::rc::Rc; +use std::sync::Arc; register_extension! { vtabs: { CsvVTabModule } @@ -260,7 +260,7 @@ impl VTable for CsvTable { type Cursor = CsvCursor; type Error = ResultCode; - fn open(&self, _conn: Option>) -> Result { + fn open(&self, _conn: Option>) -> Result { match self.new_reader() { Ok(reader) => Ok(CsvCursor::new(reader, self)), Err(_) => Err(ResultCode::Error), diff --git a/extensions/series/src/lib.rs b/extensions/series/src/lib.rs index 437940b23..26fe192de 100644 --- a/extensions/series/src/lib.rs +++ b/extensions/series/src/lib.rs @@ -1,4 +1,4 @@ -use std::rc::Rc; +use std::sync::Arc; use limbo_ext::{ register_extension, Connection, ResultCode, VTabCursor, VTabKind, VTabModule, VTabModuleDerive, @@ -45,7 +45,7 @@ impl VTable for GenerateSeriesTable { type Cursor = GenerateSeriesCursor; type Error = ResultCode; - fn open(&self, _conn: Option>) -> Result { + fn open(&self, _conn: Option>) -> Result { Ok(GenerateSeriesCursor { start: 0, stop: 0, diff --git a/extensions/tests/src/lib.rs b/extensions/tests/src/lib.rs index 9c8ad0019..e5acde939 100644 --- a/extensions/tests/src/lib.rs +++ b/extensions/tests/src/lib.rs @@ -10,8 +10,7 @@ use std::collections::BTreeMap; use std::fs::{File, OpenOptions}; use std::io::{Read, Seek, SeekFrom, Write}; use std::num::NonZeroUsize; -use std::rc::Rc; -use std::sync::Mutex; +use std::sync::{Arc, Mutex}; register_extension! { vtabs: { KVStoreVTabModule, TableStatsVtabModule }, @@ -139,7 +138,7 @@ impl VTable for KVStoreTable { type Cursor = KVStoreCursor; type Error = String; - fn open(&self, _conn: Option>) -> Result { + fn open(&self, _conn: Option>) -> Result { let _ = env_logger::try_init(); Ok(KVStoreCursor { rows: Vec::new(), @@ -303,7 +302,7 @@ pub struct TableStatsVtabModule; pub struct StatsCursor { pos: usize, rows: Vec<(String, i64)>, - conn: Option>, + conn: Option>, } pub struct StatsTable {} @@ -322,7 +321,7 @@ impl VTable for StatsTable { type Cursor = StatsCursor; type Error = String; - fn open(&self, conn: Option>) -> Result { + fn open(&self, conn: Option>) -> Result { Ok(StatsCursor { pos: 0, rows: Vec::new(), diff --git a/macros/src/ext/vtab_derive.rs b/macros/src/ext/vtab_derive.rs index eb5112a90..9fe94de05 100644 --- a/macros/src/ext/vtab_derive.rs +++ b/macros/src/ext/vtab_derive.rs @@ -55,7 +55,7 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream { } let table = table as *const <#struct_name as ::limbo_ext::VTabModule>::Table; let table: &<#struct_name as ::limbo_ext::VTabModule>::Table = &*table; - let conn = if conn.is_null() { None } else { Some(::std::rc::Rc::new(::limbo_ext::Connection::new(conn)))}; + let conn = if conn.is_null() { None } else { Some(::std::sync::Arc::new(::limbo_ext::Connection::new(conn)))}; if let Ok(cursor) = <#struct_name as ::limbo_ext::VTabModule>::Table::open(table, conn) { return ::std::boxed::Box::into_raw(::std::boxed::Box::new(cursor)) as *const ::std::ffi::c_void; } else { diff --git a/simulator/generation/plan.rs b/simulator/generation/plan.rs index 4bb408885..ec6ee471e 100644 --- a/simulator/generation/plan.rs +++ b/simulator/generation/plan.rs @@ -2,7 +2,7 @@ use std::{ collections::HashSet, fmt::{Debug, Display}, path::Path, - rc::Rc, + sync::Arc, vec, }; @@ -504,7 +504,7 @@ impl Interaction { Self::Assumption(_) | Self::Assertion(_) | Self::Fault(_) => vec![], } } - pub(crate) fn execute_query(&self, conn: &mut Rc, io: &SimulatorIO) -> ResultSet { + pub(crate) fn execute_query(&self, conn: &mut Arc, io: &SimulatorIO) -> ResultSet { if let Self::Query(query) = self { let query_str = query.to_string(); let rows = conn.query(&query_str); diff --git a/simulator/runner/env.rs b/simulator/runner/env.rs index b5a06d046..78ff3c605 100644 --- a/simulator/runner/env.rs +++ b/simulator/runner/env.rs @@ -1,7 +1,6 @@ use std::fmt::Display; use std::mem; use std::path::Path; -use std::rc::Rc; use std::sync::Arc; use limbo_core::Database; @@ -164,7 +163,7 @@ where } pub(crate) enum SimConnection { - LimboConnection(Rc), + LimboConnection(Arc), SQLiteConnection(rusqlite::Connection), Disconnected, } diff --git a/sqlite3/src/lib.rs b/sqlite3/src/lib.rs index 8066bd7d5..a3b21ff54 100644 --- a/sqlite3/src/lib.rs +++ b/sqlite3/src/lib.rs @@ -5,7 +5,6 @@ use limbo_core::Value; use std::ffi::{self, CStr, CString}; use tracing::trace; -use std::rc::Rc; use std::sync::Arc; macro_rules! stub { @@ -42,7 +41,7 @@ use util::sqlite3_safety_check_sick_or_ok; pub struct sqlite3 { pub(crate) io: Arc, pub(crate) _db: Arc, - pub(crate) conn: Rc, + pub(crate) conn: Arc, pub(crate) err_code: ffi::c_int, pub(crate) err_mask: ffi::c_int, pub(crate) malloc_failed: bool, @@ -54,7 +53,7 @@ impl sqlite3 { pub fn new( io: Arc, db: Arc, - conn: Rc, + conn: Arc, ) -> Self { Self { io, diff --git a/tests/integration/common.rs b/tests/integration/common.rs index 4de009e33..8d0fc7830 100644 --- a/tests/integration/common.rs +++ b/tests/integration/common.rs @@ -2,7 +2,6 @@ use limbo_core::{Connection, Database, PagerCacheflushStatus, IO}; use rand::{rng, RngCore}; use rusqlite::params; use std::path::{Path, PathBuf}; -use std::rc::Rc; use std::sync::Arc; use tempfile::TempDir; use tracing_subscriber::layer::SubscriberExt; @@ -55,14 +54,14 @@ impl TempDatabase { Self { path, io } } - pub fn connect_limbo(&self) -> Rc { + pub fn connect_limbo(&self) -> Arc { Self::connect_limbo_with_flags(&self, limbo_core::OpenFlags::default()) } pub fn connect_limbo_with_flags( &self, flags: limbo_core::OpenFlags, - ) -> Rc { + ) -> Arc { log::debug!("conneting to limbo"); let db = Database::open_file_with_flags( self.io.clone(), @@ -83,7 +82,7 @@ impl TempDatabase { } } -pub(crate) fn do_flush(conn: &Rc, tmp_db: &TempDatabase) -> anyhow::Result<()> { +pub(crate) fn do_flush(conn: &Arc, tmp_db: &TempDatabase) -> anyhow::Result<()> { loop { match conn.cacheflush()? { PagerCacheflushStatus::Done(_) => { @@ -155,7 +154,7 @@ pub(crate) fn sqlite_exec_rows( pub(crate) fn limbo_exec_rows( db: &TempDatabase, - conn: &Rc, + conn: &Arc, query: &str, ) -> Vec> { let mut stmt = conn.prepare(query).unwrap(); @@ -193,7 +192,7 @@ pub(crate) fn limbo_exec_rows( pub(crate) fn limbo_exec_rows_error( db: &TempDatabase, - conn: &Rc, + conn: &Arc, query: &str, ) -> limbo_core::Result<()> { let mut stmt = conn.prepare(query)?; diff --git a/tests/integration/query_processing/test_write_path.rs b/tests/integration/query_processing/test_write_path.rs index edebd2821..dd1c0d7c9 100644 --- a/tests/integration/query_processing/test_write_path.rs +++ b/tests/integration/query_processing/test_write_path.rs @@ -2,7 +2,7 @@ use crate::common::{self, maybe_setup_tracing}; use crate::common::{compare_string, do_flush, TempDatabase}; use limbo_core::{Connection, Row, StepResult, Value}; use log::debug; -use std::rc::Rc; +use std::sync::Arc; #[test] #[ignore] @@ -286,7 +286,7 @@ fn test_wal_restart() -> anyhow::Result<()> { let tmp_db = TempDatabase::new_with_rusqlite("CREATE TABLE test (x INTEGER PRIMARY KEY);"); // threshold is 1000 by default - fn insert(i: usize, conn: &Rc, tmp_db: &TempDatabase) -> anyhow::Result<()> { + fn insert(i: usize, conn: &Arc, tmp_db: &TempDatabase) -> anyhow::Result<()> { debug!("inserting {}", i); let insert_query = format!("INSERT INTO test VALUES ({})", i); run_query(tmp_db, conn, &insert_query)?; @@ -295,7 +295,7 @@ fn test_wal_restart() -> anyhow::Result<()> { Ok(()) } - fn count(conn: &Rc, tmp_db: &TempDatabase) -> anyhow::Result { + fn count(conn: &Arc, tmp_db: &TempDatabase) -> anyhow::Result { debug!("counting"); let list_query = "SELECT count(x) FROM test"; let mut count = None; @@ -447,13 +447,13 @@ fn test_delete_with_index() -> anyhow::Result<()> { Ok(()) } -fn run_query(tmp_db: &TempDatabase, conn: &Rc, query: &str) -> anyhow::Result<()> { +fn run_query(tmp_db: &TempDatabase, conn: &Arc, query: &str) -> anyhow::Result<()> { run_query_core(tmp_db, conn, query, None::) } fn run_query_on_row( tmp_db: &TempDatabase, - conn: &Rc, + conn: &Arc, query: &str, on_row: impl FnMut(&Row), ) -> anyhow::Result<()> { @@ -462,7 +462,7 @@ fn run_query_on_row( fn run_query_core( tmp_db: &TempDatabase, - conn: &Rc, + conn: &Arc, query: &str, mut on_row: Option, ) -> anyhow::Result<()> { diff --git a/tests/integration/wal/test_wal.rs b/tests/integration/wal/test_wal.rs index 9936da739..3d2000c2a 100644 --- a/tests/integration/wal/test_wal.rs +++ b/tests/integration/wal/test_wal.rs @@ -112,7 +112,7 @@ fn test_wal_1_writer_1_reader() -> Result<()> { /// Execute a statement and get strings result pub(crate) fn execute_and_get_strings( tmp_db: &TempDatabase, - conn: &Rc, + conn: &Arc, sql: &str, ) -> Result> { let statement = conn.prepare(sql)?; @@ -140,7 +140,7 @@ pub(crate) fn execute_and_get_strings( /// Execute a statement and get integers pub(crate) fn execute_and_get_ints( tmp_db: &TempDatabase, - conn: &Rc, + conn: &Arc, sql: &str, ) -> Result> { let statement = conn.prepare(sql)?;