Close statements in extension tests, and use mut pointers for stmt

This commit is contained in:
PThorpe92
2025-05-24 16:45:25 -04:00
parent d63f9d8cff
commit 1cacbf1f0d
3 changed files with 26 additions and 14 deletions

View File

@@ -89,7 +89,7 @@ pub unsafe extern "C" fn execute(
/// Wraps core Connection::prepare with a custom Stmt object with the necessary function pointers.
/// This object is boxed/leaked and the caller is responsible for freeing the memory.
pub unsafe extern "C" fn prepare_stmt(ctx: *mut ExtConn, sql: *const c_char) -> *const Stmt {
pub unsafe extern "C" fn prepare_stmt(ctx: *mut ExtConn, sql: *const c_char) -> *mut Stmt {
let c_str = unsafe { CStr::from_ptr(sql as *mut c_char) };
let sql_str = match c_str.to_str() {
Ok(s) => s.to_string(),
@@ -120,7 +120,7 @@ pub unsafe extern "C" fn prepare_stmt(ctx: *mut ExtConn, sql: *const c_char) ->
stmt_get_column_names,
stmt_free_current_row,
stmt_close,
))) as *const Stmt
)))
}
Err(e) => {
tracing::error!("prepare_stmt: failed to prepare statement: {:?}", e);

View File

@@ -343,7 +343,7 @@ impl ConstraintInfo {
}
}
pub type PrepareStmtFn = unsafe extern "C" fn(api: *mut Conn, sql: *const c_char) -> *const Stmt;
pub type PrepareStmtFn = unsafe extern "C" fn(api: *mut Conn, sql: *const c_char) -> *mut Stmt;
pub type ExecuteFn = unsafe extern "C" fn(
ctx: *mut Conn,
sql: *const c_char,
@@ -426,9 +426,9 @@ impl Conn {
Err(ResultCode::Error)
}
pub fn prepare_stmt(&self, sql: &str) -> *const Stmt {
pub fn prepare_stmt(&self, sql: &str) -> *mut Stmt {
let Ok(sql) = CString::new(sql) else {
return std::ptr::null();
return std::ptr::null_mut();
};
unsafe { (self._prepare_stmt)(self as *const _ as *mut Conn, sql.as_ptr()) }
}
@@ -438,8 +438,16 @@ impl Conn {
/// Statements can be manually closed.
#[derive(Debug)]
#[repr(transparent)]
pub struct Statement(*const Stmt);
pub struct Statement(*mut Stmt);
impl Drop for Statement {
fn drop(&mut self) {
if self.0.is_null() {
return;
}
unsafe { (*self.0).close() }
}
}
/// Public API for methods to allow extensions to query other tables for
/// the connection that opened the VTable.
#[derive(Debug)]
@@ -562,12 +570,13 @@ impl Stmt {
}
/// Close the statement
pub fn close(&self) {
pub fn close(&mut self) {
// null check to prevent double free
if self._ctx.is_null() {
return;
}
unsafe { (self._close)(self as *const Stmt as *mut Stmt) };
self._ctx = std::ptr::null_mut();
}
/// # Safety
@@ -580,21 +589,21 @@ impl Stmt {
}
/// Returns the pointer to the statement.
pub fn to_ptr(&self) -> *const Stmt {
self
pub fn to_ptr(&self) -> *mut Stmt {
self as *const Stmt as *mut Stmt
}
/// Bind a value to a parameter in the prepared statement
/// Own the value so it can be freed in core
fn bind_args(&self, idx: NonZeroUsize, arg: Value) {
unsafe {
(self._bind_args_fn)(self.to_ptr() as *mut Stmt, idx.get() as i32, arg);
(self._bind_args_fn)(self.to_ptr(), idx.get() as i32, arg);
};
}
/// Execute the statement to attempt to retrieve the next result row.
fn step(&self) -> StepResult {
unsafe { (self._step)(self.to_ptr() as *mut Stmt) }.into()
unsafe { (self._step)(self.to_ptr()) }.into()
}
/// Free the memory for the values obtained from the `get_row` method.
@@ -608,7 +617,7 @@ impl Stmt {
return;
}
// free from the core side so we don't have to expose `__free_internal_type`
(self._free_current_row)(self.to_ptr() as *mut Stmt);
(self._free_current_row)(self.to_ptr());
self.current_row = std::ptr::null_mut();
self.current_row_len = -1;
}
@@ -616,7 +625,7 @@ impl Stmt {
/// Returns the values from the current row in the prepared statement, should
/// be called after the step() method returns `StepResult::Row`
pub fn get_row(&self) -> &[Value] {
unsafe { (self._get_row_values)(self.to_ptr() as *mut Stmt) };
unsafe { (self._get_row_values)(self.to_ptr()) };
if self.current_row.is_null() || self.current_row_len < 1 {
return &[];
}
@@ -628,7 +637,7 @@ impl Stmt {
pub fn get_column_names(&self) -> Vec<String> {
let mut count_value: i32 = 0;
let count: *mut i32 = &mut count_value;
let col_names = unsafe { (self._get_column_names)(self.to_ptr() as *mut Stmt, count) };
let col_names = unsafe { (self._get_column_names)(self.to_ptr(), count) };
if col_names.is_null() || count_value == 0 {
return Vec::new();
}

View File

@@ -370,6 +370,7 @@ impl VTabCursor for StatsCursor {
}
tables.push(tbl);
}
master.close();
for tbl in tables {
// count rows for each table
if let Ok(mut count_stmt) = conn.prepare(&format!("SELECT COUNT(*) FROM {};", tbl)) {
@@ -378,6 +379,7 @@ impl VTabCursor for StatsCursor {
_ => 0,
};
self.rows.push((tbl, count));
count_stmt.close();
}
}
if conn
@@ -403,6 +405,7 @@ impl VTabCursor for StatsCursor {
assert_eq!(val.to_integer(), Some(42));
}
}
stmt.close();
ResultCode::OK
}