From 16547cb569bcd8e743700cb716598e17fe86110a Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Mon, 25 Aug 2025 14:25:26 +0300 Subject: [PATCH] sqlite3: Implement sqlite3_next_stmt() --- sqlite3/include/sqlite3.h | 2 + sqlite3/src/lib.rs | 49 +++++++++++++++++++++- sqlite3/tests/compat/mod.rs | 78 +++++++++++++++++++++++++++++++++++ sqlite3/tests/sqlite3_tests.c | 2 + 4 files changed, 130 insertions(+), 1 deletion(-) diff --git a/sqlite3/include/sqlite3.h b/sqlite3/include/sqlite3.h index f56f38b8c..0d098ce81 100644 --- a/sqlite3/include/sqlite3.h +++ b/sqlite3/include/sqlite3.h @@ -107,6 +107,8 @@ int sqlite3_stmt_readonly(sqlite3_stmt *_stmt); int sqlite3_stmt_busy(sqlite3_stmt *_stmt); +sqlite3_stmt *sqlite3_next_stmt(sqlite3 *db, sqlite3_stmt *stmt); + int sqlite3_serialize(sqlite3 *_db, const char *_schema, void **_out, int *_out_bytes, unsigned int _flags); int sqlite3_deserialize(sqlite3 *_db, const char *_schema, const void *_in_, int _in_bytes, unsigned int _flags); diff --git a/sqlite3/src/lib.rs b/sqlite3/src/lib.rs index 46d6d64b3..1d0fa1bed 100644 --- a/sqlite3/src/lib.rs +++ b/sqlite3/src/lib.rs @@ -57,6 +57,7 @@ struct sqlite3Inner { pub(crate) e_open_state: u8, pub(crate) p_err: *mut ffi::c_void, pub(crate) filename: CString, + pub(crate) stmt_list: *mut sqlite3_stmt, } impl sqlite3 { @@ -76,6 +77,7 @@ impl sqlite3 { e_open_state: SQLITE_STATE_OPEN, p_err: std::ptr::null_mut(), filename, + stmt_list: std::ptr::null_mut(), }; #[allow(clippy::arc_with_non_send_sync)] let inner = Arc::new(Mutex::new(inner)); @@ -91,6 +93,7 @@ pub struct sqlite3_stmt { Option, *mut ffi::c_void, )>, + pub(crate) next: *mut sqlite3_stmt, } impl sqlite3_stmt { @@ -99,6 +102,7 @@ impl sqlite3_stmt { db, stmt, destructors: Vec::new(), + next: std::ptr::null_mut(), } } } @@ -279,7 +283,12 @@ pub unsafe extern "C" fn sqlite3_prepare_v2( return SQLITE_ERROR; } }; - *out_stmt = Box::leak(Box::new(sqlite3_stmt::new(raw_db, stmt))); + let new_stmt = Box::leak(Box::new(sqlite3_stmt::new(raw_db, stmt))); + + new_stmt.next = db.stmt_list; + db.stmt_list = new_stmt; + + *out_stmt = new_stmt; SQLITE_OK } @@ -290,6 +299,25 @@ pub unsafe extern "C" fn sqlite3_finalize(stmt: *mut sqlite3_stmt) -> ffi::c_int } let stmt_ref = &mut *stmt; + if !stmt_ref.db.is_null() { + let db = &mut *stmt_ref.db; + let mut db_inner = db.inner.lock().unwrap(); + + if db_inner.stmt_list == stmt { + db_inner.stmt_list = stmt_ref.next; + } else { + let mut current = db_inner.stmt_list; + while !current.is_null() { + let current_ref = &mut *current; + if current_ref.next == stmt { + current_ref.next = stmt_ref.next; + break; + } + current = current_ref.next; + } + } + } + for (_idx, destructor_opt, ptr) in stmt_ref.destructors.drain(..) { if let Some(destructor_fn) = destructor_opt { destructor_fn(ptr); @@ -381,6 +409,25 @@ pub unsafe extern "C" fn sqlite3_stmt_busy(_stmt: *mut sqlite3_stmt) -> ffi::c_i stub!(); } +/// Iterate over all prepared statements in the database. +#[no_mangle] +pub unsafe extern "C" fn sqlite3_next_stmt( + db: *mut sqlite3, + stmt: *mut sqlite3_stmt, +) -> *mut sqlite3_stmt { + if db.is_null() { + return std::ptr::null_mut(); + } + if stmt.is_null() { + let db = &*db; + let db = db.inner.lock().unwrap(); + db.stmt_list + } else { + let stmt = &mut *stmt; + stmt.next + } +} + #[no_mangle] pub unsafe extern "C" fn sqlite3_serialize( _db: *mut sqlite3, diff --git a/sqlite3/tests/compat/mod.rs b/sqlite3/tests/compat/mod.rs index 94361cf11..52ed3d8fa 100644 --- a/sqlite3/tests/compat/mod.rs +++ b/sqlite3/tests/compat/mod.rs @@ -48,6 +48,7 @@ extern "C" { ) -> i32; fn libsql_wal_disable_checkpoint(db: *mut sqlite3) -> i32; fn sqlite3_column_int(stmt: *mut sqlite3_stmt, idx: i32) -> i64; + fn sqlite3_next_stmt(db: *mut sqlite3, stmt: *mut sqlite3_stmt) -> *mut sqlite3_stmt; fn sqlite3_bind_int(stmt: *mut sqlite3_stmt, idx: i32, val: i64) -> i32; fn sqlite3_bind_parameter_count(stmt: *mut sqlite3_stmt) -> i32; fn sqlite3_bind_parameter_name(stmt: *mut sqlite3_stmt, idx: i32) -> *const libc::c_char; @@ -1319,4 +1320,81 @@ mod tests { assert_eq!(sqlite3_close(db), SQLITE_OK); } } + + #[test] + fn test_sqlite3_next_stmt() { + const SQLITE_OK: i32 = 0; + + unsafe { + let mut db: *mut sqlite3 = ptr::null_mut(); + assert_eq!(sqlite3_open(c":memory:".as_ptr(), &mut db), SQLITE_OK); + + // Initially, there should be no prepared statements + let iter = sqlite3_next_stmt(db, ptr::null_mut()); + assert!(iter.is_null()); + + // Prepare first statement + let mut stmt1: *mut sqlite3_stmt = ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2(db, c"SELECT 1;".as_ptr(), -1, &mut stmt1, ptr::null_mut()), + SQLITE_OK + ); + assert!(!stmt1.is_null()); + + // Now there should be one statement + let iter = sqlite3_next_stmt(db, ptr::null_mut()); + assert_eq!(iter, stmt1); + + // And no more after that + let iter = sqlite3_next_stmt(db, stmt1); + assert!(iter.is_null()); + + // Prepare second statement + let mut stmt2: *mut sqlite3_stmt = ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2(db, c"SELECT 2;".as_ptr(), -1, &mut stmt2, ptr::null_mut()), + SQLITE_OK + ); + assert!(!stmt2.is_null()); + + // Prepare third statement + let mut stmt3: *mut sqlite3_stmt = ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2(db, c"SELECT 3;".as_ptr(), -1, &mut stmt3, ptr::null_mut()), + SQLITE_OK + ); + assert!(!stmt3.is_null()); + + // Count all statements + let mut count = 0; + let mut iter = sqlite3_next_stmt(db, ptr::null_mut()); + while !iter.is_null() { + count += 1; + iter = sqlite3_next_stmt(db, iter); + } + assert_eq!(count, 3); + + // Finalize the middle statement + assert_eq!(sqlite3_finalize(stmt2), SQLITE_OK); + + // Count should now be 2 + count = 0; + iter = sqlite3_next_stmt(db, ptr::null_mut()); + while !iter.is_null() { + count += 1; + iter = sqlite3_next_stmt(db, iter); + } + assert_eq!(count, 2); + + // Finalize remaining statements + assert_eq!(sqlite3_finalize(stmt1), SQLITE_OK); + assert_eq!(sqlite3_finalize(stmt3), SQLITE_OK); + + // Should be no statements left + let iter = sqlite3_next_stmt(db, ptr::null_mut()); + assert!(iter.is_null()); + + assert_eq!(sqlite3_close(db), SQLITE_OK); + } + } } diff --git a/sqlite3/tests/sqlite3_tests.c b/sqlite3/tests/sqlite3_tests.c index 2cd490a93..9fc1ffc49 100644 --- a/sqlite3/tests/sqlite3_tests.c +++ b/sqlite3/tests/sqlite3_tests.c @@ -18,6 +18,7 @@ void test_sqlite3_bind_text2(); void test_sqlite3_bind_blob(); void test_sqlite3_column_type(); void test_sqlite3_column_decltype(); +void test_sqlite3_next_stmt(); int allocated = 0; @@ -35,6 +36,7 @@ int main(void) test_sqlite3_bind_blob(); test_sqlite3_column_type(); test_sqlite3_column_decltype(); + test_sqlite3_next_stmt(); return 0; }