From c0e51c4ca600c2b03b8d7a500c2a02da52aeeb69 Mon Sep 17 00:00:00 2001 From: Pere Diaz Bou Date: Sun, 22 Sep 2024 17:45:39 +0200 Subject: [PATCH] wip wal --- bindings/wasm/lib.rs | 19 ++++ cli/main.rs | 31 ++++--- core/io/darwin.rs | 8 +- core/lib.rs | 43 ++++++--- core/storage/pager.rs | 16 +++- core/storage/sqlite3_ondisk.rs | 91 +++++++++++++++--- core/storage/wal.rs | 165 ++++++++++++++++++++++++++++++--- core/translate/emitter.rs | 10 +- core/translate/insert.rs | 8 +- core/translate/mod.rs | 19 ++-- core/translate/select.rs | 10 +- core/vdbe/builder.rs | 16 +++- core/vdbe/explain.rs | 4 +- core/vdbe/mod.rs | 43 +++++++-- sqlite3/src/lib.rs | 7 +- test/src/lib.rs | 3 +- 16 files changed, 405 insertions(+), 88 deletions(-) diff --git a/bindings/wasm/lib.rs b/bindings/wasm/lib.rs index 8c0fb8cd4..77ad9e6cb 100644 --- a/bindings/wasm/lib.rs +++ b/bindings/wasm/lib.rs @@ -127,6 +127,8 @@ impl DatabaseStorage { } } +struct BufferPool {} + impl limbo_core::DatabaseStorage for DatabaseStorage { fn read_page(&self, page_idx: usize, c: Rc) -> Result<()> { let r = match &(*c) { @@ -168,10 +170,27 @@ impl limbo_core::Wal for Wal { Ok(None) } + fn begin_write_tx(&self) -> Result<()> { + todo!() + } + + fn end_write_tx(&self) -> Result<()> { + todo!() + } + + fn append_frame( + &self, + _page: Rc>, + _db_size: u32, + ) -> Result<()> { + todo!() + } + fn read_frame( &self, _frame_id: u64, _page: Rc>, + _buffer_pool: Rc, ) -> Result<()> { todo!() } diff --git a/cli/main.rs b/cli/main.rs index 6f887902f..fcb339530 100644 --- a/cli/main.rs +++ b/cli/main.rs @@ -6,6 +6,7 @@ use limbo_core::{Database, RowResult, Value}; use opcodes_dictionary::OPCODE_DESCRIPTIONS; use rustyline::{error::ReadlineError, DefaultEditor}; use std::path::PathBuf; +use std::rc::Rc; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; @@ -147,7 +148,7 @@ Note: fn handle_dot_command( io: Arc, - conn: &limbo_core::Connection, + conn: &Rc, line: &str, ) -> anyhow::Result<()> { let args: Vec<&str> = line.split_whitespace().collect(); @@ -196,7 +197,7 @@ fn handle_dot_command( fn display_schema( io: Arc, - conn: &limbo_core::Connection, + conn: &Rc, table: Option<&str>, ) -> anyhow::Result<()> { let sql = match table { @@ -251,7 +252,7 @@ fn display_schema( fn query( io: Arc, - conn: &limbo_core::Connection, + conn: &Rc, sql: &str, output_mode: &OutputMode, interrupt_count: &Arc, @@ -264,8 +265,8 @@ fn query( return Ok(()); } - match rows.next_row()? { - RowResult::Row(row) => { + match rows.next_row() { + Ok(RowResult::Row(row)) => { for (i, value) in row.values.iter().enumerate() { if i > 0 { print!("|"); @@ -282,10 +283,14 @@ fn query( } println!(); } - RowResult::IO => { + Ok(RowResult::IO) => { io.run_once()?; } - RowResult::Done => { + Ok(RowResult::Done) => { + break; + } + Err(err) => { + eprintln!("{}", err); break; } } @@ -297,8 +302,8 @@ fn query( } let mut table_rows: Vec> = vec![]; loop { - match rows.next_row()? { - RowResult::Row(row) => { + match rows.next_row() { + Ok(RowResult::Row(row)) => { table_rows.push( row.values .iter() @@ -314,10 +319,14 @@ fn query( .collect(), ); } - RowResult::IO => { + Ok(RowResult::IO) => { io.run_once()?; } - RowResult::Done => break, + Ok(RowResult::Done) => break, + Err(err) => { + eprintln!("{}", err); + break; + } } } let table = table_rows.table(); diff --git a/core/io/darwin.rs b/core/io/darwin.rs index 1016e9954..141a16541 100644 --- a/core/io/darwin.rs +++ b/core/io/darwin.rs @@ -153,7 +153,9 @@ impl File for DarwinFile { if lock_result == -1 { let err = std::io::Error::last_os_error(); if err.kind() == std::io::ErrorKind::WouldBlock { - return Err(LimboError::LockingError("Failed locking file. File is locked by another process".to_string())); + return Err(LimboError::LockingError( + "Failed locking file. File is locked by another process".to_string(), + )); } else { return Err(LimboError::LockingError(format!( "Failed locking file, {}", @@ -184,8 +186,8 @@ impl File for DarwinFile { Ok(()) } - fn pread(&self, pos: usize, c: Rc) -> Result<()> { - let file = self.file.borrow(); + fn pread(&self, pos: usize, c: Rc) -> Result<()> { + let file = self.file.borrow(); let result = { let r = match &(*c) { Completion::Read(r) => r, diff --git a/core/lib.rs b/core/lib.rs index 2b35f07a6..0581c4dd6 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -19,6 +19,8 @@ use log::trace; use schema::Schema; use sqlite3_parser::ast; use sqlite3_parser::{ast::Cmd, lexer::sql::Parser}; +use std::rc::Weak; +use std::sync::Arc; use std::sync::{Arc, OnceLock}; use std::{cell::RefCell, rc::Rc}; #[cfg(feature = "fs")] @@ -44,15 +46,23 @@ pub use types::Value; pub static DATABASE_VERSION: OnceLock = OnceLock::new(); +#[derive(Clone)] +enum TransactionState { + Write, + Read, + None, +} + pub struct Database { pager: Rc, schema: Rc, header: Rc>, + transaction_state: RefCell, } impl Database { #[cfg(feature = "fs")] - pub fn open_file(io: Arc, path: &str) -> Result { + pub fn open_file(io: Arc, path: &str) -> Result> { let file = io.open_file(path)?; let page_io = Rc::new(FileStorage::new(file)); let wal_path = format!("{}-wal", path); @@ -64,7 +74,7 @@ impl Database { io: Arc, page_io: Rc, wal: Rc, - ) -> Result { + ) -> Result> { let db_header = Pager::begin_open(page_io.clone())?; DATABASE_VERSION.get_or_init(|| { let version = db_header.borrow().version_number; @@ -78,11 +88,12 @@ impl Database { io.clone(), )?); let bootstrap_schema = Rc::new(Schema::new()); - let conn = Connection { + let conn = Rc::new(Connection { pager: pager.clone(), schema: bootstrap_schema.clone(), header: db_header.clone(), - }; + db: Weak::new(), + }); let mut schema = Schema::new(); let rows = conn.query("SELECT * FROM sqlite_schema")?; if let Some(mut rows) = rows { @@ -126,19 +137,21 @@ impl Database { } let schema = Rc::new(schema); let header = db_header; - Ok(Database { + Ok(Rc::new(Database { pager, schema, header, - }) + transaction_state: RefCell::new(TransactionState::None), + })) } - pub fn connect(&self) -> Connection { - Connection { + pub fn connect(self: &Rc) -> Rc { + Rc::new(Connection { pager: self.pager.clone(), schema: self.schema.clone(), header: self.header.clone(), - } + db: Rc::downgrade(self), + }) } } @@ -146,10 +159,11 @@ pub struct Connection { pager: Rc, schema: Rc, header: Rc>, + db: Weak, // backpointer to the database holding this connection } impl Connection { - pub fn prepare(&self, sql: impl Into) -> Result { + pub fn prepare(self: &Rc, sql: impl Into) -> Result { let sql = sql.into(); trace!("Preparing: {}", sql); let mut parser = Parser::new(sql.as_bytes()); @@ -162,6 +176,7 @@ impl Connection { stmt, self.header.clone(), self.pager.clone(), + Rc::downgrade(self), )?); Ok(Statement::new(program, self.pager.clone())) } @@ -173,7 +188,7 @@ impl Connection { } } - pub fn query(&self, sql: impl Into) -> Result> { + pub fn query(self: &Rc, sql: impl Into) -> Result> { let sql = sql.into(); trace!("Querying: {}", sql); let mut parser = Parser::new(sql.as_bytes()); @@ -186,6 +201,7 @@ impl Connection { stmt, self.header.clone(), self.pager.clone(), + Rc::downgrade(&self), )?); let stmt = Statement::new(program, self.pager.clone()); Ok(Some(Rows { stmt })) @@ -196,6 +212,7 @@ impl Connection { stmt, self.header.clone(), self.pager.clone(), + Rc::downgrade(self), )?; program.explain(); Ok(None) @@ -217,7 +234,7 @@ impl Connection { } } - pub fn execute(&self, sql: impl Into) -> Result<()> { + pub fn execute(self: &Rc, sql: impl Into) -> Result<()> { let sql = sql.into(); let mut parser = Parser::new(sql.as_bytes()); let cmd = parser.next()?; @@ -229,6 +246,7 @@ impl Connection { stmt, self.header.clone(), self.pager.clone(), + Rc::downgrade(self), )?; program.explain(); } @@ -239,6 +257,7 @@ impl Connection { stmt, self.header.clone(), self.pager.clone(), + Rc::downgrade(self), )?; let mut state = vdbe::ProgramState::new(program.max_registers); program.step(&mut state, self.pager.clone())?; diff --git a/core/storage/pager.rs b/core/storage/pager.rs index 2a463514f..025ce0e24 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -307,7 +307,12 @@ impl Pager { Ok(()) } - pub fn end_read_tx(&self) -> Result<()> { + pub fn begin_write_tx(&self) -> Result<()> { + self.wal.begin_read_tx()?; + Ok(()) + } + + pub fn end_tx(&self) -> Result<()> { self.wal.end_read_tx()?; Ok(()) } @@ -322,7 +327,9 @@ impl Pager { let page = Rc::new(RefCell::new(Page::new(page_idx))); RefCell::borrow(&page).set_locked(); if let Some(frame_id) = self.wal.find_frame(page_idx as u64)? { - self.wal.read_frame(frame_id, page.clone())?; + dbg!(frame_id); + self.wal + .read_frame(frame_id, page.clone(), self.buffer_pool.clone())?; { let page = page.borrow_mut(); page.set_uptodate(); @@ -361,10 +368,11 @@ impl Pager { if dirty_pages.len() == 0 { return Ok(()); } + let db_size = self.db_header.borrow().database_size; for page_id in dirty_pages.iter() { let mut cache = self.page_cache.borrow_mut(); - let page = cache.get(page_id).expect("we somehow added a page to dirty list but we didn't mark it as dirty, causing cache to drop it."); - sqlite3_ondisk::begin_write_btree_page(self, &page)?; + let page = cache.get(&page_id).expect("we somehow added a page to dirty list but we didn't mark it as dirty, causing cache to drop it."); + self.wal.append_frame(page.clone(), db_size, self)?; } dirty_pages.clear(); self.io.run_once()?; diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index 9e0e695a3..7cd2b959f 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -87,11 +87,14 @@ pub struct DatabaseHeader { pub version_number: u32, } +pub const WAL_HEADER_SIZE: usize = 32; +pub const WAL_FRAME_HEADER_SIZE: usize = 24; + #[derive(Debug, Default)] pub struct WalHeader { magic: [u8; 4], file_format: u32, - page_size: u32, + pub page_size: u32, checkpoint_seq: u32, salt_1: u32, salt_2: u32, @@ -937,7 +940,7 @@ pub fn write_varint_to_vec(value: u64, payload: &mut Vec) { pub fn begin_read_wal_header(io: Rc) -> Result>> { let drop_fn = Rc::new(|_buf| {}); - let buf = Rc::new(RefCell::new(Buffer::allocate(32, drop_fn))); + let buf = Rc::new(RefCell::new(Buffer::allocate(WAL_HEADER_SIZE, drop_fn))); let result = Rc::new(RefCell::new(WalHeader::default())); let header = result.clone(); let complete = Box::new(move |buf: Rc>| { @@ -964,26 +967,86 @@ fn finish_read_wal_header(buf: Rc>, header: Rc, offset: usize, -) -> Result>> { - let drop_fn = Rc::new(|_buf| {}); - let buf = Rc::new(RefCell::new(Buffer::allocate(32, drop_fn))); - let result = Rc::new(RefCell::new(WalFrameHeader::default())); - let frame = result.clone(); + buffer_pool: Rc, + page: Rc>, +) -> Result<()> { + let buf = buffer_pool.get(); + let drop_fn = Rc::new(move |buf| { + let buffer_pool = buffer_pool.clone(); + buffer_pool.put(buf); + }); + let buf = Rc::new(RefCell::new(Buffer::new(buf, drop_fn))); + let frame = page.clone(); let complete = Box::new(move |buf: Rc>| { let frame = frame.clone(); - finish_read_wal_frame_header(buf, frame).unwrap(); + finish_read_page(2, buf, frame).unwrap(); }); let c = Rc::new(Completion::Read(ReadCompletion::new(buf, complete))); io.pread(offset, c)?; - Ok(result) + Ok(()) } -#[allow(dead_code)] -fn finish_read_wal_frame_header( +pub fn begin_write_wal_frame( + io: &Rc, + offset: usize, + page: &Rc>, + db_size: u32, +) -> Result<()> { + let page_finish = page.clone(); + let page_id = page.borrow().id; + + let header = WalFrameHeader { + page_number: page_id as u32, + db_size, + salt_1: 0, + salt_2: 0, + checksum_1: 0, + checksum_2: 0, + }; + let buffer = { + let page = page.borrow(); + let contents = page.contents.read().unwrap(); + let drop_fn = Rc::new(|_buf| {}); + let contents = contents.as_ref().unwrap(); + + let mut buffer = Buffer::allocate( + contents.buffer.borrow().len() + WAL_FRAME_HEADER_SIZE, + drop_fn, + ); + let buf = buffer.as_mut_slice(); + + buf[0..4].copy_from_slice(&header.page_number.to_ne_bytes()); + buf[4..8].copy_from_slice(&header.db_size.to_ne_bytes()); + buf[8..12].copy_from_slice(&header.salt_1.to_ne_bytes()); + buf[12..16].copy_from_slice(&header.salt_2.to_ne_bytes()); + buf[16..20].copy_from_slice(&header.checksum_1.to_ne_bytes()); + buf[20..24].copy_from_slice(&header.checksum_2.to_ne_bytes()); + buf[WAL_FRAME_HEADER_SIZE..].copy_from_slice(&contents.as_ptr()); + + Rc::new(RefCell::new(buffer)) + }; + + let write_complete = { + let buf_copy = buffer.clone(); + Box::new(move |bytes_written: i32| { + let buf_copy = buf_copy.clone(); + let buf_len = buf_copy.borrow().len(); + + page_finish.borrow_mut().clear_dirty(); + if bytes_written < buf_len as i32 { + log::error!("wrote({bytes_written}) less than expected({buf_len})"); + } + }) + }; + let c = Rc::new(Completion::Write(WriteCompletion::new(write_complete))); + io.pwrite(offset, buffer.clone(), c)?; + Ok(()) +} + +fn finish_read_wal_frame( buf: Rc>, frame: Rc>, ) -> Result<()> { diff --git a/core/storage/wal.rs b/core/storage/wal.rs index 10583e26b..eb77c204a 100644 --- a/core/storage/wal.rs +++ b/core/storage/wal.rs @@ -1,23 +1,51 @@ +use std::collections::HashMap; use std::{cell::RefCell, rc::Rc, sync::Arc}; use crate::io::{File, IO}; +use crate::storage::sqlite3_ondisk::{ + begin_read_page, begin_read_wal_frame, begin_write_wal_frame, WAL_FRAME_HEADER_SIZE, + WAL_HEADER_SIZE, +}; use crate::{storage::pager::Page, Result}; +use super::buffer_pool::BufferPool; +use super::pager::Pager; use super::sqlite3_ondisk; /// Write-ahead log (WAL). pub trait Wal { - /// Begin a write transaction. + /// Begin a read transaction. fn begin_read_tx(&self) -> Result<()>; - /// End a write transaction. + /// Begin a write transaction. + fn begin_write_tx(&self) -> Result<()>; + + /// End a read transaction. fn end_read_tx(&self) -> Result<()>; + /// End a write transaction. + fn end_write_tx(&self) -> Result<()>; + /// Find the latest frame containing a page. fn find_frame(&self, page_id: u64) -> Result>; /// Read a frame from the WAL. - fn read_frame(&self, frame_id: u64, page: Rc>) -> Result<()>; + fn read_frame( + &self, + frame_id: u64, + page: Rc>, + buffer_pool: Rc, + ) -> Result<()>; + + /// Write a frame to the WAL. + fn append_frame( + &self, + page: Rc>, + db_size: u32, + pager: &Pager, + ) -> Result; + + fn checkpoint(&self, pager: &Pager) -> Result; } #[cfg(feature = "fs")] @@ -26,29 +54,113 @@ pub struct WalFile { wal_path: String, file: RefCell>>, wal_header: RefCell>>>, + min_frame: RefCell, + max_frame: RefCell, + nbackfills: RefCell, + // Maps pgno to frame id and offset in wal file + frame_cache: RefCell>>, // FIXME: for now let's use a simple hashmap instead of a shm file + checkpoint_threshold: usize, +} + +enum CheckpointStatus { + Done, + IO, } #[cfg(feature = "fs")] impl Wal for WalFile { - /// Begin a write transaction. + /// Begin a read transaction. fn begin_read_tx(&self) -> Result<()> { + self.min_frame.replace(*self.nbackfills.borrow() + 1); Ok(()) } - /// End a write transaction. + /// End a read transaction. fn end_read_tx(&self) -> Result<()> { Ok(()) } /// Find the latest frame containing a page. - fn find_frame(&self, _page_id: u64) -> Result> { + fn find_frame(&self, page_id: u64) -> Result> { + let frame_cache = self.frame_cache.borrow(); + dbg!(&frame_cache); + let frames = frame_cache.get(&page_id); + dbg!(&frames); + if frames.is_none() { + return Ok(None); + } self.ensure_init()?; + let frames = frames.unwrap(); + for frame in frames.iter().rev() { + if *frame <= *self.max_frame.borrow() { + return Ok(Some(*frame)); + } + } Ok(None) } /// Read a frame from the WAL. - fn read_frame(&self, _frame_id: u64, _page: Rc>) -> Result<()> { - todo!(); + fn read_frame( + &self, + frame_id: u64, + page: Rc>, + buffer_pool: Rc, + ) -> Result<()> { + println!("read frame {}", frame_id); + let offset = self.frame_offset(frame_id); + begin_read_wal_frame( + self.file.borrow().as_ref().unwrap(), + offset, + buffer_pool, + page, + )?; + Ok(()) + } + + /// Write a frame to the WAL. + fn append_frame(&self, page: Rc>, db_size: u32, pager: &Pager) -> Result<()> { + self.ensure_init()?; + let page_id = page.borrow().id; + let frame_id = *self.max_frame.borrow(); + let offset = self.frame_offset(frame_id); + println!("appending {} at {}", frame_id, offset); + begin_write_wal_frame(self.file.borrow().as_ref().unwrap(), offset, &page, db_size)?; + self.max_frame.replace(frame_id + 1); + let mut frame_cache = self.frame_cache.borrow_mut(); + let frames = frame_cache.get_mut(&(page_id as u64)); + match frames { + Some(frames) => frames.push(frame_id), + None => { + frame_cache.insert(page_id as u64, vec![frame_id]); + } + } + dbg!(&frame_cache); + if (frame_id + 1) as usize >= self.checkpoint_threshold { + self.checkpoint(pager); + } + Ok(()) + } + + /// Begin a write transaction + fn begin_write_tx(&self) -> Result<()> { + Ok(()) + } + + /// End a write transaction + fn end_write_tx(&self) -> Result<()> { + Ok(()) + } + + fn checkpoint(&self, pager: &Pager) -> Result { + for (page_id, frames) in self.frame_cache.borrow().iter() { + // move page from WAL to database file + // TODO(Pere): use splice syscall in linux to do zero-copy file page movements to improve perf + let page = pager.read_page(*page_id as usize)?; + if page.borrow().is_locked() { + return Ok(CheckpointStatus::IO); + } + } + Ok(()) } } @@ -60,19 +172,42 @@ impl WalFile { wal_path, file: RefCell::new(None), wal_header: RefCell::new(None), + frame_cache: RefCell::new(HashMap::new()), + min_frame: RefCell::new(0), + max_frame: RefCell::new(0), + nbackfills: RefCell::new(0), + checkpoint_threshold: 1000, } } fn ensure_init(&self) -> Result<()> { + println!("ensure"); if self.file.borrow().is_none() { - if let Ok(file) = self.io.open_file(&self.wal_path) { - *self.file.borrow_mut() = Some(file.clone()); - let wal_header = sqlite3_ondisk::begin_read_wal_header(file)?; - // TODO: Return a completion instead. - self.io.run_once()?; - self.wal_header.replace(Some(wal_header)); - } + println!("inside ensure"); + match self.io.open_file(&self.wal_path) { + Ok(file) => { + *self.file.borrow_mut() = Some(file.clone()); + let wal_header = match sqlite3_ondisk::begin_read_wal_header(file) { + Ok(header) => header, + Err(err) => panic!("{:?}", err), + }; + // TODO: Return a completion instead. + self.io.run_once()?; + self.wal_header.replace(Some(wal_header)); + dbg!(&self.wal_header); + } + Err(err) => panic!("{:?}", err), + }; } Ok(()) } + + fn frame_offset(&self, frame_id: u64) -> usize { + let header = self.wal_header.borrow(); + let header = header.as_ref().unwrap().borrow(); + let page_size = header.page_size; + let page_offset = frame_id * (page_size as u64 + WAL_FRAME_HEADER_SIZE as u64); + let offset = WAL_HEADER_SIZE as u64 + WAL_FRAME_HEADER_SIZE as u64 + page_offset; + offset as usize + } } diff --git a/core/translate/emitter.rs b/core/translate/emitter.rs index 9c18e8cde..17c4b90fe 100644 --- a/core/translate/emitter.rs +++ b/core/translate/emitter.rs @@ -1,6 +1,7 @@ use std::cell::RefCell; use std::collections::HashMap; -use std::rc::Rc; +use std::rc::{Rc, Weak}; +use std::usize; use sqlite3_parser::ast; @@ -11,7 +12,7 @@ use crate::translate::plan::Search; use crate::types::{OwnedRecord, OwnedValue}; use crate::vdbe::builder::ProgramBuilder; use crate::vdbe::{BranchOffset, Insn, Program}; -use crate::Result; +use crate::{Connection, Result}; use super::expr::{ translate_aggregation, translate_condition_expr, translate_expr, translate_table_columns, @@ -1683,7 +1684,7 @@ fn epilogue( }); program.resolve_label(init_label, program.offset()); - program.emit_insn(Insn::Transaction); + program.emit_insn(Insn::Transaction { write: false }); program.emit_constant_insns(); program.emit_insn(Insn::Goto { @@ -1699,6 +1700,7 @@ pub fn emit_program( database_header: Rc>, mut plan: Plan, cache: ExpressionResultCache, + connection: Weak, ) -> Result { let (mut program, mut metadata, init_label, start_offset) = prologue(cache)?; loop { @@ -1717,7 +1719,7 @@ pub fn emit_program( } OpStepResult::Done => { epilogue(&mut program, &mut metadata, init_label, start_offset)?; - return Ok(program.build(database_header)); + return Ok(program.build(database_header, connection)); } } } diff --git a/core/translate/insert.rs b/core/translate/insert.rs index 7f764ae87..ea890e994 100644 --- a/core/translate/insert.rs +++ b/core/translate/insert.rs @@ -1,3 +1,4 @@ +use std::rc::Weak; use std::{cell::RefCell, ops::Deref, rc::Rc}; use sqlite3_parser::ast::{ @@ -5,13 +6,13 @@ use sqlite3_parser::ast::{ }; use crate::error::SQLITE_CONSTRAINT_PRIMARYKEY; -use crate::Result; use crate::{ schema::{Schema, Table}, storage::sqlite3_ondisk::DatabaseHeader, translate::expr::translate_expr, vdbe::{builder::ProgramBuilder, Insn, Program}, }; +use crate::{Connection, Result}; #[allow(clippy::too_many_arguments)] pub fn translate_insert( @@ -23,6 +24,7 @@ pub fn translate_insert( body: &InsertBody, _returning: &Option>, database_header: Rc>, + connection: Weak, ) -> Result { assert!(with.is_none()); assert!(or_conflict.is_none()); @@ -203,11 +205,11 @@ pub fn translate_insert( description: String::new(), }); program.resolve_label(init_label, program.offset()); - program.emit_insn(Insn::Transaction); + program.emit_insn(Insn::Transaction { write: true }); program.emit_constant_insns(); program.emit_insn(Insn::Goto { target_pc: start_offset, }); program.resolve_deferred_labels(); - Ok(program.build(database_header)) + Ok(program.build(database_header, connection)) } diff --git a/core/translate/mod.rs b/core/translate/mod.rs index 86219c833..dc77a00b3 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -16,13 +16,13 @@ pub(crate) mod planner; pub(crate) mod select; use std::cell::RefCell; -use std::rc::Rc; +use std::rc::{Rc, Weak}; use crate::schema::Schema; use crate::storage::pager::Pager; use crate::storage::sqlite3_ondisk::{DatabaseHeader, MIN_PAGE_CACHE_SIZE}; use crate::vdbe::{builder::ProgramBuilder, Insn, Program}; -use crate::{bail_parse_error, Result}; +use crate::{bail_parse_error, Connection, Result}; use insert::translate_insert; use select::translate_select; use sqlite3_parser::ast; @@ -33,6 +33,7 @@ pub fn translate( stmt: ast::Stmt, database_header: Rc>, pager: Rc, + connection: Weak, ) -> Result { match stmt { ast::Stmt::AlterTable(_, _) => bail_parse_error!("ALTER TABLE not supported yet"), @@ -53,12 +54,14 @@ pub fn translate( ast::Stmt::DropTable { .. } => bail_parse_error!("DROP TABLE not supported yet"), ast::Stmt::DropTrigger { .. } => bail_parse_error!("DROP TRIGGER not supported yet"), ast::Stmt::DropView { .. } => bail_parse_error!("DROP VIEW not supported yet"), - ast::Stmt::Pragma(name, body) => translate_pragma(&name, body, database_header, pager), + ast::Stmt::Pragma(name, body) => { + translate_pragma(&name, body, database_header, pager, connection) + } ast::Stmt::Reindex { .. } => bail_parse_error!("REINDEX not supported yet"), ast::Stmt::Release(_) => bail_parse_error!("RELEASE not supported yet"), ast::Stmt::Rollback { .. } => bail_parse_error!("ROLLBACK not supported yet"), ast::Stmt::Savepoint(_) => bail_parse_error!("SAVEPOINT not supported yet"), - ast::Stmt::Select(select) => translate_select(schema, select, database_header), + ast::Stmt::Select(select) => translate_select(schema, select, database_header, connection), ast::Stmt::Update { .. } => bail_parse_error!("UPDATE not supported yet"), ast::Stmt::Vacuum(_, _) => bail_parse_error!("VACUUM not supported yet"), ast::Stmt::Insert { @@ -77,6 +80,7 @@ pub fn translate( &body, &returning, database_header, + connection, ), } } @@ -86,6 +90,7 @@ fn translate_pragma( body: Option, database_header: Rc>, pager: Rc, + connection: Weak, ) -> Result { let mut program = ProgramBuilder::new(); let init_label = program.allocate_label(); @@ -96,6 +101,7 @@ fn translate_pragma( init_label, ); let start_offset = program.offset(); + let mut write = false; match body { None => { let pragma_result = program.alloc_register(); @@ -124,6 +130,7 @@ fn translate_pragma( }, _ => 0, }; + write = true; update_pragma( &name.name.0, value_to_update, @@ -140,13 +147,13 @@ fn translate_pragma( description: String::new(), }); program.resolve_label(init_label, program.offset()); - program.emit_insn(Insn::Transaction); + program.emit_insn(Insn::Transaction { write }); program.emit_constant_insns(); program.emit_insn(Insn::Goto { target_pc: start_offset, }); program.resolve_deferred_labels(); - Ok(program.build(database_header)) + Ok(program.build(database_header, connection)) } fn update_pragma(name: &str, value: i64, header: Rc>, pager: Rc) { diff --git a/core/translate/select.rs b/core/translate/select.rs index 07ea7d8f2..d486f6c23 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -1,6 +1,8 @@ +use std::rc::Weak; use std::{cell::RefCell, rc::Rc}; use crate::storage::sqlite3_ondisk::DatabaseHeader; +use crate::Connection; use crate::{schema::Schema, vdbe::Program, Result}; use sqlite3_parser::ast; @@ -12,8 +14,14 @@ pub fn translate_select( schema: &Schema, select: ast::Select, database_header: Rc>, + connection: Weak, ) -> Result { let select_plan = prepare_select_plan(schema, select)?; let (optimized_plan, expr_result_cache) = optimize_plan(select_plan)?; - emit_program(database_header, optimized_plan, expr_result_cache) + emit_program( + database_header, + optimized_plan, + expr_result_cache, + connection, + ) } diff --git a/core/vdbe/builder.rs b/core/vdbe/builder.rs index 313967c5d..ee2bdb613 100644 --- a/core/vdbe/builder.rs +++ b/core/vdbe/builder.rs @@ -1,6 +1,10 @@ -use std::{cell::RefCell, collections::HashMap, rc::Rc}; +use std::{ + cell::RefCell, + collections::HashMap, + rc::{Rc, Weak}, +}; -use crate::storage::sqlite3_ondisk::DatabaseHeader; +use crate::{storage::sqlite3_ondisk::DatabaseHeader, Connection}; use super::{BranchOffset, CursorID, Insn, InsnReference, Program, Table}; @@ -354,7 +358,11 @@ impl ProgramBuilder { self.deferred_label_resolutions.clear(); } - pub fn build(self, database_header: Rc>) -> Program { + pub fn build( + self, + database_header: Rc>, + connection: Weak, + ) -> Program { assert!( self.deferred_label_resolutions.is_empty(), "deferred_label_resolutions is not empty when build() is called, did you forget to call resolve_deferred_labels()?" @@ -369,6 +377,8 @@ impl ProgramBuilder { cursor_ref: self.cursor_ref, database_header, comments: self.comments, + connection, + auto_commit: true, } } } diff --git a/core/vdbe/explain.rs b/core/vdbe/explain.rs index de918398a..b437b9d5d 100644 --- a/core/vdbe/explain.rs +++ b/core/vdbe/explain.rs @@ -395,10 +395,10 @@ pub fn insn_to_str( 0, "".to_string(), ), - Insn::Transaction => ( + Insn::Transaction { write } => ( "Transaction", 0, - 0, + *write as i32, 0, OwnedValue::Text(Rc::new("".to_string())), 0, diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index f3ee80fa4..71d4b6f54 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -33,7 +33,8 @@ use crate::storage::{btree::BTreeCursor, pager::Pager}; use crate::types::{ AggContext, Cursor, CursorResult, OwnedRecord, OwnedValue, Record, SeekKey, SeekOp, }; -use crate::{Result, DATABASE_VERSION}; +use crate::DATABASE_VERSION; +use crate::{Connection, Result, TransactionState}; use datetime::{exec_date, exec_time, exec_unixepoch}; @@ -44,7 +45,7 @@ use std::borrow::BorrowMut; use std::cell::RefCell; use std::collections::{BTreeMap, HashMap}; use std::fmt::Display; -use std::rc::Rc; +use std::rc::{Rc, Weak}; pub type BranchOffset = i64; @@ -240,7 +241,9 @@ pub enum Insn { }, // Start a transaction. - Transaction, + Transaction { + write: bool, + }, // Branch to the given PC. Goto { @@ -529,6 +532,8 @@ pub struct Program { pub cursor_ref: Vec<(Option, Option)>, pub database_header: Rc>, pub comments: HashMap, + pub connection: Weak, + pub auto_commit: bool, } impl Program { @@ -555,6 +560,7 @@ impl Program { state: &'a mut ProgramState, pager: Rc, ) -> Result> { + dbg!(&self.connection.upgrade().is_none()); loop { let insn = &self.insns[state.pc as usize]; trace_insn(self, state.pc as InsnReference, insn); @@ -1093,11 +1099,36 @@ impl Program { ))); } } - pager.end_read_tx()?; + if self.auto_commit { + pager.end_tx()?; + } return Ok(StepResult::Done); } - Insn::Transaction => { - pager.begin_read_tx()?; + Insn::Transaction { write } => { + let connection = self.connection.upgrade().unwrap(); + if let Some(db) = connection.db.upgrade() { + // TODO(pere): are backpointers good ?? this looks ugly af + // upgrade transaction if needed + let new_transaction_state = + match (db.transaction_state.borrow().clone(), write) { + (crate::TransactionState::Write, true) => TransactionState::Write, + (crate::TransactionState::Write, false) => TransactionState::Write, + (crate::TransactionState::Read, true) => TransactionState::Write, + (crate::TransactionState::Read, false) => TransactionState::Read, + (crate::TransactionState::None, true) => TransactionState::Read, + (crate::TransactionState::None, false) => TransactionState::Read, + }; + // TODO(Pere): + // 1. lock wal + // 2. lock shared + // 3. lock write db if write + db.transaction_state.replace(new_transaction_state.clone()); + if matches!(new_transaction_state, TransactionState::Write) { + pager.begin_read_tx()?; + } else { + pager.begin_write_tx()?; + } + } state.pc += 1; } Insn::Goto { target_pc } => { diff --git a/sqlite3/src/lib.rs b/sqlite3/src/lib.rs index 9990a3320..67c975fb6 100644 --- a/sqlite3/src/lib.rs +++ b/sqlite3/src/lib.rs @@ -5,6 +5,7 @@ use log::trace; use std::cell::RefCell; use std::ffi; +use std::rc::Rc; use std::sync::Arc; macro_rules! stub { @@ -32,8 +33,8 @@ pub mod util; use util::sqlite3_safety_check_sick_or_ok; pub struct sqlite3 { - pub(crate) _db: limbo_core::Database, - pub(crate) conn: limbo_core::Connection, + pub(crate) _db: Rc, + pub(crate) conn: Rc, pub(crate) err_code: ffi::c_int, pub(crate) err_mask: ffi::c_int, pub(crate) malloc_failed: bool, @@ -42,7 +43,7 @@ pub struct sqlite3 { } impl sqlite3 { - pub fn new(db: limbo_core::Database, conn: limbo_core::Connection) -> Self { + pub fn new(db: Rc, conn: Rc) -> Self { Self { _db: db, conn, diff --git a/test/src/lib.rs b/test/src/lib.rs index 46ea9f704..d77c528d3 100644 --- a/test/src/lib.rs +++ b/test/src/lib.rs @@ -1,5 +1,6 @@ use limbo_core::Database; use std::path::PathBuf; +use std::rc::Rc; use std::sync::Arc; use tempfile::TempDir; @@ -23,7 +24,7 @@ impl TempDatabase { Self { path, io } } - pub fn connect_limbo(&self) -> limbo_core::Connection { + pub fn connect_limbo(&self) -> Rc { let db = Database::open_file(self.io.clone(), self.path.to_str().unwrap()).unwrap(); db.connect()