diff --git a/core/mvcc/database/mod.rs b/core/mvcc/database/mod.rs index 54eedeebb..596825eac 100644 --- a/core/mvcc/database/mod.rs +++ b/core/mvcc/database/mod.rs @@ -305,21 +305,40 @@ impl StateTransition for StateMachine { pub enum CommitState { Initial, BeginPagerTxn { end_ts: u64 }, - WriteRows { end_ts: u64 }, + WriteRow { end_ts: u64, write_set_index: usize }, + WriteRowStateMachine { end_ts: u64, write_set_index: usize }, CommitPagerTxn { end_ts: u64 }, Commit { end_ts: u64 }, } -struct CommitStateMachine { +#[derive(Debug)] +pub enum WriteRowState { + Initial, + CreateCursor, + Seek, + Insert, +} + +pub struct CommitStateMachine { state: CommitState, is_finalized: bool, pager: Rc, tx_id: TxID, connection: Arc, write_set: Vec, + write_row_state_machine: Option>, _phantom: PhantomData, } +pub struct WriteRowStateMachine { + state: WriteRowState, + is_finalized: bool, + pager: Rc, + row: Row, + record: Option, + cursor: Option, +} + impl CommitStateMachine { fn new(state: CommitState, pager: Rc, tx_id: TxID, connection: Arc) -> Self { Self { @@ -329,11 +348,25 @@ impl CommitStateMachine { tx_id, connection, write_set: Vec::new(), + write_row_state_machine: None, _phantom: PhantomData, } } } +impl WriteRowStateMachine { + fn new(pager: Rc, row: Row) -> Self { + Self { + state: WriteRowState::Initial, + is_finalized: false, + pager, + row, + record: None, + cursor: None, + } + } +} + impl StateTransition for CommitStateMachine { type State = CommitStateMachine; type Context = MvStore; @@ -466,35 +499,72 @@ impl StateTransition for CommitStateMachine { } } } - self.state = CommitState::WriteRows { end_ts }; + self.state = CommitState::WriteRow { + end_ts, + write_set_index: 0, + }; return Ok(TransitionResult::Continue); } - CommitState::WriteRows { end_ts } => { - for id in &self.write_set { - if let Some(row_versions) = mvcc_store.rows.get(id) { - let row_versions = row_versions.value().read(); - // Find rows that were written by this transaction - for row_version in row_versions.iter() { - if let TxTimestampOrID::TxID(row_tx_id) = row_version.begin { - if row_tx_id == self.tx_id { - mvcc_store - .write_row_to_pager(self.pager.clone(), &row_version.row)?; - break; - } + CommitState::WriteRow { + end_ts, + write_set_index, + } => { + if write_set_index == self.write_set.len() { + self.state = CommitState::CommitPagerTxn { end_ts }; + return Ok(TransitionResult::Continue); + } + let id = &self.write_set[write_set_index]; + if let Some(row_versions) = mvcc_store.rows.get(id) { + let row_versions = row_versions.value().read(); + // Find rows that were written by this transaction + for row_version in row_versions.iter() { + if let TxTimestampOrID::TxID(row_tx_id) = row_version.begin { + if row_tx_id == self.tx_id { + let state_machine = mvcc_store + .write_row_to_pager(self.pager.clone(), &row_version.row)?; + self.write_row_state_machine = Some(state_machine); + self.state = CommitState::WriteRowStateMachine { + end_ts, + write_set_index, + }; + break; } - if let Some(TxTimestampOrID::Timestamp(row_tx_id)) = row_version.end { - if row_tx_id == self.tx_id { - mvcc_store - .write_row_to_pager(self.pager.clone(), &row_version.row)?; - break; - } + } + if let Some(TxTimestampOrID::Timestamp(row_tx_id)) = row_version.end { + if row_tx_id == self.tx_id { + let state_machine = mvcc_store + .write_row_to_pager(self.pager.clone(), &row_version.row)?; + self.write_row_state_machine = Some(state_machine); + self.state = CommitState::WriteRowStateMachine { + end_ts, + write_set_index, + }; + break; } } } } - self.state = CommitState::CommitPagerTxn { end_ts }; Ok(TransitionResult::Continue) } + CommitState::WriteRowStateMachine { + end_ts, + write_set_index, + } => { + let write_row_state_machine = self.write_row_state_machine.as_mut().unwrap(); + match write_row_state_machine.transition(&())? { + TransitionResult::Io => return Ok(TransitionResult::Io), + TransitionResult::Continue => { + return Ok(TransitionResult::Continue); + } + TransitionResult::Done => { + self.state = CommitState::WriteRow { + end_ts, + write_set_index: write_set_index + 1, + }; + return Ok(TransitionResult::Continue); + } + } + } CommitState::CommitPagerTxn { end_ts } => { // Write committed data to pager for persistence // Flush dirty pages to WAL - this is critical for data persistence @@ -579,6 +649,95 @@ impl StateTransition for CommitStateMachine { } } +impl StateTransition for WriteRowStateMachine { + type State = WriteRowStateMachine; + type Context = (); + + #[tracing::instrument(fields(state = ?self.state), skip(self, _context))] + fn transition<'a>(&mut self, _context: &Self::Context) -> Result { + use crate::storage::btree::BTreeCursor; + use crate::types::{IOResult, SeekKey, SeekOp}; + + match self.state { + WriteRowState::Initial => { + // Create the record and key + let mut record = ImmutableRecord::new(self.row.data.len()); + record.start_serialization(&self.row.data); + self.record = Some(record); + + self.state = WriteRowState::CreateCursor; + Ok(TransitionResult::Continue) + } + WriteRowState::CreateCursor => { + // Create the cursor + let root_page = self.row.id.table_id as usize; + let num_columns = self.row.column_count; + + let cursor = BTreeCursor::new_table( + None, // Write directly to B-tree + self.pager.clone(), + root_page, + num_columns, + ); + self.cursor = Some(cursor); + + self.state = WriteRowState::Seek; + Ok(TransitionResult::Continue) + } + WriteRowState::Seek => { + // Position the cursor by seeking to the row position + let seek_key = SeekKey::TableRowId(self.row.id.row_id); + let cursor = self.cursor.as_mut().unwrap(); + + match cursor + .seek(seek_key, SeekOp::GE { eq_only: true }) + .map_err(|e| DatabaseError::Io(e.to_string()))? + { + IOResult::Done(_) => { + self.state = WriteRowState::Insert; + Ok(TransitionResult::Continue) + } + IOResult::IO => { + return Ok(TransitionResult::Io); + } + } + } + WriteRowState::Insert => { + // Insert the record into the B-tree + let cursor = self.cursor.as_mut().unwrap(); + let key = BTreeKey::new_table_rowid(self.row.id.row_id, self.record.as_ref()); + + match cursor + .insert(&key, true) + .map_err(|e| DatabaseError::Io(e.to_string()))? + { + IOResult::Done(()) => { + tracing::trace!( + "write_row_to_pager(table_id={}, row_id={})", + self.row.id.table_id, + self.row.id.row_id + ); + self.finalize(&())?; + Ok(TransitionResult::Done) + } + IOResult::IO => { + return Ok(TransitionResult::Io); + } + } + } + } + } + + fn finalize<'a>(&mut self, _context: &Self::Context) -> Result<()> { + self.is_finalized = true; + Ok(()) + } + + fn is_finalized(&self) -> bool { + self.is_finalized + } +} + /// A multi-version concurrency control database. #[derive(Debug)] pub struct MvStore { @@ -853,15 +1012,13 @@ impl MvStore { tx_id: TxID, pager: Rc, connection: &Arc, - ) -> Result<()> { - let mut state_machine: StateMachine> = StateMachine::< + ) -> Result>> { + let state_machine: StateMachine> = StateMachine::< CommitStateMachine, >::new( CommitStateMachine::new(CommitState::Initial, pager, tx_id, connection.clone()), ); - state_machine.transition(self)?; - assert!(state_machine.is_finalized()); - Ok(()) + Ok(state_machine) } /// Rolls back a transaction with the specified ID. @@ -1021,64 +1178,18 @@ impl MvStore { versions.insert(position, row_version); } - pub fn write_row_to_pager(&self, pager: Rc, row: &Row) -> Result<()> { - use crate::storage::btree::BTreeCursor; - use crate::types::{IOResult, SeekKey, SeekOp}; + pub fn write_row_to_pager( + &self, + pager: Rc, + row: &Row, + ) -> Result> { + let state_machine: StateMachine = + StateMachine::::new(WriteRowStateMachine::new( + pager, + row.clone(), + )); - // The row.data is already a properly serialized SQLite record payload - // Create an ImmutableRecord and copy the data - let mut record = ImmutableRecord::new(row.data.len()); - record.start_serialization(&row.data); - - // Create a BTreeKey for the row - let key = BTreeKey::new_table_rowid(row.id.row_id, Some(&record)); - - // Get the column count from the row - let root_page = row.id.table_id as usize; - let num_columns = row.column_count; - - let mut cursor = BTreeCursor::new_table( - None, // Write directly to B-tree - pager.clone(), - root_page, - num_columns, - ); - - // Position the cursor first by seeking to the row position - let seek_key = SeekKey::TableRowId(row.id.row_id); - match cursor - .seek(seek_key, SeekOp::GE { eq_only: true }) - .map_err(|e| DatabaseError::Io(e.to_string()))? - { - IOResult::Done(_) => {} - IOResult::IO => { - panic!("IOResult::IO not supported in write_row_to_pager seek"); - } - } - - // Insert the record into the B-tree - loop { - match cursor - .insert(&key, true) - .map_err(|e| DatabaseError::Io(e.to_string())) - { - Ok(IOResult::Done(())) => break, - Ok(IOResult::IO) => { - pager.io.run_once().unwrap(); - continue; - } - Err(e) => { - return Err(DatabaseError::Io(e.to_string())); - } - } - } - - tracing::trace!( - "write_row_to_pager(table_id={}, row_id={})", - row.id.table_id, - row.id.row_id - ); - Ok(()) + Ok(state_machine) } /// Try to scan for row ids in the table. diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index abbcd3f25..9e3ebbb76 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -27,6 +27,7 @@ pub mod sorter; use crate::{ error::LimboError, function::{AggFunc, FuncCtx}, + mvcc::database::StateTransition, storage::sqlite3_ondisk::SmallVec, translate::plan::TableReferences, types::{IOResult, RawSlice, TextRef}, @@ -442,7 +443,12 @@ impl Program { // FIXME: we don't want to commit stuff from other programs. let mut mv_transactions = conn.mv_transactions.borrow_mut(); for tx_id in mv_transactions.iter() { - mv_store.commit_tx(*tx_id, pager.clone(), &conn).unwrap(); + let mut state_machine = + mv_store.commit_tx(*tx_id, pager.clone(), &conn).unwrap(); + state_machine + .transition(&mv_store) + .map_err(|e| LimboError::InternalError(e.to_string()))?; + assert!(state_machine.is_finalized()); } mv_transactions.clear(); }