mirror of
https://github.com/aljazceru/turso.git
synced 2026-01-07 18:24:20 +01:00
Handle multiple statements via sqlite3_exec API
This commit is contained in:
@@ -380,25 +380,212 @@ type exec_callback = Option<
|
||||
pub unsafe extern "C" fn sqlite3_exec(
|
||||
db: *mut sqlite3,
|
||||
sql: *const ffi::c_char,
|
||||
_callback: exec_callback,
|
||||
_context: *mut ffi::c_void,
|
||||
_err: *mut *mut ffi::c_char,
|
||||
callback: exec_callback,
|
||||
context: *mut ffi::c_void,
|
||||
err: *mut *mut ffi::c_char,
|
||||
) -> ffi::c_int {
|
||||
if db.is_null() || sql.is_null() {
|
||||
return SQLITE_MISUSE;
|
||||
}
|
||||
let db: &mut sqlite3 = &mut *db;
|
||||
let db = db.inner.lock().unwrap();
|
||||
let sql = CStr::from_ptr(sql);
|
||||
let sql = match sql.to_str() {
|
||||
|
||||
let db_ref: &mut sqlite3 = &mut *db;
|
||||
let sql_cstr = CStr::from_ptr(sql);
|
||||
let sql_str = match sql_cstr.to_str() {
|
||||
Ok(s) => s,
|
||||
Err(_) => return SQLITE_MISUSE,
|
||||
};
|
||||
trace!("sqlite3_exec(sql={})", sql);
|
||||
match db.conn.execute(sql) {
|
||||
Ok(_) => SQLITE_OK,
|
||||
Err(_) => SQLITE_ERROR,
|
||||
trace!("sqlite3_exec(sql={})", sql_str);
|
||||
let statements = split_sql_statements(sql_str);
|
||||
|
||||
for stmt_sql in statements {
|
||||
let trimmed = stmt_sql.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// check if this is a DQL statement, because we will only allow if there is a callback
|
||||
let is_dql = is_query_statement(trimmed);
|
||||
if is_dql && callback.is_none() {
|
||||
if !err.is_null() {
|
||||
let err_msg =
|
||||
CString::new("queries return results, use callback or sqlite3_prepare")
|
||||
.unwrap();
|
||||
*err = err_msg.into_raw();
|
||||
}
|
||||
return SQLITE_MISUSE;
|
||||
}
|
||||
|
||||
// For DML/DDL, use normal execute path
|
||||
if !is_dql {
|
||||
let db_inner = db_ref.inner.lock().unwrap();
|
||||
match db_inner.conn.execute(trimmed) {
|
||||
Ok(_) => continue,
|
||||
Err(e) => {
|
||||
if !err.is_null() {
|
||||
let err_msg = format!("SQL error: {e:?}");
|
||||
*err = CString::new(err_msg).unwrap().into_raw();
|
||||
}
|
||||
return SQLITE_ERROR;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Handle DQL with callback
|
||||
let rc = execute_query_with_callback(db, trimmed, callback, context, err);
|
||||
if rc != SQLITE_OK {
|
||||
return rc;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SQLITE_OK
|
||||
}
|
||||
|
||||
/// Detect if a SQL statement is DQL
|
||||
fn is_query_statement(sql: &str) -> bool {
|
||||
let sql_upper = sql.to_uppercase();
|
||||
let first_token = sql_upper.split_whitespace().next().unwrap_or("");
|
||||
|
||||
matches!(
|
||||
first_token,
|
||||
"SELECT" | "VALUES" | "WITH" | "PRAGMA" | "EXPLAIN"
|
||||
) || sql_upper.contains("RETURNING")
|
||||
}
|
||||
|
||||
/// Execute a query statement with callback for each row
|
||||
/// Only called when we know callback is Some
|
||||
unsafe fn execute_query_with_callback(
|
||||
db: *mut sqlite3,
|
||||
sql: &str,
|
||||
callback: exec_callback,
|
||||
context: *mut ffi::c_void,
|
||||
err: *mut *mut ffi::c_char,
|
||||
) -> ffi::c_int {
|
||||
let sql_cstring = match CString::new(sql) {
|
||||
Ok(s) => s,
|
||||
Err(_) => return SQLITE_MISUSE,
|
||||
};
|
||||
|
||||
let mut stmt_ptr: *mut sqlite3_stmt = std::ptr::null_mut();
|
||||
let rc = sqlite3_prepare_v2(
|
||||
db,
|
||||
sql_cstring.as_ptr(),
|
||||
-1,
|
||||
&mut stmt_ptr,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
|
||||
if rc != SQLITE_OK {
|
||||
if !err.is_null() {
|
||||
let err_msg = format!("Prepare failed: {rc}");
|
||||
*err = CString::new(err_msg).unwrap().into_raw();
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
let stmt_ref = &*stmt_ptr;
|
||||
let n_cols = stmt_ref.stmt.num_columns() as ffi::c_int;
|
||||
let mut column_names: Vec<CString> = Vec::with_capacity(n_cols as usize);
|
||||
|
||||
for i in 0..n_cols {
|
||||
let name = stmt_ref.stmt.get_column_name(i as usize);
|
||||
column_names.push(CString::new(name.as_bytes()).unwrap());
|
||||
}
|
||||
|
||||
loop {
|
||||
let step_rc = sqlite3_step(stmt_ptr);
|
||||
|
||||
match step_rc {
|
||||
SQLITE_ROW => {
|
||||
// Safety: checked earlier
|
||||
let callback = callback.unwrap();
|
||||
|
||||
let mut values: Vec<CString> = Vec::with_capacity(n_cols as usize);
|
||||
let mut value_ptrs: Vec<*mut ffi::c_char> = Vec::with_capacity(n_cols as usize);
|
||||
let mut col_ptrs: Vec<*mut ffi::c_char> = Vec::with_capacity(n_cols as usize);
|
||||
|
||||
for i in 0..n_cols {
|
||||
let val = stmt_ref.stmt.row().unwrap().get_value(i as usize);
|
||||
values.push(CString::new(val.to_string().as_bytes()).unwrap());
|
||||
}
|
||||
|
||||
for value in &values {
|
||||
value_ptrs.push(value.as_ptr() as *mut ffi::c_char);
|
||||
}
|
||||
for name in &column_names {
|
||||
col_ptrs.push(name.as_ptr() as *mut ffi::c_char);
|
||||
}
|
||||
|
||||
let cb_rc = callback(
|
||||
context,
|
||||
n_cols,
|
||||
value_ptrs.as_mut_ptr(),
|
||||
col_ptrs.as_mut_ptr(),
|
||||
);
|
||||
|
||||
if cb_rc != 0 {
|
||||
sqlite3_finalize(stmt_ptr);
|
||||
return SQLITE_ABORT;
|
||||
}
|
||||
}
|
||||
SQLITE_DONE => {
|
||||
break;
|
||||
}
|
||||
_ => {
|
||||
sqlite3_finalize(stmt_ptr);
|
||||
if !err.is_null() {
|
||||
let err_msg = format!("Step failed: {step_rc}");
|
||||
*err = CString::new(err_msg).unwrap().into_raw();
|
||||
}
|
||||
return step_rc;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sqlite3_finalize(stmt_ptr)
|
||||
}
|
||||
|
||||
/// Split SQL string into individual statements
|
||||
/// Handles quoted strings properly and skips comments
|
||||
fn split_sql_statements(sql: &str) -> Vec<&str> {
|
||||
let mut statements = Vec::new();
|
||||
let mut current_start = 0;
|
||||
let mut in_single_quote = false;
|
||||
let mut in_double_quote = false;
|
||||
let bytes = sql.as_bytes();
|
||||
let mut i = 0;
|
||||
|
||||
while i < bytes.len() {
|
||||
match bytes[i] {
|
||||
// Check for escaped quotes first
|
||||
b'\'' if !in_double_quote => {
|
||||
if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
|
||||
i += 2;
|
||||
continue;
|
||||
}
|
||||
in_single_quote = !in_single_quote;
|
||||
}
|
||||
b'"' if !in_single_quote => {
|
||||
if i + 1 < bytes.len() && bytes[i + 1] == b'"' {
|
||||
i += 2;
|
||||
continue;
|
||||
}
|
||||
in_double_quote = !in_double_quote;
|
||||
}
|
||||
b';' if !in_single_quote && !in_double_quote => {
|
||||
// we found the statement boundary
|
||||
statements.push(&sql[current_start..i]);
|
||||
current_start = i + 1;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
|
||||
if current_start < sql.len() {
|
||||
statements.push(&sql[current_start..]);
|
||||
}
|
||||
|
||||
statements
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
|
||||
Reference in New Issue
Block a user