Handle multiple statements via sqlite3_exec API

This commit is contained in:
PThorpe92
2025-10-22 15:02:24 -04:00
parent 8c6a6f0aa1
commit d0fd258ab5

View File

@@ -380,25 +380,212 @@ type exec_callback = Option<
pub unsafe extern "C" fn sqlite3_exec(
db: *mut sqlite3,
sql: *const ffi::c_char,
_callback: exec_callback,
_context: *mut ffi::c_void,
_err: *mut *mut ffi::c_char,
callback: exec_callback,
context: *mut ffi::c_void,
err: *mut *mut ffi::c_char,
) -> ffi::c_int {
if db.is_null() || sql.is_null() {
return SQLITE_MISUSE;
}
let db: &mut sqlite3 = &mut *db;
let db = db.inner.lock().unwrap();
let sql = CStr::from_ptr(sql);
let sql = match sql.to_str() {
let db_ref: &mut sqlite3 = &mut *db;
let sql_cstr = CStr::from_ptr(sql);
let sql_str = match sql_cstr.to_str() {
Ok(s) => s,
Err(_) => return SQLITE_MISUSE,
};
trace!("sqlite3_exec(sql={})", sql);
match db.conn.execute(sql) {
Ok(_) => SQLITE_OK,
Err(_) => SQLITE_ERROR,
trace!("sqlite3_exec(sql={})", sql_str);
let statements = split_sql_statements(sql_str);
for stmt_sql in statements {
let trimmed = stmt_sql.trim();
if trimmed.is_empty() {
continue;
}
// check if this is a DQL statement, because we will only allow if there is a callback
let is_dql = is_query_statement(trimmed);
if is_dql && callback.is_none() {
if !err.is_null() {
let err_msg =
CString::new("queries return results, use callback or sqlite3_prepare")
.unwrap();
*err = err_msg.into_raw();
}
return SQLITE_MISUSE;
}
// For DML/DDL, use normal execute path
if !is_dql {
let db_inner = db_ref.inner.lock().unwrap();
match db_inner.conn.execute(trimmed) {
Ok(_) => continue,
Err(e) => {
if !err.is_null() {
let err_msg = format!("SQL error: {e:?}");
*err = CString::new(err_msg).unwrap().into_raw();
}
return SQLITE_ERROR;
}
}
} else {
// Handle DQL with callback
let rc = execute_query_with_callback(db, trimmed, callback, context, err);
if rc != SQLITE_OK {
return rc;
}
}
}
SQLITE_OK
}
/// Detect if a SQL statement is DQL
fn is_query_statement(sql: &str) -> bool {
let sql_upper = sql.to_uppercase();
let first_token = sql_upper.split_whitespace().next().unwrap_or("");
matches!(
first_token,
"SELECT" | "VALUES" | "WITH" | "PRAGMA" | "EXPLAIN"
) || sql_upper.contains("RETURNING")
}
/// Execute a query statement with callback for each row
/// Only called when we know callback is Some
unsafe fn execute_query_with_callback(
db: *mut sqlite3,
sql: &str,
callback: exec_callback,
context: *mut ffi::c_void,
err: *mut *mut ffi::c_char,
) -> ffi::c_int {
let sql_cstring = match CString::new(sql) {
Ok(s) => s,
Err(_) => return SQLITE_MISUSE,
};
let mut stmt_ptr: *mut sqlite3_stmt = std::ptr::null_mut();
let rc = sqlite3_prepare_v2(
db,
sql_cstring.as_ptr(),
-1,
&mut stmt_ptr,
std::ptr::null_mut(),
);
if rc != SQLITE_OK {
if !err.is_null() {
let err_msg = format!("Prepare failed: {rc}");
*err = CString::new(err_msg).unwrap().into_raw();
}
return rc;
}
let stmt_ref = &*stmt_ptr;
let n_cols = stmt_ref.stmt.num_columns() as ffi::c_int;
let mut column_names: Vec<CString> = Vec::with_capacity(n_cols as usize);
for i in 0..n_cols {
let name = stmt_ref.stmt.get_column_name(i as usize);
column_names.push(CString::new(name.as_bytes()).unwrap());
}
loop {
let step_rc = sqlite3_step(stmt_ptr);
match step_rc {
SQLITE_ROW => {
// Safety: checked earlier
let callback = callback.unwrap();
let mut values: Vec<CString> = Vec::with_capacity(n_cols as usize);
let mut value_ptrs: Vec<*mut ffi::c_char> = Vec::with_capacity(n_cols as usize);
let mut col_ptrs: Vec<*mut ffi::c_char> = Vec::with_capacity(n_cols as usize);
for i in 0..n_cols {
let val = stmt_ref.stmt.row().unwrap().get_value(i as usize);
values.push(CString::new(val.to_string().as_bytes()).unwrap());
}
for value in &values {
value_ptrs.push(value.as_ptr() as *mut ffi::c_char);
}
for name in &column_names {
col_ptrs.push(name.as_ptr() as *mut ffi::c_char);
}
let cb_rc = callback(
context,
n_cols,
value_ptrs.as_mut_ptr(),
col_ptrs.as_mut_ptr(),
);
if cb_rc != 0 {
sqlite3_finalize(stmt_ptr);
return SQLITE_ABORT;
}
}
SQLITE_DONE => {
break;
}
_ => {
sqlite3_finalize(stmt_ptr);
if !err.is_null() {
let err_msg = format!("Step failed: {step_rc}");
*err = CString::new(err_msg).unwrap().into_raw();
}
return step_rc;
}
}
}
sqlite3_finalize(stmt_ptr)
}
/// Split SQL string into individual statements
/// Handles quoted strings properly and skips comments
fn split_sql_statements(sql: &str) -> Vec<&str> {
let mut statements = Vec::new();
let mut current_start = 0;
let mut in_single_quote = false;
let mut in_double_quote = false;
let bytes = sql.as_bytes();
let mut i = 0;
while i < bytes.len() {
match bytes[i] {
// Check for escaped quotes first
b'\'' if !in_double_quote => {
if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
i += 2;
continue;
}
in_single_quote = !in_single_quote;
}
b'"' if !in_single_quote => {
if i + 1 < bytes.len() && bytes[i + 1] == b'"' {
i += 2;
continue;
}
in_double_quote = !in_double_quote;
}
b';' if !in_single_quote && !in_double_quote => {
// we found the statement boundary
statements.push(&sql[current_start..i]);
current_start = i + 1;
}
_ => {}
}
i += 1;
}
if current_start < sql.len() {
statements.push(&sql[current_start..]);
}
statements
}
#[no_mangle]