diff --git a/sqlite3/src/lib.rs b/sqlite3/src/lib.rs index 0083715ae..32e20b0ca 100644 --- a/sqlite3/src/lib.rs +++ b/sqlite3/src/lib.rs @@ -380,25 +380,212 @@ type exec_callback = Option< pub unsafe extern "C" fn sqlite3_exec( db: *mut sqlite3, sql: *const ffi::c_char, - _callback: exec_callback, - _context: *mut ffi::c_void, - _err: *mut *mut ffi::c_char, + callback: exec_callback, + context: *mut ffi::c_void, + err: *mut *mut ffi::c_char, ) -> ffi::c_int { if db.is_null() || sql.is_null() { return SQLITE_MISUSE; } - let db: &mut sqlite3 = &mut *db; - let db = db.inner.lock().unwrap(); - let sql = CStr::from_ptr(sql); - let sql = match sql.to_str() { + + let db_ref: &mut sqlite3 = &mut *db; + let sql_cstr = CStr::from_ptr(sql); + let sql_str = match sql_cstr.to_str() { Ok(s) => s, Err(_) => return SQLITE_MISUSE, }; - trace!("sqlite3_exec(sql={})", sql); - match db.conn.execute(sql) { - Ok(_) => SQLITE_OK, - Err(_) => SQLITE_ERROR, + trace!("sqlite3_exec(sql={})", sql_str); + let statements = split_sql_statements(sql_str); + + for stmt_sql in statements { + let trimmed = stmt_sql.trim(); + if trimmed.is_empty() { + continue; + } + + // check if this is a DQL statement, because we will only allow if there is a callback + let is_dql = is_query_statement(trimmed); + if is_dql && callback.is_none() { + if !err.is_null() { + let err_msg = + CString::new("queries return results, use callback or sqlite3_prepare") + .unwrap(); + *err = err_msg.into_raw(); + } + return SQLITE_MISUSE; + } + + // For DML/DDL, use normal execute path + if !is_dql { + let db_inner = db_ref.inner.lock().unwrap(); + match db_inner.conn.execute(trimmed) { + Ok(_) => continue, + Err(e) => { + if !err.is_null() { + let err_msg = format!("SQL error: {e:?}"); + *err = CString::new(err_msg).unwrap().into_raw(); + } + return SQLITE_ERROR; + } + } + } else { + // Handle DQL with callback + let rc = execute_query_with_callback(db, trimmed, callback, context, err); + if rc != SQLITE_OK { + return rc; + } + } } + + SQLITE_OK +} + +/// Detect if a SQL statement is DQL +fn is_query_statement(sql: &str) -> bool { + let sql_upper = sql.to_uppercase(); + let first_token = sql_upper.split_whitespace().next().unwrap_or(""); + + matches!( + first_token, + "SELECT" | "VALUES" | "WITH" | "PRAGMA" | "EXPLAIN" + ) || sql_upper.contains("RETURNING") +} + +/// Execute a query statement with callback for each row +/// Only called when we know callback is Some +unsafe fn execute_query_with_callback( + db: *mut sqlite3, + sql: &str, + callback: exec_callback, + context: *mut ffi::c_void, + err: *mut *mut ffi::c_char, +) -> ffi::c_int { + let sql_cstring = match CString::new(sql) { + Ok(s) => s, + Err(_) => return SQLITE_MISUSE, + }; + + let mut stmt_ptr: *mut sqlite3_stmt = std::ptr::null_mut(); + let rc = sqlite3_prepare_v2( + db, + sql_cstring.as_ptr(), + -1, + &mut stmt_ptr, + std::ptr::null_mut(), + ); + + if rc != SQLITE_OK { + if !err.is_null() { + let err_msg = format!("Prepare failed: {rc}"); + *err = CString::new(err_msg).unwrap().into_raw(); + } + return rc; + } + + let stmt_ref = &*stmt_ptr; + let n_cols = stmt_ref.stmt.num_columns() as ffi::c_int; + let mut column_names: Vec = Vec::with_capacity(n_cols as usize); + + for i in 0..n_cols { + let name = stmt_ref.stmt.get_column_name(i as usize); + column_names.push(CString::new(name.as_bytes()).unwrap()); + } + + loop { + let step_rc = sqlite3_step(stmt_ptr); + + match step_rc { + SQLITE_ROW => { + // Safety: checked earlier + let callback = callback.unwrap(); + + let mut values: Vec = Vec::with_capacity(n_cols as usize); + let mut value_ptrs: Vec<*mut ffi::c_char> = Vec::with_capacity(n_cols as usize); + let mut col_ptrs: Vec<*mut ffi::c_char> = Vec::with_capacity(n_cols as usize); + + for i in 0..n_cols { + let val = stmt_ref.stmt.row().unwrap().get_value(i as usize); + values.push(CString::new(val.to_string().as_bytes()).unwrap()); + } + + for value in &values { + value_ptrs.push(value.as_ptr() as *mut ffi::c_char); + } + for name in &column_names { + col_ptrs.push(name.as_ptr() as *mut ffi::c_char); + } + + let cb_rc = callback( + context, + n_cols, + value_ptrs.as_mut_ptr(), + col_ptrs.as_mut_ptr(), + ); + + if cb_rc != 0 { + sqlite3_finalize(stmt_ptr); + return SQLITE_ABORT; + } + } + SQLITE_DONE => { + break; + } + _ => { + sqlite3_finalize(stmt_ptr); + if !err.is_null() { + let err_msg = format!("Step failed: {step_rc}"); + *err = CString::new(err_msg).unwrap().into_raw(); + } + return step_rc; + } + } + } + + sqlite3_finalize(stmt_ptr) +} + +/// Split SQL string into individual statements +/// Handles quoted strings properly and skips comments +fn split_sql_statements(sql: &str) -> Vec<&str> { + let mut statements = Vec::new(); + let mut current_start = 0; + let mut in_single_quote = false; + let mut in_double_quote = false; + let bytes = sql.as_bytes(); + let mut i = 0; + + while i < bytes.len() { + match bytes[i] { + // Check for escaped quotes first + b'\'' if !in_double_quote => { + if i + 1 < bytes.len() && bytes[i + 1] == b'\'' { + i += 2; + continue; + } + in_single_quote = !in_single_quote; + } + b'"' if !in_single_quote => { + if i + 1 < bytes.len() && bytes[i + 1] == b'"' { + i += 2; + continue; + } + in_double_quote = !in_double_quote; + } + b';' if !in_single_quote && !in_double_quote => { + // we found the statement boundary + statements.push(&sql[current_start..i]); + current_start = i + 1; + } + _ => {} + } + i += 1; + } + + if current_start < sql.len() { + statements.push(&sql[current_start..]); + } + + statements } #[no_mangle]