fix exec to run over multiple statements in the string

This commit is contained in:
Nikita Sivukhin
2025-09-25 12:03:52 +04:00
parent ddfa77997d
commit a938bdcf09
9 changed files with 169 additions and 57 deletions

View File

@@ -167,16 +167,30 @@ class Database {
} }
/** /**
* Executes a SQL statement. * Executes the given SQL string
* Unlike prepared statements, this can execute strings that contain multiple SQL statements
* *
* @param {string} sql - The SQL statement string to execute. * @param {string} sql - The string containing SQL statements to execute
*/ */
exec(sql) { exec(sql) {
let stmt = this.prepare(sql); const exec = this.db.executor(sql);
try { try {
stmt.run(); while (true) {
const stepResult = exec.stepSync();
if (stepResult === STEP_IO) {
this.db.ioLoopSync();
continue;
}
if (stepResult === STEP_DONE) {
break;
}
if (stepResult === STEP_ROW) {
// For exec(), we don't need the row data, just continue
continue;
}
}
} finally { } finally {
stmt.close(); exec.reset();
} }
} }

View File

@@ -165,16 +165,32 @@ class Database {
} }
/** /**
* Executes a SQL statement. * Executes the given SQL string
* Unlike prepared statements, this can execute strings that contain multiple SQL statements
* *
* @param {string} sql - The SQL statement string to execute. * @param {string} sql - The string containing SQL statements to execute
*/ */
async exec(sql) { async exec(sql) {
const stmt = this.prepare(sql); await this.execLock.acquire();
const exec = this.db.executor(sql);
try { try {
await stmt.run(); while (true) {
const stepResult = exec.stepSync();
if (stepResult === STEP_IO) {
await this.db.ioLoopAsync();
continue;
}
if (stepResult === STEP_DONE) {
break;
}
if (stepResult === STEP_ROW) {
// For exec(), we don't need the row data, just continue
continue;
}
}
} finally { } finally {
await stmt.close(); exec.reset();
this.execLock.release();
} }
} }

View File

