core/mvcc commit_txn generic state machinery

Unfortunately it seems we are never reaching the point to remove state
machines, so might as well make it easier to make.

There are two points that must be highlighted:
1. There is a `StateTransition` trait implemented like:

```rust
pub trait StateTransition {
    type State;
    type Context;

    fn transition<'a>(&mut self, context: &Self::Context) ->
Result<TransitionResult>;
    fn finalize<'a>(&mut self, context: &Self::Context) -> Result<()>;
    fn is_finalized(&self) -> bool;
}
```

where there exists `transition` which tries to move state forward, and
`finalize` which marks the state machine as "finalized" so that **no
other call to finalize will forward the state and it will panic instead.

2. Before, we would store the state of a state machine inside the
callee's struct, but I'm proposing we do something different where the
callee will return the state machine and the caller will be responsible
of advancing it. This way we don't need to track many reset operations
in case of failures or rollbacks, and instead we could simply drop a
state machine and all other nested state machines will drop in a
cascade.
This commit is contained in:
Pere Diaz Bou
2025-08-01 12:35:44 +02:00
parent d616a375ee
commit 27757ab4eb
2 changed files with 202 additions and 85 deletions

View File

@@ -305,21 +305,40 @@ impl<State: StateTransition> StateTransition for StateMachine<State> {
pub enum CommitState {
Initial,
BeginPagerTxn { end_ts: u64 },
WriteRows { end_ts: u64 },
WriteRow { end_ts: u64, write_set_index: usize },
WriteRowStateMachine { end_ts: u64, write_set_index: usize },
CommitPagerTxn { end_ts: u64 },
Commit { end_ts: u64 },
}
struct CommitStateMachine<Clock: LogicalClock> {
#[derive(Debug)]
pub enum WriteRowState {
Initial,
CreateCursor,
Seek,
Insert,
}
pub struct CommitStateMachine<Clock: LogicalClock> {
state: CommitState,
is_finalized: bool,
pager: Rc<Pager>,
tx_id: TxID,
connection: Arc<Connection>,
write_set: Vec<RowID>,
write_row_state_machine: Option<StateMachine<WriteRowStateMachine>>,
_phantom: PhantomData<Clock>,
}
pub struct WriteRowStateMachine {
state: WriteRowState,
is_finalized: bool,
pager: Rc<Pager>,
row: Row,
record: Option<ImmutableRecord>,
cursor: Option<BTreeCursor>,
}
impl<Clock: LogicalClock> CommitStateMachine<Clock> {
fn new(state: CommitState, pager: Rc<Pager>, tx_id: TxID, connection: Arc<Connection>) -> Self {
Self {
@@ -329,11 +348,25 @@ impl<Clock: LogicalClock> CommitStateMachine<Clock> {
tx_id,
connection,
write_set: Vec::new(),
write_row_state_machine: None,
_phantom: PhantomData,
}
}
}
impl WriteRowStateMachine {
fn new(pager: Rc<Pager>, row: Row) -> Self {
Self {
state: WriteRowState::Initial,
is_finalized: false,
pager,
row,
record: None,
cursor: None,
}
}
}
impl<Clock: LogicalClock> StateTransition for CommitStateMachine<Clock> {
type State = CommitStateMachine<Clock>;
type Context = MvStore<Clock>;
@@ -466,35 +499,72 @@ impl<Clock: LogicalClock> StateTransition for CommitStateMachine<Clock> {
}
}
}
self.state = CommitState::WriteRows { end_ts };
self.state = CommitState::WriteRow {
end_ts,
write_set_index: 0,
};
return Ok(TransitionResult::Continue);
}
CommitState::WriteRows { end_ts } => {
for id in &self.write_set {
if let Some(row_versions) = mvcc_store.rows.get(id) {
let row_versions = row_versions.value().read();
// Find rows that were written by this transaction
for row_version in row_versions.iter() {
if let TxTimestampOrID::TxID(row_tx_id) = row_version.begin {
if row_tx_id == self.tx_id {
mvcc_store
.write_row_to_pager(self.pager.clone(), &row_version.row)?;
break;
}
CommitState::WriteRow {
end_ts,
write_set_index,
} => {
if write_set_index == self.write_set.len() {
self.state = CommitState::CommitPagerTxn { end_ts };
return Ok(TransitionResult::Continue);
}
let id = &self.write_set[write_set_index];
if let Some(row_versions) = mvcc_store.rows.get(id) {
let row_versions = row_versions.value().read();
// Find rows that were written by this transaction
for row_version in row_versions.iter() {
if let TxTimestampOrID::TxID(row_tx_id) = row_version.begin {
if row_tx_id == self.tx_id {
let state_machine = mvcc_store
.write_row_to_pager(self.pager.clone(), &row_version.row)?;
self.write_row_state_machine = Some(state_machine);
self.state = CommitState::WriteRowStateMachine {
end_ts,
write_set_index,
};
break;
}
if let Some(TxTimestampOrID::Timestamp(row_tx_id)) = row_version.end {
if row_tx_id == self.tx_id {
mvcc_store
.write_row_to_pager(self.pager.clone(), &row_version.row)?;
break;
}
}
if let Some(TxTimestampOrID::Timestamp(row_tx_id)) = row_version.end {
if row_tx_id == self.tx_id {
let state_machine = mvcc_store
.write_row_to_pager(self.pager.clone(), &row_version.row)?;
self.write_row_state_machine = Some(state_machine);
self.state = CommitState::WriteRowStateMachine {
end_ts,
write_set_index,
};
break;
}
}
}
}
self.state = CommitState::CommitPagerTxn { end_ts };
Ok(TransitionResult::Continue)
}
CommitState::WriteRowStateMachine {
end_ts,
write_set_index,
} => {
let write_row_state_machine = self.write_row_state_machine.as_mut().unwrap();
match write_row_state_machine.transition(&())? {
TransitionResult::Io => return Ok(TransitionResult::Io),
TransitionResult::Continue => {
return Ok(TransitionResult::Continue);
}
TransitionResult::Done => {
self.state = CommitState::WriteRow {
end_ts,
write_set_index: write_set_index + 1,
};
return Ok(TransitionResult::Continue);
}
}
}
CommitState::CommitPagerTxn { end_ts } => {
// Write committed data to pager for persistence
// Flush dirty pages to WAL - this is critical for data persistence
@@ -579,6 +649,95 @@ impl<Clock: LogicalClock> StateTransition for CommitStateMachine<Clock> {
}
}
impl StateTransition for WriteRowStateMachine {
type State = WriteRowStateMachine;
type Context = ();
#[tracing::instrument(fields(state = ?self.state), skip(self, _context))]
fn transition<'a>(&mut self, _context: &Self::Context) -> Result<TransitionResult> {
use crate::storage::btree::BTreeCursor;
use crate::types::{IOResult, SeekKey, SeekOp};
match self.state {
WriteRowState::Initial => {
// Create the record and key
let mut record = ImmutableRecord::new(self.row.data.len());
record.start_serialization(&self.row.data);
self.record = Some(record);
self.state = WriteRowState::CreateCursor;
Ok(TransitionResult::Continue)
}
WriteRowState::CreateCursor => {
// Create the cursor
let root_page = self.row.id.table_id as usize;
let num_columns = self.row.column_count;
let cursor = BTreeCursor::new_table(
None, // Write directly to B-tree
self.pager.clone(),
root_page,
num_columns,
);
self.cursor = Some(cursor);
self.state = WriteRowState::Seek;
Ok(TransitionResult::Continue)
}
WriteRowState::Seek => {
// Position the cursor by seeking to the row position
let seek_key = SeekKey::TableRowId(self.row.id.row_id);
let cursor = self.cursor.as_mut().unwrap();
match cursor
.seek(seek_key, SeekOp::GE { eq_only: true })
.map_err(|e| DatabaseError::Io(e.to_string()))?
{
IOResult::Done(_) => {
self.state = WriteRowState::Insert;
Ok(TransitionResult::Continue)
}
IOResult::IO => {
return Ok(TransitionResult::Io);
}
}
}
WriteRowState::Insert => {
// Insert the record into the B-tree
let cursor = self.cursor.as_mut().unwrap();
let key = BTreeKey::new_table_rowid(self.row.id.row_id, self.record.as_ref());
match cursor
.insert(&key, true)
.map_err(|e| DatabaseError::Io(e.to_string()))?
{
IOResult::Done(()) => {
tracing::trace!(
"write_row_to_pager(table_id={}, row_id={})",
self.row.id.table_id,
self.row.id.row_id
);
self.finalize(&())?;
Ok(TransitionResult::Done)
}
IOResult::IO => {
return Ok(TransitionResult::Io);
}
}
}
}
}
fn finalize<'a>(&mut self, _context: &Self::Context) -> Result<()> {
self.is_finalized = true;
Ok(())
}
fn is_finalized(&self) -> bool {
self.is_finalized
}
}
/// A multi-version concurrency control database.
#[derive(Debug)]
pub struct MvStore<Clock: LogicalClock> {
@@ -853,15 +1012,13 @@ impl<Clock: LogicalClock> MvStore<Clock> {
tx_id: TxID,
pager: Rc<Pager>,
connection: &Arc<Connection>,
) -> Result<()> {
let mut state_machine: StateMachine<CommitStateMachine<Clock>> = StateMachine::<
) -> Result<StateMachine<CommitStateMachine<Clock>>> {
let state_machine: StateMachine<CommitStateMachine<Clock>> = StateMachine::<
CommitStateMachine<Clock>,
>::new(
CommitStateMachine::new(CommitState::Initial, pager, tx_id, connection.clone()),
);
state_machine.transition(self)?;
assert!(state_machine.is_finalized());
Ok(())
Ok(state_machine)
}
/// Rolls back a transaction with the specified ID.
@@ -1021,64 +1178,18 @@ impl<Clock: LogicalClock> MvStore<Clock> {
versions.insert(position, row_version);
}
pub fn write_row_to_pager(&self, pager: Rc<Pager>, row: &Row) -> Result<()> {
use crate::storage::btree::BTreeCursor;
use crate::types::{IOResult, SeekKey, SeekOp};
pub fn write_row_to_pager(
&self,
pager: Rc<Pager>,
row: &Row,
) -> Result<StateMachine<WriteRowStateMachine>> {
let state_machine: StateMachine<WriteRowStateMachine> =
StateMachine::<WriteRowStateMachine>::new(WriteRowStateMachine::new(
pager,
row.clone(),
));
// The row.data is already a properly serialized SQLite record payload
// Create an ImmutableRecord and copy the data
let mut record = ImmutableRecord::new(row.data.len());
record.start_serialization(&row.data);
// Create a BTreeKey for the row
let key = BTreeKey::new_table_rowid(row.id.row_id, Some(&record));
// Get the column count from the row
let root_page = row.id.table_id as usize;
let num_columns = row.column_count;
let mut cursor = BTreeCursor::new_table(
None, // Write directly to B-tree
pager.clone(),
root_page,
num_columns,
);
// Position the cursor first by seeking to the row position
let seek_key = SeekKey::TableRowId(row.id.row_id);
match cursor
.seek(seek_key, SeekOp::GE { eq_only: true })
.map_err(|e| DatabaseError::Io(e.to_string()))?
{
IOResult::Done(_) => {}
IOResult::IO => {
panic!("IOResult::IO not supported in write_row_to_pager seek");
}
}
// Insert the record into the B-tree
loop {
match cursor
.insert(&key, true)
.map_err(|e| DatabaseError::Io(e.to_string()))
{
Ok(IOResult::Done(())) => break,
Ok(IOResult::IO) => {
pager.io.run_once().unwrap();
continue;
}
Err(e) => {
return Err(DatabaseError::Io(e.to_string()));
}
}
}
tracing::trace!(
"write_row_to_pager(table_id={}, row_id={})",
row.id.table_id,
row.id.row_id
);
Ok(())
Ok(state_machine)
}
/// Try to scan for row ids in the table.

View File

@@ -27,6 +27,7 @@ pub mod sorter;
use crate::{
error::LimboError,
function::{AggFunc, FuncCtx},
mvcc::database::StateTransition,
storage::sqlite3_ondisk::SmallVec,
translate::plan::TableReferences,
types::{IOResult, RawSlice, TextRef},
@@ -442,7 +443,12 @@ impl Program {
// FIXME: we don't want to commit stuff from other programs.
let mut mv_transactions = conn.mv_transactions.borrow_mut();
for tx_id in mv_transactions.iter() {
mv_store.commit_tx(*tx_id, pager.clone(), &conn).unwrap();
let mut state_machine =
mv_store.commit_tx(*tx_id, pager.clone(), &conn).unwrap();
state_machine
.transition(&mv_store)
.map_err(|e| LimboError::InternalError(e.to_string()))?;
assert!(state_machine.is_finalized());
}
mv_transactions.clear();
}