diff --git a/core/mvcc/cursor.rs b/core/mvcc/cursor.rs index cc8cde437..f10326df7 100644 --- a/core/mvcc/cursor.rs +++ b/core/mvcc/cursor.rs @@ -1,8 +1,10 @@ use crate::mvcc::clock::LogicalClock; use crate::mvcc::database::{MvStore, Row, RowID}; +use crate::types::{IOResult, SeekKey, SeekOp, SeekResult}; use crate::Pager; use crate::Result; use std::fmt::Debug; +use std::ops::Bound; use std::rc::Rc; use std::sync::Arc; @@ -151,4 +153,48 @@ impl MvccLazyCursor { CursorPosition::End => i64::MAX, } } + + pub fn seek(&mut self, seek_key: SeekKey<'_>, op: SeekOp) -> Result> { + let row_id = match seek_key { + SeekKey::TableRowId(row_id) => row_id, + SeekKey::IndexKey(_) => { + todo!(); + } + }; + // gt -> lower_bound bound excluded, we want first row after row_id + // ge -> lower_bound bound included, we want first row equal to row_id or first row after row_id + // lt -> upper_bound bound excluded, we want last row before row_id + // le -> upper_bound bound included, we want last row equal to row_id or first row before row_id + let rowid = RowID { + table_id: self.table_id, + row_id, + }; + let (bound, lower_bound) = match op { + SeekOp::GT => (Bound::Excluded(&rowid), true), + SeekOp::GE { eq_only: _ } => (Bound::Included(&rowid), true), + SeekOp::LT => (Bound::Excluded(&rowid), false), + SeekOp::LE { eq_only: _ } => (Bound::Included(&rowid), false), + }; + let rowid = self.db.seek_rowid(bound, lower_bound); + if let Some(rowid) = rowid { + self.current_pos = CursorPosition::Loaded(rowid); + if op.eq_only() { + if rowid.row_id == row_id { + Ok(IOResult::Done(SeekResult::Found)) + } else { + Ok(IOResult::Done(SeekResult::NotFound)) + } + } else { + Ok(IOResult::Done(SeekResult::Found)) + } + } else { + let forwards = matches!(op, SeekOp::GE { eq_only: _ } | SeekOp::GT); + if forwards { + self.last(); + } else { + self.rewind(); + } + Ok(IOResult::Done(SeekResult::NotFound)) + } + } } diff --git a/core/mvcc/database/mod.rs b/core/mvcc/database/mod.rs index b95ae07c6..4466d51d2 100644 --- a/core/mvcc/database/mod.rs +++ b/core/mvcc/database/mod.rs @@ -917,6 +917,16 @@ impl MvStore { .map(|entry| *entry.key()) } + pub fn seek_rowid(&self, bound: Bound<&RowID>, lower_bound: bool) -> Option { + tracing::trace!("seek_rowid(bound={:?}, lower_bound={})", bound, lower_bound,); + + if lower_bound { + self.rows.lower_bound(bound).map(|entry| *entry.key()) + } else { + self.rows.upper_bound(bound).map(|entry| *entry.key()) + } + } + /// Begins a new transaction in the database. /// /// This function starts a new transaction in the database and returns a `TxID` value diff --git a/core/storage/btree.rs b/core/storage/btree.rs index 6c31b7fb2..d2974fb9f 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -4367,7 +4367,10 @@ impl BTreeCursor { #[instrument(skip(self), level = Level::DEBUG)] pub fn seek(&mut self, key: SeekKey<'_>, op: SeekOp) -> Result> { - assert!(self.mv_cursor.is_none()); + if let Some(mv_cursor) = &self.mv_cursor { + let mut mv_cursor = mv_cursor.borrow_mut(); + return mv_cursor.seek(key, op); + } // Empty trace to capture the span information tracing::trace!(""); // We need to clear the null flag for the table cursor before seeking,