diff --git a/core/benches/mvcc_benchmark.rs b/core/benches/mvcc_benchmark.rs index 547a473b4..ecce8c06d 100644 --- a/core/benches/mvcc_benchmark.rs +++ b/core/benches/mvcc_benchmark.rs @@ -47,9 +47,7 @@ fn bench(c: &mut Criterion) { let conn = &db.conn; let tx_id = db.mvcc_store.begin_tx(conn.get_pager().clone()).unwrap(); let mv_store = &db.mvcc_store; - let mut sm = mv_store - .commit_tx(tx_id, conn.get_pager().clone(), conn) - .unwrap(); + let mut sm = mv_store.commit_tx(tx_id, conn).unwrap(); // TODO: sync IO hack loop { let res = sm.step(mv_store).unwrap(); @@ -76,9 +74,7 @@ fn bench(c: &mut Criterion) { ) .unwrap(); let mv_store = &db.mvcc_store; - let mut sm = mv_store - .commit_tx(tx_id, conn.get_pager().clone(), conn) - .unwrap(); + let mut sm = mv_store.commit_tx(tx_id, conn).unwrap(); // TODO: sync IO hack loop { let res = sm.step(mv_store).unwrap(); @@ -111,9 +107,7 @@ fn bench(c: &mut Criterion) { ) .unwrap(); let mv_store = &db.mvcc_store; - let mut sm = mv_store - .commit_tx(tx_id, conn.get_pager().clone(), conn) - .unwrap(); + let mut sm = mv_store.commit_tx(tx_id, conn).unwrap(); // TODO: sync IO hack loop { let res = sm.step(mv_store).unwrap(); diff --git a/core/mvcc/database/checkpoint_state_machine.rs b/core/mvcc/database/checkpoint_state_machine.rs new file mode 100644 index 000000000..5e4a9c4ff --- /dev/null +++ b/core/mvcc/database/checkpoint_state_machine.rs @@ -0,0 +1,540 @@ +use crate::mvcc::clock::LogicalClock; +use crate::mvcc::database::{ + DeleteRowStateMachine, MvStore, RowVersion, TxTimestampOrID, WriteRowStateMachine, +}; +use crate::state_machine::{StateMachine, StateTransition, TransitionResult}; +use crate::storage::btree::BTreeCursor; +use crate::storage::pager::CreateBTreeFlags; +use crate::storage::wal::{CheckpointMode, TursoRwLock}; +use crate::types::{IOResult, ImmutableRecord, RecordCursor}; +use crate::{CheckpointResult, Connection, IOExt, Pager, RefValue, Result, TransactionState}; +use parking_lot::RwLock; +use std::collections::HashMap; +use std::sync::atomic::Ordering; +use std::sync::Arc; + +#[derive(Debug)] +pub enum CheckpointState { + AcquireLock, + BeginPagerTxn, + WriteRow { + write_set_index: usize, + requires_seek: bool, + }, + WriteRowStateMachine { + write_set_index: usize, + }, + DeleteRowStateMachine { + write_set_index: usize, + }, + CommitPagerTxn, + TruncateLogicalLog, + FsyncLogicalLog, + CheckpointWal, + Finalize, +} + +/// The states of the locks held by the state machine - these are tracked for error handling so that they are +/// released if the state machine fails. +pub struct LockStates { + blocking_checkpoint_lock_held: bool, + pager_read_tx: bool, + pager_write_tx: bool, +} + +/// A state machine that performs a complete checkpoint operation on the MVCC store. +/// +/// The checkpoint process: +/// 1. Takes a blocking lock on the database so that no other transactions can run during the checkpoint. +/// 2. Determines which row versions should be written to the B-tree. +/// 3. Begins a pager transaction +/// 4. Writes all the selected row versions to the B-tree. +/// 5. Commits the pager transaction, effectively flushing to the WAL +/// 6. Truncates the logical log file +/// 7. Immediately does a TRUNCATE checkpoint from the WAL to the DB +/// 8. Releases the blocking_checkpoint_lock +pub struct CheckpointStateMachine { + /// The current state of the state machine + state: CheckpointState, + /// The states of the locks held by the state machine - these are tracked for error handling so that they are + /// released if the state machine fails. + lock_states: LockStates, + /// The highest transaction ID that has been checkpointed in a previous checkpoint. + checkpointed_txid_max_old: u64, + /// The highest transaction ID that will be checkpointed in the current checkpoint. + checkpointed_txid_max_new: u64, + /// Pager used for writing to the B-tree + pager: Arc, + /// MVCC store containing the row versions. + mvstore: Arc>, + /// Connection to the database + connection: Arc, + /// Lock used to block other transactions from running during the checkpoint + checkpoint_lock: Arc, + /// All committed versions to write to the B-tree. + /// In the case of CREATE TABLE / DROP TABLE ops, contains a [SpecialWrite] to create/destroy the B-tree. + write_set: Vec<(RowVersion, Option)>, + /// State machine for writing rows to the B-tree + write_row_state_machine: Option>, + /// State machine for deleting rows from the B-tree + delete_row_state_machine: Option>, + /// Cursors for the B-trees + cursors: HashMap>>, + /// Result of the checkpoint + checkpoint_result: Option, +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +/// Special writes for CREATE TABLE / DROP TABLE ops. +/// These are used to create/destroy B-trees during pager ops. +pub enum SpecialWrite { + BTreeCreate { root_page: u64 }, + BTreeDestroy { root_page: u64, num_columns: usize }, +} + +impl CheckpointStateMachine { + pub fn new( + pager: Arc, + mvstore: Arc>, + connection: Arc, + ) -> Self { + let checkpoint_lock = mvstore.blocking_checkpoint_lock.clone(); + Self { + state: CheckpointState::AcquireLock, + lock_states: LockStates { + blocking_checkpoint_lock_held: false, + pager_read_tx: false, + pager_write_tx: false, + }, + pager, + checkpointed_txid_max_old: mvstore.checkpointed_txid_max.load(Ordering::SeqCst), + checkpointed_txid_max_new: mvstore.checkpointed_txid_max.load(Ordering::SeqCst), + mvstore, + connection, + checkpoint_lock, + write_set: Vec::new(), + write_row_state_machine: None, + delete_row_state_machine: None, + cursors: HashMap::new(), + checkpoint_result: None, + } + } + + /// Collect all committed versions that need to be written to the B-tree. + /// We must only write to the B-tree if: + /// 1. The row has not already been checkpointed in a previous checkpoint. + /// TODO: garbage collect row versions after checkpointing. + /// 2. Either: + /// * The row is not a delete (we inserted or changed an existing row), OR + /// * The row is a delete AND it exists in the database file already. + /// If the row didn't exist in the database file and was deleted, we can simply not write it. + fn collect_committed_versions(&mut self) { + // Keep track of the highest timestamp that will be checkpointed in the current checkpoint; + // This value will be used at the end of the checkpoint to update the corresponding value in + // the MVCC store, so that we don't checkpoint the same row versions again on the next checkpoint. + let mut max_timestamp = self.checkpointed_txid_max_old; + + for entry in self.mvstore.rows.iter() { + let row_versions = entry.value().read(); + let mut exists_in_db_file = false; + for (i, version) in row_versions.iter().enumerate() { + let is_last = i == row_versions.len() - 1; + if let TxTimestampOrID::Timestamp(ts) = &version.begin { + if *ts <= self.checkpointed_txid_max_old { + exists_in_db_file = true; + } + + let current_version_ts = + if let Some(TxTimestampOrID::Timestamp(ts_end)) = version.end { + ts_end.max(*ts) + } else { + *ts + }; + if current_version_ts <= self.checkpointed_txid_max_old { + // already checkpointed. TODO: garbage collect row versions after checkpointing. + continue; + } + + let get_root_page = |row_data: &Vec| { + let row_data = ImmutableRecord::from_bin_record(row_data.clone()); + let mut record_cursor = RecordCursor::new(); + record_cursor.parse_full_header(&row_data).unwrap(); + let RefValue::Integer(root_page) = + record_cursor.get_value(&row_data, 3).unwrap() + else { + panic!( + "Expected integer value for root page, got {:?}", + record_cursor.get_value(&row_data, 3) + ); + }; + root_page as u64 + }; + + max_timestamp = max_timestamp.max(current_version_ts); + if is_last { + let is_delete = version.end.is_some(); + let should_be_deleted_from_db_file = is_delete && exists_in_db_file; + + // We might need to create or destroy a B-tree in the pager during checkpoint if a row in root page 1 is deleted or created. + let special_write = + if should_be_deleted_from_db_file && version.row.id.table_id == 1 { + let root_page = get_root_page(&version.row.data); + Some(SpecialWrite::BTreeDestroy { + root_page, + num_columns: version.row.column_count, + }) + } else if !exists_in_db_file && version.row.id.table_id == 1 { + let root_page = get_root_page(&version.row.data); + Some(SpecialWrite::BTreeCreate { root_page }) + } else { + None + }; + + // Only write the row to the B-tree if it is not a delete, or if it is a delete and it exists in the database file. + if !is_delete || should_be_deleted_from_db_file { + self.write_set.push((version.clone(), special_write)); + } + } + } + } + } + self.checkpointed_txid_max_new = max_timestamp; + } + + /// Get the current row version to write to the B-tree + fn get_current_row_version( + &self, + write_set_index: usize, + ) -> Option<&(RowVersion, Option)> { + self.write_set.get(write_set_index) + } + + /// Check if we have more rows to write + fn has_more_rows(&self, write_set_index: usize) -> bool { + write_set_index < self.write_set.len() + } + + /// Fsync the logical log file + fn fsync_logical_log(&self) -> Result> { + self.mvstore.storage.sync() + } + + /// Truncate the logical log file + fn truncate_logical_log(&self) -> Result> { + self.mvstore.storage.truncate() + } + + /// Perform a TRUNCATE checkpoint on the WAL + fn checkpoint_wal(&self) -> Result> { + let Some(wal) = &self.pager.wal else { + panic!("No WAL to checkpoint"); + }; + let mut wal_ref = wal.borrow_mut(); + match wal_ref.checkpoint( + &self.pager, + CheckpointMode::Truncate { + upper_bound_inclusive: None, + }, + )? { + IOResult::Done(result) => Ok(IOResult::Done(result)), + IOResult::IO(io) => Ok(IOResult::IO(io)), + } + } + + fn step_inner(&mut self, _context: &()) -> Result> { + match &self.state { + CheckpointState::AcquireLock => { + tracing::debug!("Acquiring blocking checkpoint lock"); + let locked = self.checkpoint_lock.write(); + if !locked { + return Err(crate::LimboError::Busy); + } + self.lock_states.blocking_checkpoint_lock_held = true; + + self.collect_committed_versions(); + tracing::debug!("Collected {} committed versions", self.write_set.len()); + + if self.write_set.is_empty() { + // Nothing to checkpoint, skip to truncate logical log + self.state = CheckpointState::TruncateLogicalLog; + } else { + self.state = CheckpointState::BeginPagerTxn; + } + Ok(TransitionResult::Continue) + } + CheckpointState::BeginPagerTxn => { + tracing::debug!("Beginning pager transaction"); + // Start a pager transaction to write committed versions to B-tree + let result = self.pager.begin_read_tx(); + if let Err(crate::LimboError::Busy) = result { + return Err(crate::LimboError::Busy); + } + result?; + self.lock_states.pager_read_tx = true; + + let result = self.pager.io.block(|| self.pager.begin_write_tx()); + if let Err(crate::LimboError::Busy) = result { + return Err(crate::LimboError::Busy); + } + result?; + *self.connection.transaction_state.write() = TransactionState::Write { + schema_did_change: false, + }; // TODO: schema_did_change?? + self.lock_states.pager_write_tx = true; + self.state = CheckpointState::WriteRow { + write_set_index: 0, + requires_seek: true, + }; + Ok(TransitionResult::Continue) + } + + CheckpointState::WriteRow { + write_set_index, + requires_seek, + } => { + let write_set_index = *write_set_index; + let requires_seek = *requires_seek; + + if !self.has_more_rows(write_set_index) { + // Done writing all rows + self.state = CheckpointState::CommitPagerTxn; + return Ok(TransitionResult::Continue); + } + + let (num_columns, table_id, special_write) = { + let (row_version, special_write) = + self.get_current_row_version(write_set_index).unwrap(); + ( + row_version.row.column_count, + row_version.row.id.table_id, + *special_write, + ) + }; + + // Handle CREATE TABLE / DROP TABLE ops + if let Some(special_write) = special_write { + match special_write { + SpecialWrite::BTreeCreate { root_page } => { + let created_root_page = self.pager.io.block(|| { + self.pager.btree_create(&CreateBTreeFlags::new_table()) + })?; + assert_eq!(created_root_page as u64, root_page, "Created root page does not match expected root page: {created_root_page} != {root_page}"); + } + SpecialWrite::BTreeDestroy { + root_page, + num_columns, + } => { + let cursor = if let Some(cursor) = self.cursors.get(&root_page) { + cursor.clone() + } else { + let cursor = BTreeCursor::new_table( + None, + self.pager.clone(), + root_page as usize, + num_columns, + ); + let cursor = Arc::new(RwLock::new(cursor)); + self.cursors.insert(root_page, cursor.clone()); + cursor + }; + self.pager.io.block(|| cursor.write().btree_destroy())?; + self.cursors.remove(&root_page); + } + } + } + + // Get or create cursor for this table + let cursor = if let Some(cursor) = self.cursors.get(&table_id) { + cursor.clone() + } else { + let cursor = BTreeCursor::new_table( + None, // Write directly to B-tree + self.pager.clone(), + table_id as usize, + num_columns, + ); + let cursor = Arc::new(RwLock::new(cursor)); + self.cursors.insert(table_id, cursor.clone()); + cursor + }; + + let (row_version, _) = self.get_current_row_version(write_set_index).unwrap(); + + // Check if this is an insert or delete + if row_version.end.is_some() { + // This is a delete operation + let state_machine = self + .mvstore + .delete_row_from_pager(row_version.row.id, cursor)?; + self.delete_row_state_machine = Some(state_machine); + self.state = CheckpointState::DeleteRowStateMachine { write_set_index }; + } else { + // This is an insert/update operation + let state_machine = + self.mvstore + .write_row_to_pager(&row_version.row, cursor, requires_seek)?; + self.write_row_state_machine = Some(state_machine); + self.state = CheckpointState::WriteRowStateMachine { write_set_index }; + } + + Ok(TransitionResult::Continue) + } + + CheckpointState::WriteRowStateMachine { write_set_index } => { + let write_set_index = *write_set_index; + let write_row_state_machine = self.write_row_state_machine.as_mut().unwrap(); + + match write_row_state_machine.step(&())? { + IOResult::IO(io) => Ok(TransitionResult::Io(io)), + IOResult::Done(_) => { + self.state = CheckpointState::WriteRow { + write_set_index: write_set_index + 1, + requires_seek: true, + }; + Ok(TransitionResult::Continue) + } + } + } + + CheckpointState::DeleteRowStateMachine { write_set_index } => { + let write_set_index = *write_set_index; + let delete_row_state_machine = self.delete_row_state_machine.as_mut().unwrap(); + + match delete_row_state_machine.step(&())? { + IOResult::IO(io) => Ok(TransitionResult::Io(io)), + IOResult::Done(_) => { + self.state = CheckpointState::WriteRow { + write_set_index: write_set_index + 1, + requires_seek: true, + }; + Ok(TransitionResult::Continue) + } + } + } + + CheckpointState::CommitPagerTxn => { + tracing::debug!("Committing pager transaction"); + let result = self.pager.end_tx(false, &self.connection)?; + match result { + IOResult::Done(_) => { + self.state = CheckpointState::TruncateLogicalLog; + self.lock_states.pager_read_tx = false; + self.lock_states.pager_write_tx = false; + *self.connection.transaction_state.write() = TransactionState::None; + let header = self + .pager + .io + .block(|| { + self.pager.with_header_mut(|header| { + header.schema_cookie = self + .connection + .db + .schema + .lock() + .unwrap() + .schema_version + .into(); + *header + }) + }) + .unwrap(); + self.mvstore.global_header.write().replace(header); + Ok(TransitionResult::Continue) + } + IOResult::IO(io) => Ok(TransitionResult::Io(io)), + } + } + + CheckpointState::TruncateLogicalLog => { + tracing::debug!("Truncating logical log file"); + match self.truncate_logical_log()? { + IOResult::Done(_) => { + self.state = CheckpointState::FsyncLogicalLog; + Ok(TransitionResult::Continue) + } + IOResult::IO(io) => { + if io.finished() { + self.state = CheckpointState::CheckpointWal; + Ok(TransitionResult::Continue) + } else { + Ok(TransitionResult::Io(io)) + } + } + } + } + + CheckpointState::FsyncLogicalLog => { + tracing::debug!("Fsyncing logical log file"); + match self.fsync_logical_log()? { + IOResult::Done(_) => { + self.state = CheckpointState::CheckpointWal; + Ok(TransitionResult::Continue) + } + IOResult::IO(io) => Ok(TransitionResult::Io(io)), + } + } + + CheckpointState::CheckpointWal => { + tracing::debug!("Performing TRUNCATE checkpoint on WAL"); + match self.checkpoint_wal()? { + IOResult::Done(result) => { + self.checkpoint_result = Some(result); + self.state = CheckpointState::Finalize; + Ok(TransitionResult::Continue) + } + IOResult::IO(io) => Ok(TransitionResult::Io(io)), + } + } + + CheckpointState::Finalize => { + tracing::debug!("Releasing blocking checkpoint lock"); + self.mvstore + .checkpointed_txid_max + .store(self.checkpointed_txid_max_new, Ordering::SeqCst); + self.checkpoint_lock.unlock(); + self.finalize(&())?; + Ok(TransitionResult::Done( + self.checkpoint_result.take().unwrap(), + )) + } + } + } +} + +impl StateTransition for CheckpointStateMachine { + type Context = (); + type SMResult = CheckpointResult; + + fn step(&mut self, _context: &Self::Context) -> Result> { + let res = self.step_inner(&()); + match res { + Err(err) => { + tracing::info!("Error in checkpoint state machine: {err}"); + if self.lock_states.pager_write_tx { + let rollback = true; + self.pager + .io + .block(|| self.pager.end_tx(rollback, self.connection.as_ref())) + .expect("failed to end pager write tx"); + *self.connection.transaction_state.write() = TransactionState::None; + } else if self.lock_states.pager_read_tx { + self.pager.end_read_tx().unwrap(); + *self.connection.transaction_state.write() = TransactionState::None; + } + if self.lock_states.blocking_checkpoint_lock_held { + self.checkpoint_lock.unlock(); + } + Err(err) + } + Ok(result) => Ok(result), + } + } + + fn finalize(&mut self, _context: &Self::Context) -> Result<()> { + Ok(()) + } + + fn is_finalized(&self) -> bool { + matches!(self.state, CheckpointState::Finalize) + } +} diff --git a/core/mvcc/database/mod.rs b/core/mvcc/database/mod.rs index 95b36beb9..f78c46d1a 100644 --- a/core/mvcc/database/mod.rs +++ b/core/mvcc/database/mod.rs @@ -19,7 +19,6 @@ use crate::Result; use crate::{Connection, Pager}; use crossbeam_skiplist::{SkipMap, SkipSet}; use parking_lot::RwLock; -use std::collections::HashMap; use std::collections::HashSet; use std::fmt::Debug; use std::marker::PhantomData; @@ -29,6 +28,9 @@ use std::sync::Arc; use tracing::instrument; use tracing::Level; +pub mod checkpoint_state_machine; +pub use checkpoint_state_machine::{CheckpointState, CheckpointStateMachine}; + #[cfg(test)] pub mod tests; @@ -262,41 +264,11 @@ impl AtomicTransactionState { #[derive(Debug)] pub enum CommitState { Initial, - BeginPagerTxn { - end_ts: u64, - }, - WriteRow { - end_ts: u64, - write_set_index: usize, - requires_seek: bool, - }, - WriteRowStateMachine { - end_ts: u64, - write_set_index: usize, - }, - DeleteRowStateMachine { - end_ts: u64, - write_set_index: usize, - }, - CommitPagerTxn { - end_ts: u64, - }, - Commit { - end_ts: u64, - }, - BeginCommitLogicalLog { - end_ts: u64, - log_record: LogRecord, - }, - EndCommitLogicalLog { - end_ts: u64, - }, - SyncLogicalLog { - end_ts: u64, - }, - CommitEnd { - end_ts: u64, - }, + Commit { end_ts: u64 }, + BeginCommitLogicalLog { end_ts: u64, log_record: LogRecord }, + EndCommitLogicalLog { end_ts: u64 }, + SyncLogicalLog { end_ts: u64 }, + CommitEnd { end_ts: u64 }, } #[derive(Debug)] @@ -311,21 +283,16 @@ pub enum WriteRowState { #[derive(Debug)] struct CommitCoordinator { pager_commit_lock: Arc, - commits_waiting: Arc, } pub struct CommitStateMachine { state: CommitState, is_finalized: bool, - pager: Arc, tx_id: TxID, connection: Arc, /// Write set sorted by table id and row id write_set: Vec, - write_row_state_machine: Option>, - delete_row_state_machine: Option>, commit_coordinator: Arc, - cursors: HashMap>>, header: Arc>>, _phantom: PhantomData, } @@ -365,7 +332,6 @@ pub struct DeleteRowStateMachine { impl CommitStateMachine { fn new( state: CommitState, - pager: Arc, tx_id: TxID, connection: Arc, commit_coordinator: Arc, @@ -374,46 +340,14 @@ impl CommitStateMachine { Self { state, is_finalized: false, - pager, tx_id, connection, write_set: Vec::new(), - write_row_state_machine: None, - delete_row_state_machine: None, commit_coordinator, - cursors: HashMap::new(), header, _phantom: PhantomData, } } - - /// We need to update pager's header to account for changes made by other transactions. - fn update_pager_header(&self, mvcc_store: &MvStore) -> Result<()> { - let header = self.header.read(); - let last_commited_header = header.as_ref().expect("Header not found"); - self.pager.io.block(|| self.pager.maybe_allocate_page1())?; - let _ = self.pager.io.block(|| { - self.pager.with_header_mut(|header_in_pager| { - let header_in_transaction = mvcc_store.get_transaction_database_header(&self.tx_id); - tracing::debug!("update header here {}", header_in_transaction.schema_cookie); - // database_size should only be updated in each commit so it should be safe to assume correct database_size is in last_commited_header - header_in_pager.database_size = last_commited_header.database_size; - if header_in_transaction.schema_cookie < last_commited_header.schema_cookie { - tracing::error!("txn's schema cookie went back in time, aborting"); - return Err(LimboError::SchemaUpdated); - } - - assert!( - header_in_transaction.schema_cookie >= last_commited_header.schema_cookie, - "txn's schema cookie went back in time" - ); - header_in_pager.schema_cookie = header_in_transaction.schema_cookie; - // TODO: deal with other fields - Ok(()) - }) - })?; - Ok(()) - } } impl WriteRowStateMachine { @@ -541,245 +475,13 @@ impl StateTransition for CommitStateMachine { mvcc_store.release_exclusive_tx(&self.tx_id); self.commit_coordinator.pager_commit_lock.unlock(); } + mvcc_store.remove_tx(self.tx_id); self.finalize(mvcc_store)?; return Ok(TransitionResult::Done(())); } self.state = CommitState::Commit { end_ts }; Ok(TransitionResult::Continue) } - CommitState::BeginPagerTxn { end_ts } => { - // FIXME: how do we deal with multiple concurrent writes? - // WAL requires a txn to be written sequentially. Either we: - // 1. Wait for currently writer to finish before second txn starts. - // 2. Choose a txn to write depending on some heuristics like amount of frames will be written. - // 3. .. - - // If this is the exclusive transaction, we already acquired a write transaction - // on the pager in begin_exclusive_tx() and don't need to acquire it. - if mvcc_store.is_exclusive_tx(&self.tx_id) { - self.update_pager_header(mvcc_store)?; - self.state = CommitState::WriteRow { - end_ts: *end_ts, - write_set_index: 0, - requires_seek: true, - }; - return Ok(TransitionResult::Continue); - } else if mvcc_store.has_exclusive_tx() { - // There is an exclusive transaction holding the write lock. We must abort. - return Err(LimboError::WriteWriteConflict); - } - // Currently txns are queued without any heuristics whasoever. This is important because - // we need to ensure writes to disk happen sequentially. - // * We don't want txns to write to WAL in parallel. - // * We don't want BTree modifications to happen in parallel. - // If any of these were to happen, we would find ourselves in a bad corruption situation. - - // NOTE: since we are blocking for `begin_write_tx` we do not care about re-entrancy right now. - let locked = self.commit_coordinator.pager_commit_lock.write(); - if !locked { - self.commit_coordinator - .commits_waiting - .fetch_add(1, Ordering::SeqCst); - // FIXME: IOCompletions still needs a yield variant... - return Ok(TransitionResult::Io(crate::types::IOCompletions::Single( - Completion::new_dummy(), - ))); - } - - self.update_pager_header(mvcc_store)?; - - { - let mut wal = self.pager.wal.as_ref().unwrap().borrow_mut(); - // we need to update the max frame to the latest shared max frame in order to avoid snapshot staleness - wal.update_max_frame(); - } - - // We started a pager read transaction at the beginning of the MV transaction, because - // any reads we do from the database file and WAL must uphold snapshot isolation. - // However, now we must end and immediately restart the read transaction before committing. - // This is because other transactions may have committed writes to the DB file or WAL, - // and our pager must read in those changes when applying our writes; otherwise we would overwrite - // the changes from the previous committed transactions. - // - // Note that this would be incredibly unsafe in the regular transaction model, but in MVCC we trust - // the MV-store to uphold the guarantee that no write-write conflicts happened. - self.pager.end_read_tx().expect("end_read_tx cannot fail"); - let result = self.pager.begin_read_tx(); - if let Err(LimboError::Busy) = result { - // We cannot obtain a WAL read lock due to contention, so we must abort. - self.commit_coordinator.pager_commit_lock.unlock(); - return Err(LimboError::WriteWriteConflict); - } - result?; - let result = self.pager.io.block(|| self.pager.begin_write_tx()); - if let Err(LimboError::Busy) = result { - // There is a non-CONCURRENT transaction holding the write lock. We must abort. - self.commit_coordinator.pager_commit_lock.unlock(); - return Err(LimboError::WriteWriteConflict); - } - result?; - self.state = CommitState::WriteRow { - end_ts: *end_ts, - write_set_index: 0, - requires_seek: true, - }; - return Ok(TransitionResult::Continue); - } - CommitState::WriteRow { - end_ts, - write_set_index, - requires_seek, - } => { - if *write_set_index == self.write_set.len() { - self.state = CommitState::CommitPagerTxn { end_ts: *end_ts }; - return Ok(TransitionResult::Continue); - } - let id = &self.write_set[*write_set_index]; - if let Some(row_versions) = mvcc_store.rows.get(id) { - let row_versions = row_versions.value().read(); - // Find rows that were written by this transaction. - // Hekaton uses oldest-to-newest order for row versions, so we reverse iterate to find the newest one - // this transaction changed. - for row_version in row_versions.iter().rev() { - if let TxTimestampOrID::TxID(row_tx_id) = row_version.begin { - if row_tx_id == self.tx_id { - let cursor = if let Some(cursor) = self.cursors.get(&id.table_id) { - cursor.clone() - } else { - let cursor = BTreeCursor::new_table( - None, // Write directly to B-tree - self.pager.clone(), - id.table_id as usize, - row_version.row.column_count, - ); - let cursor = Arc::new(RwLock::new(cursor)); - self.cursors.insert(id.table_id, cursor.clone()); - cursor - }; - let state_machine = mvcc_store.write_row_to_pager( - &row_version.row, - cursor, - *requires_seek, - )?; - self.write_row_state_machine = Some(state_machine); - - self.state = CommitState::WriteRowStateMachine { - end_ts: *end_ts, - write_set_index: *write_set_index, - }; - break; - } - } - if let Some(TxTimestampOrID::TxID(row_tx_id)) = row_version.end { - if row_tx_id == self.tx_id { - let column_count = row_version.row.column_count; - let cursor = if let Some(cursor) = self.cursors.get(&id.table_id) { - cursor.clone() - } else { - let cursor = BTreeCursor::new_table( - None, // Write directly to B-tree - self.pager.clone(), - id.table_id as usize, - column_count, - ); - let cursor = Arc::new(RwLock::new(cursor)); - self.cursors.insert(id.table_id, cursor.clone()); - cursor - }; - let state_machine = - mvcc_store.delete_row_from_pager(row_version.row.id, cursor)?; - self.delete_row_state_machine = Some(state_machine); - self.state = CommitState::DeleteRowStateMachine { - end_ts: *end_ts, - write_set_index: *write_set_index, - }; - break; - } - } - } - } - Ok(TransitionResult::Continue) - } - - CommitState::WriteRowStateMachine { - end_ts, - write_set_index, - } => { - let write_row_state_machine = self.write_row_state_machine.as_mut().unwrap(); - match write_row_state_machine.step(&())? { - IOResult::IO(io) => return Ok(TransitionResult::Io(io)), - IOResult::Done(_) => { - let requires_seek = { - if let Some(next_id) = self.write_set.get(*write_set_index + 1) { - let current_id = &self.write_set[*write_set_index]; - if current_id.table_id == next_id.table_id - && current_id.row_id + 1 == next_id.row_id - { - // simple optimizaiton for sequential inserts with inceasing by 1 ids - // we should probably just check record in next row and see if it requires seek - false - } else { - true - } - } else { - false - } - }; - self.state = CommitState::WriteRow { - end_ts: *end_ts, - write_set_index: *write_set_index + 1, - requires_seek, - }; - return Ok(TransitionResult::Continue); - } - } - } - CommitState::DeleteRowStateMachine { - end_ts, - write_set_index, - } => { - let delete_row_state_machine = self.delete_row_state_machine.as_mut().unwrap(); - match delete_row_state_machine.step(&())? { - IOResult::IO(io) => return Ok(TransitionResult::Io(io)), - IOResult::Done(_) => { - self.state = CommitState::WriteRow { - end_ts: *end_ts, - write_set_index: *write_set_index + 1, - requires_seek: true, - }; - return Ok(TransitionResult::Continue); - } - } - } - CommitState::CommitPagerTxn { end_ts } => { - // Write committed data to pager for persistence - // Flush dirty pages to WAL - this is critical for data persistence - // Similar to what step_end_write_txn does for legacy transactions - - let result = self - .pager - .end_tx( - false, // rollback = false since we're committing - &self.connection, - ) - .map_err(|e| LimboError::InternalError(e.to_string())) - .unwrap(); - match result { - IOResult::Done(_) => { - // FIXME: hack for now to keep database header updated for pager commit - let tx = mvcc_store.txs.get(&self.tx_id).unwrap(); - let tx_unlocked = tx.value(); - self.header.write().replace(*tx_unlocked.header.read()); - self.commit_coordinator.pager_commit_lock.unlock(); - // TODO: here mark we are ready for a batch - self.state = CommitState::Commit { end_ts: *end_ts }; - return Ok(TransitionResult::Continue); - } - IOResult::IO(io) => { - return Ok(TransitionResult::Io(io)); - } - } - } CommitState::Commit { end_ts } => { let mut log_record = LogRecord::new(*end_ts); if !mvcc_store.is_exclusive_tx(&self.tx_id) && mvcc_store.has_exclusive_tx() { @@ -885,6 +587,10 @@ impl StateTransition for CommitStateMachine { tx_unlocked .state .store(TransactionState::Committed(*end_ts)); + mvcc_store + .global_header + .write() + .replace(*tx_unlocked.header.read()); // We have now updated all the versions with a reference to the // transaction ID to a timestamp and can, therefore, remove the // transaction. Please note that when we move to lockless, the @@ -1090,6 +796,7 @@ pub struct MvStore { txs: SkipMap, tx_ids: AtomicU64, next_rowid: AtomicU64, + next_table_id: AtomicU64, clock: Clock, storage: Storage, loaded_tables: RwLock>, @@ -1110,6 +817,9 @@ pub struct MvStore { /// - Immediately TRUNCATE checkpoint the WAL into the database file. /// - Release the blocking_checkpoint_lock. blocking_checkpoint_lock: Arc, + /// The highest transaction ID that has been checkpointed. + /// Used to skip checkpointing transactions that have already been checkpointed. + checkpointed_txid_max: AtomicU64, } impl MvStore { @@ -1120,19 +830,28 @@ impl MvStore { txs: SkipMap::new(), tx_ids: AtomicU64::new(1), // let's reserve transaction 0 for special purposes next_rowid: AtomicU64::new(0), // TODO: determine this from B-Tree + next_table_id: AtomicU64::new(2), // table id 1 / root page 1 is always sqlite_schema. clock, storage, loaded_tables: RwLock::new(HashSet::new()), exclusive_tx: RwLock::new(None), commit_coordinator: Arc::new(CommitCoordinator { pager_commit_lock: Arc::new(TursoRwLock::new()), - commits_waiting: Arc::new(AtomicU64::new(0)), }), global_header: Arc::new(RwLock::new(None)), blocking_checkpoint_lock: Arc::new(TursoRwLock::new()), + checkpointed_txid_max: AtomicU64::new(0), } } + /// MVCC does not use the pager/btree cursors to create pages until checkpoint. + /// This method is used to assign root page numbers when Insn::CreateBtree is used. + /// NOTE: during MVCC recovery (not implemented yet), [MvStore::next_table_id] must be + /// initialized to the current highest table id / root page number. + pub fn get_next_table_id(&self) -> u64 { + self.next_table_id.fetch_add(1, Ordering::SeqCst) + } + pub fn get_next_rowid(&self) -> i64 { self.next_rowid.fetch_add(1, Ordering::SeqCst) as i64 } @@ -1553,13 +1272,11 @@ impl MvStore { pub fn commit_tx( &self, tx_id: TxID, - pager: Arc, connection: &Arc, ) -> Result>> { let state_machine: StateMachine> = StateMachine::>::new(CommitStateMachine::new( CommitState::Initial, - pager, tx_id, connection.clone(), self.commit_coordinator.clone(), @@ -1819,7 +1536,7 @@ impl MvStore { // Then, scan the disk B-tree to find existing rows self.scan_load_table(table_id, pager)?; - self.loaded_tables.write().insert(table_id); + self.mark_table_as_loaded(table_id); Ok(()) } diff --git a/core/mvcc/database/tests.rs b/core/mvcc/database/tests.rs index ff35ba9ab..1f29160e4 100644 --- a/core/mvcc/database/tests.rs +++ b/core/mvcc/database/tests.rs @@ -760,9 +760,7 @@ pub(crate) fn commit_tx( conn: &Arc, tx_id: u64, ) -> Result<()> { - let mut sm = mv_store - .commit_tx(tx_id, conn.pager.read().clone(), conn) - .unwrap(); + let mut sm = mv_store.commit_tx(tx_id, conn).unwrap(); // TODO: sync IO hack loop { let res = sm.step(&mv_store)?; @@ -783,9 +781,7 @@ pub(crate) fn commit_tx_no_conn( conn: &Arc, ) -> Result<(), LimboError> { let mv_store = db.get_mvcc_store(); - let mut sm = mv_store - .commit_tx(tx_id, conn.pager.read().clone(), conn) - .unwrap(); + let mut sm = mv_store.commit_tx(tx_id, conn).unwrap(); // TODO: sync IO hack loop { let res = sm.step(&mv_store)?; diff --git a/core/mvcc/persistent_storage/logical_log.rs b/core/mvcc/persistent_storage/logical_log.rs index 1ba1a07d1..a2b2f50c9 100644 --- a/core/mvcc/persistent_storage/logical_log.rs +++ b/core/mvcc/persistent_storage/logical_log.rs @@ -184,4 +184,15 @@ impl LogicalLog { let c = self.file.sync(completion)?; Ok(IOResult::IO(IOCompletions::Single(c))) } + + pub fn truncate(&mut self) -> Result> { + let completion = Completion::new_trunc(move |result| { + if let Err(err) = result { + tracing::error!("logical_log_truncate failed: {}", err); + } + }); + let c = self.file.truncate(0, completion)?; + self.offset = 0; + Ok(IOResult::IO(IOCompletions::Single(c))) + } } diff --git a/core/mvcc/persistent_storage/mod.rs b/core/mvcc/persistent_storage/mod.rs index cfe977a5f..58af1b849 100644 --- a/core/mvcc/persistent_storage/mod.rs +++ b/core/mvcc/persistent_storage/mod.rs @@ -31,6 +31,10 @@ impl Storage { pub fn sync(&self) -> Result> { self.logical_log.write().unwrap().sync() } + + pub fn truncate(&self) -> Result> { + self.logical_log.write().unwrap().truncate() + } } impl Debug for Storage { diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index 176715ace..203d48836 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -1,8 +1,10 @@ #![allow(unused_variables)] use crate::error::SQLITE_CONSTRAINT_UNIQUE; use crate::function::AlterTableFunc; +use crate::mvcc::database::CheckpointStateMachine; use crate::numeric::{NullableInteger, Numeric}; use crate::schema::Table; +use crate::state_machine::StateMachine; use crate::storage::btree::{ integrity_check, IntegrityCheckError, IntegrityCheckState, PageCategory, }; @@ -33,7 +35,7 @@ use crate::{ }, translate::emitter::TransactionMode, }; -use crate::{get_cursor, MvCursor}; +use crate::{get_cursor, CheckpointMode, MvCursor}; use std::env::temp_dir; use std::ops::DerefMut; use std::{ @@ -376,6 +378,31 @@ pub fn op_checkpoint_inner( // however. return Err(LimboError::TableLocked); } + if let Some(mv_store) = mv_store { + if !matches!(checkpoint_mode, CheckpointMode::Truncate { .. }) { + return Err(LimboError::InvalidArgument( + "Only TRUNCATE checkpoint mode is supported for MVCC".to_string(), + )); + } + let mut ckpt_sm = StateMachine::new(CheckpointStateMachine::new( + pager.clone(), + mv_store.clone(), + program.connection.clone(), + )); + loop { + let result = ckpt_sm.step(&())?; + match result { + IOResult::IO(io) => { + pager.io.step()?; + } + IOResult::Done(result) => { + state.op_checkpoint_state = + OpCheckpointState::CompleteResult { result: Ok(result) }; + break; + } + } + } + } loop { match &mut state.op_checkpoint_state { OpCheckpointState::StartCheckpoint => { @@ -6629,6 +6656,13 @@ pub fn op_create_btree( // TODO: implement temp databases todo!("temp databases not implemented yet"); } + + if let Some(mv_store) = mv_store { + let root_page = mv_store.get_next_table_id(); + state.registers[*root] = Register::Value(Value::Integer(root_page as i64)); + state.pc += 1; + return Ok(InsnFunctionStepResult::Step); + } // FIXME: handle page cache is full let root_page = return_if_io!(pager.btree_create(flags)); state.registers[*root] = Register::Value(Value::Integer(root_page as i64)); diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index f70028f47..422737140 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -837,7 +837,7 @@ impl Program { let Some((tx_id, _)) = conn.mv_tx.get() else { return Ok(IOResult::Done(())); }; - let state_machine = mv_store.commit_tx(tx_id, pager.clone(), &conn).unwrap(); + let state_machine = mv_store.commit_tx(tx_id, &conn).unwrap(); program_state.commit_state = CommitState::CommitingMvcc { state_machine }; } let CommitState::CommitingMvcc { state_machine } = &mut program_state.commit_state diff --git a/tests/integration/fuzz_transaction/mod.rs b/tests/integration/fuzz_transaction/mod.rs index cc8e35724..c969a40cf 100644 --- a/tests/integration/fuzz_transaction/mod.rs +++ b/tests/integration/fuzz_transaction/mod.rs @@ -494,13 +494,14 @@ async fn test_multiple_connections_fuzz() { async fn test_multiple_connections_fuzz_mvcc() { let mvcc_fuzz_options = FuzzOptions { mvcc_enabled: true, - max_num_connections: 8, + max_num_connections: 2, query_gen_options: QueryGenOptions { weight_begin_deferred: 4, weight_begin_concurrent: 12, weight_commit: 8, weight_rollback: 8, - weight_checkpoint: 0, + weight_checkpoint: 2, + checkpoint_modes: vec![CheckpointMode::Truncate], weight_ddl: 0, weight_dml: 76, dml_gen_options: DmlGenOptions { @@ -531,6 +532,7 @@ struct QueryGenOptions { weight_commit: usize, weight_rollback: usize, weight_checkpoint: usize, + checkpoint_modes: Vec, weight_ddl: usize, weight_dml: usize, dml_gen_options: DmlGenOptions, @@ -564,6 +566,12 @@ impl Default for QueryGenOptions { weight_commit: 10, weight_rollback: 10, weight_checkpoint: 5, + checkpoint_modes: vec![ + CheckpointMode::Passive, + CheckpointMode::Restart, + CheckpointMode::Truncate, + CheckpointMode::Full, + ], weight_ddl: 5, weight_dml: 55, dml_gen_options: DmlGenOptions::default(), @@ -587,7 +595,8 @@ async fn multiple_connections_fuzz(opts: FuzzOptions) { println!("Multiple connections fuzz test seed: {seed}"); for iteration in 0..opts.num_iterations { - let num_connections = rng.random_range(2..=opts.max_num_connections); + let num_connections = + rng.random_range(2.min(opts.max_num_connections)..=opts.max_num_connections); println!("--- Seed {seed} Iteration {iteration} ---"); println!("Options: {opts:?}"); // Create a fresh database for each iteration @@ -1104,14 +1113,15 @@ fn generate_operation( ) } } else if range_checkpoint.contains(&random_val) { - let mode = match rng.random_range(0..=3) { - 0 => CheckpointMode::Passive, - 1 => CheckpointMode::Restart, - 2 => CheckpointMode::Truncate, - 3 => CheckpointMode::Full, - _ => unreachable!(), - }; - (Operation::Checkpoint { mode }, get_visible_rows()) + let mode = shadow_db + .query_gen_options + .checkpoint_modes + .choose(rng) + .unwrap(); + ( + Operation::Checkpoint { mode: mode.clone() }, + get_visible_rows(), + ) } else if range_ddl.contains(&random_val) { let op = match rng.random_range(0..6) { 0..=2 => AlterTableOp::AddColumn { diff --git a/tests/integration/query_processing/test_transactions.rs b/tests/integration/query_processing/test_transactions.rs index 5de2bb566..e641ffb11 100644 --- a/tests/integration/query_processing/test_transactions.rs +++ b/tests/integration/query_processing/test_transactions.rs @@ -458,6 +458,101 @@ fn test_mvcc_concurrent_conflicting_update_2() { assert!(matches!(err, LimboError::WriteWriteConflict)); } +#[test] +fn test_mvcc_checkpoint_works() { + let tmp_db = TempDatabase::new_with_opts( + "test_mvcc_checkpoint_works.db", + turso_core::DatabaseOpts::new().with_mvcc(true), + ); + + // Create table + let conn = tmp_db.connect_limbo(); + conn.execute("CREATE TABLE test (id INTEGER, value TEXT)") + .unwrap(); + + // Insert rows from multiple connections + let mut expected_rows = Vec::new(); + + // Create 5 connections, each inserting 20 rows + for conn_id in 0..5 { + let conn = tmp_db.connect_limbo(); + conn.execute("BEGIN CONCURRENT").unwrap(); + + // Each connection inserts rows with its own pattern + for i in 0..20 { + let id = conn_id * 100 + i; + let value = format!("value_conn{conn_id}_row{i}"); + conn.execute(format!( + "INSERT INTO test (id, value) VALUES ({id}, '{value}')", + )) + .unwrap(); + expected_rows.push((id, value)); + } + + conn.execute("COMMIT").unwrap(); + } + + // Before checkpoint: assert that the DB file size is exactly 4096, .db-wal size is exactly 32, and there is a nonzero size .db-lg file + let db_file_size = std::fs::metadata(&tmp_db.path).unwrap().len(); + assert!(db_file_size == 4096); + let wal_file_size = std::fs::metadata(tmp_db.path.with_extension("db-wal")) + .unwrap() + .len(); + assert!( + wal_file_size == 0, + "wal file size should be 0 bytes, but is {wal_file_size} bytes" + ); + let lg_file_size = std::fs::metadata(tmp_db.path.with_extension("db-lg")) + .unwrap() + .len(); + assert!(lg_file_size > 0); + + // Sort expected rows to match ORDER BY id, value + expected_rows.sort_by(|a, b| match a.0.cmp(&b.0) { + std::cmp::Ordering::Equal => a.1.cmp(&b.1), + other => other, + }); + + // Checkpoint + conn.execute("PRAGMA wal_checkpoint(TRUNCATE)").unwrap(); + + // Verify all rows after reopening database + let tmp_db = TempDatabase::new_with_existent(&tmp_db.path, true); + let conn = tmp_db.connect_limbo(); + let stmt = conn + .query("SELECT * FROM test ORDER BY id, value") + .unwrap() + .unwrap(); + let rows = helper_read_all_rows(stmt); + + // Build expected results + let expected: Vec> = expected_rows + .into_iter() + .map(|(id, value)| vec![Value::Integer(id as i64), Value::build_text(value)]) + .collect(); + + assert_eq!(rows, expected); + + // Assert that the db file size is larger than 4096, assert .db-wal size is 32 bytes, assert there is no .db-lg file + let db_file_size = std::fs::metadata(&tmp_db.path).unwrap().len(); + assert!(db_file_size > 4096); + assert!(db_file_size % 4096 == 0); + let wal_size = std::fs::metadata(tmp_db.path.with_extension("db-wal")) + .unwrap() + .len(); + assert!( + wal_size == 0, + "wal size should be 0 bytes, but is {wal_size} bytes" + ); + let log_size = std::fs::metadata(tmp_db.path.with_extension("db-lg")) + .unwrap() + .len(); + assert!( + log_size == 0, + "log size should be 0 bytes, but is {log_size} bytes" + ); +} + fn helper_read_all_rows(mut stmt: turso_core::Statement) -> Vec> { let mut ret = Vec::new(); loop {