diff --git a/core/ext/vtab_xconnect.rs b/core/ext/vtab_xconnect.rs index e116d7535..cae4485e5 100644 --- a/core/ext/vtab_xconnect.rs +++ b/core/ext/vtab_xconnect.rs @@ -89,7 +89,7 @@ pub unsafe extern "C" fn execute( /// Wraps core Connection::prepare with a custom Stmt object with the necessary function pointers. /// This object is boxed/leaked and the caller is responsible for freeing the memory. -pub unsafe extern "C" fn prepare_stmt(ctx: *mut ExtConn, sql: *const c_char) -> *const Stmt { +pub unsafe extern "C" fn prepare_stmt(ctx: *mut ExtConn, sql: *const c_char) -> *mut Stmt { let c_str = unsafe { CStr::from_ptr(sql as *mut c_char) }; let sql_str = match c_str.to_str() { Ok(s) => s.to_string(), @@ -120,7 +120,7 @@ pub unsafe extern "C" fn prepare_stmt(ctx: *mut ExtConn, sql: *const c_char) -> stmt_get_column_names, stmt_free_current_row, stmt_close, - ))) as *const Stmt + ))) } Err(e) => { tracing::error!("prepare_stmt: failed to prepare statement: {:?}", e); diff --git a/extensions/core/src/vtabs.rs b/extensions/core/src/vtabs.rs index 6b0ae0fc5..755621f56 100644 --- a/extensions/core/src/vtabs.rs +++ b/extensions/core/src/vtabs.rs @@ -343,7 +343,7 @@ impl ConstraintInfo { } } -pub type PrepareStmtFn = unsafe extern "C" fn(api: *mut Conn, sql: *const c_char) -> *const Stmt; +pub type PrepareStmtFn = unsafe extern "C" fn(api: *mut Conn, sql: *const c_char) -> *mut Stmt; pub type ExecuteFn = unsafe extern "C" fn( ctx: *mut Conn, sql: *const c_char, @@ -426,9 +426,9 @@ impl Conn { Err(ResultCode::Error) } - pub fn prepare_stmt(&self, sql: &str) -> *const Stmt { + pub fn prepare_stmt(&self, sql: &str) -> *mut Stmt { let Ok(sql) = CString::new(sql) else { - return std::ptr::null(); + return std::ptr::null_mut(); }; unsafe { (self._prepare_stmt)(self as *const _ as *mut Conn, sql.as_ptr()) } } @@ -438,8 +438,16 @@ impl Conn { /// Statements can be manually closed. #[derive(Debug)] #[repr(transparent)] -pub struct Statement(*const Stmt); +pub struct Statement(*mut Stmt); +impl Drop for Statement { + fn drop(&mut self) { + if self.0.is_null() { + return; + } + unsafe { (*self.0).close() } + } +} /// Public API for methods to allow extensions to query other tables for /// the connection that opened the VTable. #[derive(Debug)] @@ -562,12 +570,13 @@ impl Stmt { } /// Close the statement - pub fn close(&self) { + pub fn close(&mut self) { // null check to prevent double free if self._ctx.is_null() { return; } unsafe { (self._close)(self as *const Stmt as *mut Stmt) }; + self._ctx = std::ptr::null_mut(); } /// # Safety @@ -580,21 +589,21 @@ impl Stmt { } /// Returns the pointer to the statement. - pub fn to_ptr(&self) -> *const Stmt { - self + pub fn to_ptr(&self) -> *mut Stmt { + self as *const Stmt as *mut Stmt } /// Bind a value to a parameter in the prepared statement /// Own the value so it can be freed in core fn bind_args(&self, idx: NonZeroUsize, arg: Value) { unsafe { - (self._bind_args_fn)(self.to_ptr() as *mut Stmt, idx.get() as i32, arg); + (self._bind_args_fn)(self.to_ptr(), idx.get() as i32, arg); }; } /// Execute the statement to attempt to retrieve the next result row. fn step(&self) -> StepResult { - unsafe { (self._step)(self.to_ptr() as *mut Stmt) }.into() + unsafe { (self._step)(self.to_ptr()) }.into() } /// Free the memory for the values obtained from the `get_row` method. @@ -608,7 +617,7 @@ impl Stmt { return; } // free from the core side so we don't have to expose `__free_internal_type` - (self._free_current_row)(self.to_ptr() as *mut Stmt); + (self._free_current_row)(self.to_ptr()); self.current_row = std::ptr::null_mut(); self.current_row_len = -1; } @@ -616,7 +625,7 @@ impl Stmt { /// Returns the values from the current row in the prepared statement, should /// be called after the step() method returns `StepResult::Row` pub fn get_row(&self) -> &[Value] { - unsafe { (self._get_row_values)(self.to_ptr() as *mut Stmt) }; + unsafe { (self._get_row_values)(self.to_ptr()) }; if self.current_row.is_null() || self.current_row_len < 1 { return &[]; } @@ -628,7 +637,7 @@ impl Stmt { pub fn get_column_names(&self) -> Vec { let mut count_value: i32 = 0; let count: *mut i32 = &mut count_value; - let col_names = unsafe { (self._get_column_names)(self.to_ptr() as *mut Stmt, count) }; + let col_names = unsafe { (self._get_column_names)(self.to_ptr(), count) }; if col_names.is_null() || count_value == 0 { return Vec::new(); } diff --git a/extensions/tests/src/lib.rs b/extensions/tests/src/lib.rs index 5fe19db75..9c8ad0019 100644 --- a/extensions/tests/src/lib.rs +++ b/extensions/tests/src/lib.rs @@ -370,6 +370,7 @@ impl VTabCursor for StatsCursor { } tables.push(tbl); } + master.close(); for tbl in tables { // count rows for each table if let Ok(mut count_stmt) = conn.prepare(&format!("SELECT COUNT(*) FROM {};", tbl)) { @@ -378,6 +379,7 @@ impl VTabCursor for StatsCursor { _ => 0, }; self.rows.push((tbl, count)); + count_stmt.close(); } } if conn @@ -403,6 +405,7 @@ impl VTabCursor for StatsCursor { assert_eq!(val.to_integer(), Some(42)); } } + stmt.close(); ResultCode::OK }