From 282222a39ff2d8bba6641148bf27980d83d480d2 Mon Sep 17 00:00:00 2001 From: RS2007 Date: Sat, 26 Jul 2025 17:34:10 +0530 Subject: [PATCH] feat: execute_batch working --- bindings/rust/src/lib.rs | 14 ++++++- bindings/rust/tests/integration_tests.rs | 42 +++++++++++++------ core/lib.rs | 52 ++++++++++++++++++++++++ 3 files changed, 93 insertions(+), 15 deletions(-) diff --git a/bindings/rust/src/lib.rs b/bindings/rust/src/lib.rs index e52e6e6e8..442bc4ad4 100644 --- a/bindings/rust/src/lib.rs +++ b/bindings/rust/src/lib.rs @@ -219,8 +219,9 @@ impl Connection { } /// Execute a batch of SQL statements on the database. - pub async fn execute_batch(&self, _sql: &str, _params: impl IntoParams) -> Result { - unimplemented!(); + pub async fn execute_batch(&self, sql: &str) -> Result<()> { + self.prepare_execute_batch(sql).await?; + Ok(()) } /// Prepare a SQL statement for later execution. @@ -239,6 +240,15 @@ impl Connection { Ok(statement) } + async fn prepare_execute_batch(&self, sql: impl AsRef) -> Result<()> { + let conn = self + .inner + .lock() + .map_err(|e| Error::MutexError(e.to_string()))?; + conn.prepare_execute_batch(sql)?; + Ok(()) + } + /// Query a pragma. pub fn pragma_query(&self, pragma_name: &str, mut f: F) -> Result<()> where diff --git a/bindings/rust/tests/integration_tests.rs b/bindings/rust/tests/integration_tests.rs index 9067d6440..7e79ff34e 100644 --- a/bindings/rust/tests/integration_tests.rs +++ b/bindings/rust/tests/integration_tests.rs @@ -149,16 +149,16 @@ async fn test_rows_returned() { //--- A more complicated example of insert with a select join subquery ---// conn.execute( "CREATE TABLE authors ( id INTEGER PRIMARY KEY, name TEXT NOT NULL); - ", + ", (), ) .await .unwrap(); conn.execute( - "CREATE TABLE books ( id INTEGER PRIMARY KEY, author_id INTEGER NOT NULL REFERENCES authors(id), title TEXT NOT NULL); " - ,() - ).await.unwrap(); + "CREATE TABLE books ( id INTEGER PRIMARY KEY, author_id INTEGER NOT NULL REFERENCES authors(id), title TEXT NOT NULL); " + ,() + ).await.unwrap(); conn.execute( "CREATE TABLE prize_winners ( book_id INTEGER PRIMARY KEY, author_name TEXT NOT NULL);", @@ -175,19 +175,19 @@ async fn test_rows_returned() { .unwrap(); conn.execute( - "INSERT INTO books (id, author_id, title) VALUES (1, 1, 'Rust in Action'), (2, 1, 'Async Adventures'), (3, 1, 'Fearless Concurrency'), (4, 1, 'Unsafe Tales'), (5, 1, 'Zero-Cost Futures'), (6, 2, 'Learning SQL');", - () - ).await.unwrap(); + "INSERT INTO books (id, author_id, title) VALUES (1, 1, 'Rust in Action'), (2, 1, 'Async Adventures'), (3, 1, 'Fearless Concurrency'), (4, 1, 'Unsafe Tales'), (5, 1, 'Zero-Cost Futures'), (6, 2, 'Learning SQL');", + () + ).await.unwrap(); let rows_changed = conn .execute( " - INSERT INTO prize_winners (book_id, author_name) - SELECT b.id, a.name - FROM books b - JOIN authors a ON a.id = b.author_id - WHERE a.id = 1; -- Alice’s five books - ", + INSERT INTO prize_winners (book_id, author_name) + SELECT b.id, a.name + FROM books b + JOIN authors a ON a.id = b.author_id + WHERE a.id = 1; -- Alice’s five books + ", (), ) .await @@ -195,3 +195,19 @@ async fn test_rows_returned() { assert_eq!(rows_changed, 5); } + +#[tokio::test] +pub async fn test_execute_batch() { + let db = Builder::new_local(":memory:").build().await.unwrap(); + let conn = db.connect().unwrap(); + conn.execute_batch("CREATE TABLE authors ( id INTEGER PRIMARY KEY, name TEXT NOT NULL);CREATE TABLE books ( id INTEGER PRIMARY KEY, author_id INTEGER NOT NULL REFERENCES authors(id), title TEXT NOT NULL); INSERT INTO authors (id, name) VALUES (1, 'Alice'), (2, 'Bob');") + .await + .unwrap(); + let mut rows = conn + .query("SELECT COUNT(*) FROM authors;", ()) + .await + .unwrap(); + if let Some(row) = rows.next().await.unwrap() { + assert_eq!(row.get_value(0).unwrap(), Value::Integer(2)); + } +} diff --git a/core/lib.rs b/core/lib.rs index e3cdea746..59a90a8a0 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -874,6 +874,58 @@ impl Connection { Ok(()) } + #[instrument(skip_all, level = Level::INFO)] + pub fn prepare(self: &Arc, sql: impl AsRef) -> Result { + pub fn prepare_execute_batch(self: &Arc, sql: impl AsRef) -> Result<()> { + if self.closed.get() { + return Err(LimboError::InternalError("Connection closed".to_string())); + } + if sql.as_ref().is_empty() { + return Err(LimboError::InvalidArgument( + "The supplied SQL string contains no statements".to_string(), + )); + } + let sql = sql.as_ref(); + tracing::trace!("Preparing and executing batch: {}", sql); + let mut parser = Parser::new(sql.as_bytes()); + while let Some(cmd) = parser.next()? { + dbg!(&cmd); + let syms = self.syms.borrow(); + let pager = self.pager.borrow().clone(); + let byte_offset_end = parser.offset(); + let input = str::from_utf8(&sql.as_bytes()[..byte_offset_end]) + .unwrap() + .trim(); + dbg!(&self.schema); + match cmd { + Cmd::Stmt(stmt) => { + let program = translate::translate( + self.schema.borrow().deref(), + stmt, + pager.clone(), + self.clone(), + &syms, + QueryMode::Normal, + input, + )?; + + let mut state = + vdbe::ProgramState::new(program.max_registers, program.cursor_ref.len()); + loop { + let res = + program.step(&mut state, self._db.mv_store.clone(), pager.clone())?; + if matches!(res, StepResult::Done) { + break; + } + self.run_once()?; + } + } + _ => todo!(), + } + } + Ok(()) + } + #[instrument(skip_all, level = Level::INFO)] pub fn query(self: &Arc, sql: impl AsRef) -> Result> { if self.closed.get() {