@@ -19,6 +19,7 @@ export interface NativeDatabase {
ioLoopAsync(): Promise<void>; ioLoopAsync(): Promise<void>;
prepare(sql: string): NativeStatement; prepare(sql: string): NativeStatement;
executor(sql: string): NativeExecutor;
defaultSafeIntegers(toggle: boolean); defaultSafeIntegers(toggle: boolean);
totalChanges(): number; totalChanges(): number;
@@ -38,6 +39,10 @@ export interface TableColumn {
type: string type: string
} }
export interface NativeExecutor {
stepSync(): number;
reset();
}
export interface NativeStatement { export interface NativeStatement {
stepAsync(): Promise<number>; stepAsync(): Promise<number>;
stepSync(): number; stepSync(): number;

View File

@@ -11,6 +11,14 @@ test('in-memory db', () => {
expect(rows).toEqual([{ x: 1 }, { x: 3 }]); expect(rows).toEqual([{ x: 1 }, { x: 3 }]);
}) })
test('exec multiple statements', async () => {
const db = new Database(":memory:");
db.exec("CREATE TABLE t(x); INSERT INTO t VALUES (1); INSERT INTO t VALUES (2)");
const stmt = db.prepare("SELECT * FROM t");
const rows = stmt.all();
expect(rows).toEqual([{ x: 1 }, { x: 2 }]);
})
test('readonly-db', () => { test('readonly-db', () => {
const path = `test-${(Math.random() * 10000) | 0}.db`; const path = `test-${(Math.random() * 10000) | 0}.db`;
try { try {

View File

@@ -1,5 +1,10 @@
/* auto-generated by NAPI-RS */ /* auto-generated by NAPI-RS */
/* eslint-disable */ /* eslint-disable */
export declare class BatchExecutor {
stepSync(): number
reset(): void
}
/** A database connection. */ /** A database connection. */
export declare class Database { export declare class Database {
/** /**
@@ -39,6 +44,7 @@ export declare class Database {
* A `Statement` instance. * A `Statement` instance.
*/ */
prepare(sql: string): Statement prepare(sql: string): Statement
executor(sql: string): BatchExecutor
/** /**
* Returns the rowid of the last row inserted. * Returns the rowid of the last row inserted.
* *

View File

@@ -508,6 +508,7 @@ if (!nativeBinding) {
throw new Error(`Failed to load native binding`) throw new Error(`Failed to load native binding`)
} }
const { Database, Statement } = nativeBinding const { BatchExecutor, Database, Statement } = nativeBinding
export { BatchExecutor }
export { Database } export { Database }
export { Statement } export { Statement }

View File

@@ -31,6 +31,14 @@ test('in-memory-db-async', async () => {
expect(rows).toEqual([{ x: 1 }, { x: 3 }]); expect(rows).toEqual([{ x: 1 }, { x: 3 }]);
}) })
test('exec multiple statements', async () => {
const db = await connect(":memory:");
await db.exec("CREATE TABLE t(x); INSERT INTO t VALUES (1); INSERT INTO t VALUES (2)");
const stmt = db.prepare("SELECT * FROM t");
const rows = await stmt.all();
expect(rows).toEqual([{ x: 1 }, { x: 2 }]);
})
test('readonly-db', async () => { test('readonly-db', async () => {
const path = `test-${(Math.random() * 10000) | 0}.db`; const path = `test-${(Math.random() * 10000) | 0}.db`;
try { try {

View File

@@ -121,12 +121,8 @@ pub struct DatabaseOpts {
pub tracing: Option<String>, pub tracing: Option<String>,
} }
fn step_sync(stmt: &Arc<RefCell<Option<turso_core::Statement>>>) -> napi::Result<u32> { fn step_sync(stmt: &Arc<RefCell<turso_core::Statement>>) -> napi::Result<u32> {
let mut stmt_ref = stmt.borrow_mut(); let mut stmt = stmt.borrow_mut();
let stmt = stmt_ref
.as_mut()
.ok_or_else(|| create_generic_error("statement has been finalized"))?;
match stmt.step() { match stmt.step() {
Ok(turso_core::StepResult::Row) => Ok(STEP_ROW), Ok(turso_core::StepResult::Row) => Ok(STEP_ROW),
Ok(turso_core::StepResult::IO) => Ok(STEP_IO), Ok(turso_core::StepResult::IO) => Ok(STEP_IO),
@@ -357,13 +353,23 @@ impl Database {
.collect(); .collect();
Ok(Statement { Ok(Statement {
#[allow(clippy::arc_with_non_send_sync)] #[allow(clippy::arc_with_non_send_sync)]
stmt: Arc::new(RefCell::new(Some(stmt))), stmt: Some(Arc::new(RefCell::new(stmt))),
column_names, column_names,
mode: RefCell::new(PresentationMode::Expanded), mode: RefCell::new(PresentationMode::Expanded),
safe_integers: Cell::new(*self.inner()?.default_safe_integers.lock().unwrap()), safe_integers: Cell::new(*self.inner()?.default_safe_integers.lock().unwrap()),
}) })
} }
#[napi]
pub fn executor(&self, sql: String) -> napi::Result<BatchExecutor> {
return Ok(BatchExecutor {
conn: Some(self.conn()?.clone()),
sql,
position: 0,
stmt: None,
});
}
/// Returns the rowid of the last row inserted. /// Returns the rowid of the last row inserted.
/// ///
/// # Returns /// # Returns
@@ -432,10 +438,55 @@ impl Database {
} }
} }
#[napi]
pub struct BatchExecutor {
conn: Option<Arc<turso_core::Connection>>,
sql: String,
position: usize,
stmt: Option<Arc<RefCell<turso_core::Statement>>>,
}
#[napi]
impl BatchExecutor {
#[napi]
pub fn step_sync(&mut self) -> Result<u32> {
loop {
if self.stmt.is_none() && self.position >= self.sql.len() {
return Ok(STEP_DONE);
}
if self.stmt.is_none() {
let conn = self.conn.as_ref().unwrap();
match conn.consume_stmt(&self.sql[self.position..]) {
Ok(Some((stmt, offset))) => {
self.position += offset;
self.stmt = Some(Arc::new(RefCell::new(stmt)));
}
Ok(None) => return Ok(STEP_DONE),
Err(err) => return Err(to_generic_error("failed to consume stmt", err)),
}
}
let stmt = self.stmt.as_ref().unwrap();
match step_sync(stmt) {
Ok(STEP_DONE) => {
let _ = self.stmt.take();
continue;
}
result => return result,
}
}
}
#[napi]
pub fn reset(&mut self) {
let _ = self.conn.take();
let _ = self.stmt.take();
}
}
/// A prepared statement. /// A prepared statement.
#[napi] #[napi]
pub struct Statement { pub struct Statement {
stmt: Arc<RefCell<Option<turso_core::Statement>>>, stmt: Option<Arc<RefCell<turso_core::Statement>>>,
column_names: Vec<std::ffi::CString>, column_names: Vec<std::ffi::CString>,
mode: RefCell<PresentationMode>, mode: RefCell<PresentationMode>,
safe_integers: Cell<bool>, safe_integers: Cell<bool>,
@@ -443,24 +494,21 @@ pub struct Statement {
#[napi] #[napi]
impl Statement { impl Statement {
pub fn stmt(&self) -> napi::Result<&Arc<RefCell<turso_core::Statement>>> {
self.stmt
.as_ref()
.ok_or_else(|| create_generic_error("statement has been finalized"))
}
#[napi] #[napi]
pub fn reset(&self) -> Result<()> { pub fn reset(&self) -> Result<()> {
let mut stmt = self.stmt.borrow_mut(); self.stmt()?.borrow_mut().reset();
let stmt = stmt
.as_mut()
.ok_or_else(|| create_generic_error("statement has been finalized"))?;
stmt.reset();
Ok(()) Ok(())
} }
/// Returns the number of parameters in the statement. /// Returns the number of parameters in the statement.
#[napi] #[napi]
pub fn parameter_count(&self) -> Result<u32> { pub fn parameter_count(&self) -> Result<u32> {
let stmt = self.stmt.borrow(); Ok(self.stmt()?.borrow().parameters_count() as u32)
let stmt = stmt
.as_ref()
.ok_or_else(|| create_generic_error("statement has been finalized"))?;
Ok(stmt.parameters_count() as u32)
} }
/// Returns the name of a parameter at a specific 1-based index. /// Returns the name of a parameter at a specific 1-based index.
@@ -470,15 +518,11 @@ impl Statement {
/// * `index` - The 1-based parameter index. /// * `index` - The 1-based parameter index.
#[napi] #[napi]
pub fn parameter_name(&self, index: u32) -> Result<Option<String>> { pub fn parameter_name(&self, index: u32) -> Result<Option<String>> {
let stmt = self.stmt.borrow();
let stmt = stmt
.as_ref()
.ok_or_else(|| create_generic_error("statement has been finalized"))?;
let non_zero_idx = NonZeroUsize::new(index as usize).ok_or_else(|| { let non_zero_idx = NonZeroUsize::new(index as usize).ok_or_else(|| {
create_error(Status::InvalidArg, "parameter index must be greater than 0") create_error(Status::InvalidArg, "parameter index must be greater than 0")
})?; })?;
let stmt = self.stmt()?.borrow();
Ok(stmt.parameters().name(non_zero_idx).map(|s| s.to_string())) Ok(stmt.parameters().name(non_zero_idx).map(|s| s.to_string()))
} }
@@ -491,11 +535,6 @@ impl Statement {
/// * `value` - The value to bind. /// * `value` - The value to bind.
#[napi] #[napi]
pub fn bind_at(&self, index: u32, value: Unknown) -> Result<()> { pub fn bind_at(&self, index: u32, value: Unknown) -> Result<()> {
let mut stmt = self.stmt.borrow_mut();
let stmt = stmt
.as_mut()
.ok_or_else(|| create_generic_error("statement has been finalized"))?;
let non_zero_idx = NonZeroUsize::new(index as usize).ok_or_else(|| { let non_zero_idx = NonZeroUsize::new(index as usize).ok_or_else(|| {
create_error(Status::InvalidArg, "parameter index must be greater than 0") create_error(Status::InvalidArg, "parameter index must be greater than 0")
})?; })?;
@@ -547,7 +586,7 @@ impl Statement {
} }
}; };
stmt.bind_at(non_zero_idx, turso_value); self.stmt()?.borrow_mut().bind_at(non_zero_idx, turso_value);
Ok(()) Ok(())
} }
@@ -555,17 +594,13 @@ impl Statement {
/// 1 = Row available, 2 = Done, 3 = I/O needed /// 1 = Row available, 2 = Done, 3 = I/O needed
#[napi] #[napi]
pub fn step_sync(&self) -> Result<u32> { pub fn step_sync(&self) -> Result<u32> {
step_sync(&self.stmt) step_sync(self.stmt()?)
} }
/// Get the current row data according to the presentation mode /// Get the current row data according to the presentation mode
#[napi] #[napi]
pub fn row<'env>(&self, env: &'env Env) -> Result<Unknown<'env>> { pub fn row<'env>(&self, env: &'env Env) -> Result<Unknown<'env>> {
let stmt_ref = self.stmt.borrow(); let stmt = self.stmt()?.borrow();
let stmt = stmt_ref
.as_ref()
.ok_or_else(|| create_generic_error("statement has been finalized"))?;
let row_data = stmt let row_data = stmt
.row() .row()
.ok_or_else(|| create_generic_error("no row data available"))?; .ok_or_else(|| create_generic_error("no row data available"))?;
@@ -647,10 +682,7 @@ impl Statement {
/// Get column information for the statement /// Get column information for the statement
#[napi(ts_return_type = "Promise<any>")] #[napi(ts_return_type = "Promise<any>")]
pub fn columns<'env>(&self, env: &'env Env) -> Result<Array<'env>> { pub fn columns<'env>(&self, env: &'env Env) -> Result<Array<'env>> {
let stmt_ref = self.stmt.borrow(); let stmt = self.stmt()?.borrow();
let stmt = stmt_ref
.as_ref()
.ok_or_else(|| create_generic_error("statement has been finalized"))?;
let column_count = stmt.num_columns(); let column_count = stmt.num_columns();
let mut js_array = env.create_array(column_count as u32)?; let mut js_array = env.create_array(column_count as u32)?;
@@ -682,13 +714,8 @@ impl Statement {
/// Finalizes the statement. /// Finalizes the statement.
#[napi] #[napi]
pub fn finalize(&self) -> Result<()> { pub fn finalize(&mut self) -> Result<()> {
match self.stmt.try_borrow_mut() { let _ = self.stmt.take();
Ok(mut stmt) => {
stmt.take();
}
Err(err) => tracing::error!("borrow error: {:?}", err),
}
Ok(()) Ok(())
} }
} }

View File

@@ -1316,6 +1316,33 @@ impl Connection {
Ok(()) Ok(())
} }
#[instrument(skip_all, level = Level::INFO)]
pub fn consume_stmt(self: &Arc<Connection>, sql: &str) -> Result<Option<(Statement, usize)>> {
let mut parser = Parser::new(sql.as_bytes());
let Some(cmd) = parser.next_cmd()? else {
return Ok(None);
};
let syms = self.syms.read();
let pager = self.pager.read().clone();
let byte_offset_end = parser.offset();
let input = str::from_utf8(&sql.as_bytes()[..byte_offset_end])
.unwrap()
.trim();
let mode = QueryMode::new(&cmd);
let (Cmd::Stmt(stmt) | Cmd::Explain(stmt) | Cmd::ExplainQueryPlan(stmt)) = cmd;
let program = translate::translate(
self.schema.read().deref(),
stmt,
pager.clone(),
self.clone(),
&syms,
mode,
input,
)?;
let stmt = Statement::new(program, self.db.mv_store.clone(), pager.clone(), mode);
Ok(Some((stmt, parser.offset())))
}
#[cfg(feature = "fs")] #[cfg(feature = "fs")]
pub fn from_uri( pub fn from_uri(
uri: &str, uri: &str,