diff --git a/core/mvcc/bindings/c/include/mvcc.h b/core/mvcc/bindings/c/include/mvcc.h index b30f1cf4a..c12aa5a43 100644 --- a/core/mvcc/bindings/c/include/mvcc.h +++ b/core/mvcc/bindings/c/include/mvcc.h @@ -11,8 +11,12 @@ typedef enum { typedef struct DbContext DbContext; +typedef struct ScanCursorContext ScanCursorContext; + typedef const DbContext *MVCCDatabaseRef; +typedef ScanCursorContext *MVCCScanCursorRef; + #ifdef __cplusplus extern "C" { #endif // __cplusplus @@ -27,6 +31,14 @@ MVCCError MVCCDatabaseRead(MVCCDatabaseRef db, uint64_t id, char **value_ptr, in void MVCCFreeStr(void *ptr); +MVCCScanCursorRef MVCCScanCursorOpen(MVCCDatabaseRef db); + +void MVCCScanCursorClose(MVCCScanCursorRef cursor); + +MVCCError MVCCScanCursorRead(MVCCScanCursorRef cursor, char **value_ptr, int64_t *value_len); + +int MVCCScanCursorNext(MVCCScanCursorRef cursor); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/core/mvcc/bindings/c/src/lib.rs b/core/mvcc/bindings/c/src/lib.rs index f995be1ea..7f3f1ea62 100644 --- a/core/mvcc/bindings/c/src/lib.rs +++ b/core/mvcc/bindings/c/src/lib.rs @@ -6,7 +6,7 @@ mod types; use errors::MVCCError; use mvcc_rs::*; -use types::{DbContext, MVCCDatabaseRef}; +use types::{DbContext, MVCCDatabaseRef, MVCCScanCursorRef, ScanCursorContext}; /// cbindgen:ignore type Clock = clock::LocalClock; @@ -20,6 +20,9 @@ type Inner = database::DatabaseInner; /// cbindgen:ignore type Db = database::Database>; +/// cbindgen:ignore +type ScanCursor = cursor::ScanCursor<'static, Clock, Storage, tokio::sync::Mutex>; + static INIT_RUST_LOG: std::sync::Once = std::sync::Once::new(); #[no_mangle] @@ -143,3 +146,97 @@ pub unsafe extern "C" fn MVCCFreeStr(ptr: *mut std::ffi::c_void) { } let _ = std::ffi::CString::from_raw(ptr as *mut std::ffi::c_char); } + +#[no_mangle] +pub unsafe extern "C" fn MVCCScanCursorOpen(db: MVCCDatabaseRef) -> MVCCScanCursorRef { + tracing::debug!("MVCCScanCursorOpen()"); + // Reference is transmuted to &'static in order to be able to pass the cursor back to C. + // The contract with C is to never use a cursor after MVCCDatabaseClose() has been called. + let database = unsafe { std::mem::transmute::<&DbContext, &'static DbContext>(db.get_ref()) }; + let (database, runtime) = (&database.db, &database.runtime); + match runtime.block_on(async move { mvcc_rs::cursor::ScanCursor::new(database).await }) { + Ok(cursor) => { + tracing::debug!("Cursor open: {cursor:?}"); + MVCCScanCursorRef { + ptr: Box::into_raw(Box::new(ScanCursorContext { cursor, db })), + } + } + Err(e) => { + tracing::error!("MVCCScanCursorOpen: {e}"); + MVCCScanCursorRef { + ptr: std::ptr::null_mut(), + } + } + } +} + +#[no_mangle] +pub unsafe extern "C" fn MVCCScanCursorClose(cursor: MVCCScanCursorRef) { + tracing::debug!("MVCCScanCursorClose()"); + if cursor.ptr.is_null() { + tracing::debug!("warning: `cursor` is null in MVCCScanCursorClose()"); + return; + } + let _ = unsafe { Box::from_raw(cursor.ptr) }; +} + +#[no_mangle] +pub unsafe extern "C" fn MVCCScanCursorRead( + cursor: MVCCScanCursorRef, + value_ptr: *mut *mut std::ffi::c_char, + value_len: *mut i64, +) -> MVCCError { + tracing::debug!("MVCCScanCursorRead()"); + if cursor.ptr.is_null() { + tracing::debug!("warning: `cursor` is null in MVCCScanCursorRead()"); + return MVCCError::MVCC_IO_ERROR_READ; + } + let cursor_ctx = unsafe { &*cursor.ptr }; + let runtime = &cursor_ctx.db.get_ref().runtime; + let cursor = &cursor_ctx.cursor; + + // TODO: deduplicate with MVCCDatabaseRead() + match runtime.block_on(async move { + let maybe_row = cursor.current().await?; + match maybe_row { + Some(row) => { + tracing::debug!("Found row {row:?}"); + let str_len = row.data.len() + 1; + let value = std::ffi::CString::new(row.data.as_bytes()).map_err(|e| { + mvcc_rs::errors::DatabaseError::Io(format!( + "Failed to transform read data into CString: {e}" + )) + })?; + unsafe { + *value_ptr = value.into_raw(); + *value_len = str_len as i64; + } + } + None => unsafe { *value_len = -1 }, + }; + Ok::<(), mvcc_rs::errors::DatabaseError>(()) + }) { + Ok(_) => { + tracing::debug!("MVCCDatabaseRead: success"); + MVCCError::MVCC_OK + } + Err(e) => { + tracing::error!("MVCCDatabaseRead: {e}"); + MVCCError::MVCC_IO_ERROR_READ + } + } +} + +#[no_mangle] +pub unsafe extern "C" fn MVCCScanCursorNext(cursor: MVCCScanCursorRef) -> std::ffi::c_int { + let cursor_ctx = unsafe { &mut *cursor.ptr }; + let cursor = &mut cursor_ctx.cursor; + tracing::debug!("MVCCScanCursorNext(): {}", cursor.index); + if cursor.forward() { + tracing::debug!("Forwarded to {}", cursor.index); + 1 + } else { + tracing::debug!("Forwarded to end"); + 0 + } +} diff --git a/core/mvcc/bindings/c/src/types.rs b/core/mvcc/bindings/c/src/types.rs index acb929dcd..34d035cb8 100644 --- a/core/mvcc/bindings/c/src/types.rs +++ b/core/mvcc/bindings/c/src/types.rs @@ -45,3 +45,13 @@ pub struct DbContext { pub(crate) db: Db, pub(crate) runtime: tokio::runtime::Runtime, } + +pub struct ScanCursorContext { + pub cursor: crate::ScanCursor, + pub db: MVCCDatabaseRef, +} + +#[repr(transparent)] +pub struct MVCCScanCursorRef { + pub ptr: *mut ScanCursorContext, +} diff --git a/core/mvcc/mvcc-rs/src/cursor.rs b/core/mvcc/mvcc-rs/src/cursor.rs new file mode 100644 index 000000000..230ad6ff6 --- /dev/null +++ b/core/mvcc/mvcc-rs/src/cursor.rs @@ -0,0 +1,48 @@ +use crate::clock::LogicalClock; +use crate::database::{Database, DatabaseInner, Result, Row}; +use crate::persistent_storage::Storage; +use crate::sync::AsyncMutex; + +#[derive(Debug)] +pub struct ScanCursor< + 'a, + Clock: LogicalClock, + StorageImpl: Storage, + Mutex: AsyncMutex>, +> { + pub db: &'a Database, + pub row_ids: Vec, + pub index: usize, + tx_id: u64, +} + +impl< + 'a, + Clock: LogicalClock, + StorageImpl: Storage, + Mutex: AsyncMutex>, + > ScanCursor<'a, Clock, StorageImpl, Mutex> +{ + pub async fn new( + db: &'a Database, + ) -> Result> { + let tx_id = db.begin_tx().await; + let row_ids = db.scan_row_ids().await?; + Ok(Self { + db, + tx_id, + row_ids, + index: 0, + }) + } + + pub async fn current(&self) -> Result> { + let id = self.row_ids[self.index]; + self.db.read(self.tx_id, id).await + } + + pub fn forward(&mut self) -> bool { + self.index += 1; + self.index < self.row_ids.len() + } +} diff --git a/core/mvcc/mvcc-rs/src/database.rs b/core/mvcc/mvcc-rs/src/database.rs index 42d0d1588..89e56aa30 100644 --- a/core/mvcc/mvcc-rs/src/database.rs +++ b/core/mvcc/mvcc-rs/src/database.rs @@ -225,6 +225,11 @@ impl< inner.read(tx_id, id).await } + pub async fn scan_row_ids(&self) -> Result> { + let inner = self.inner.lock().await; + inner.scan_row_ids() + } + /// Begins a new transaction in the database. /// /// This function starts a new transaction in the database and returns a `TxID` value @@ -355,6 +360,11 @@ impl Ok(None) } + fn scan_row_ids(&self) -> Result> { + let rows = self.rows.borrow(); + Ok(rows.keys().cloned().collect()) + } + async fn begin_tx(&mut self) -> TxID { let tx_id = self.get_tx_id(); let begin_ts = self.get_timestamp(); diff --git a/core/mvcc/mvcc-rs/src/lib.rs b/core/mvcc/mvcc-rs/src/lib.rs index d88011290..f8d418335 100644 --- a/core/mvcc/mvcc-rs/src/lib.rs +++ b/core/mvcc/mvcc-rs/src/lib.rs @@ -32,6 +32,7 @@ //! * Garbage collection pub mod clock; +pub mod cursor; pub mod database; pub mod errors; pub mod persistent_storage;