From bb68fbdd674da59b001e2dcad5a26b6464f098d3 Mon Sep 17 00:00:00 2001 From: Yirt Grek Date: Wed, 12 Mar 2025 00:37:30 -0700 Subject: [PATCH] bindings/rust: Fix bindings so example runs --- bindings/rust/src/lib.rs | 58 ++++++++++++++++++++++++++++++++++---- bindings/rust/src/value.rs | 12 ++++++++ core/lib.rs | 4 +++ 3 files changed, 68 insertions(+), 6 deletions(-) diff --git a/bindings/rust/src/lib.rs b/bindings/rust/src/lib.rs index b65624d10..03da2149a 100644 --- a/bindings/rust/src/lib.rs +++ b/bindings/rust/src/lib.rs @@ -6,6 +6,7 @@ pub use value::Value; pub use params::params_from_iter; use crate::params::*; +use std::num::NonZero; use std::rc::Rc; use std::sync::{Arc, Mutex}; @@ -63,7 +64,7 @@ unsafe impl Sync for Database {} impl Database { pub fn connect(&self) -> Result { - let conn = self.inner.connect().unwrap(); + let conn = self.inner.connect()?; #[allow(clippy::arc_with_non_send_sync)] let connection = Connection { inner: Arc::new(Mutex::new(conn)), @@ -125,8 +126,14 @@ impl Statement { pub async fn query(&mut self, params: impl IntoParams) -> Result { let params = params.into_params()?; match params { - crate::params::Params::None => {} - _ => todo!(), + params::Params::None => (), + params::Params::Positional(values) => { + for (i, value) in values.into_iter().enumerate() { + let mut stmt = self.inner.lock().unwrap(); + stmt.bind_at(NonZero::new(i + 1).unwrap(), value.into()); + } + } + params::Params::Named(_items) => todo!(), } #[allow(clippy::arc_with_non_send_sync)] let rows = Rows { @@ -136,8 +143,42 @@ impl Statement { } pub async fn execute(&mut self, params: impl IntoParams) -> Result { - let _params = params.into_params()?; - todo!(); + let params = params.into_params()?; + match params { + params::Params::None => (), + params::Params::Positional(values) => { + for (i, value) in values.into_iter().enumerate() { + let mut stmt = self.inner.lock().unwrap(); + stmt.bind_at(NonZero::new(i + 1).unwrap(), value.into()); + } + } + params::Params::Named(_items) => todo!(), + } + loop { + let mut stmt = self.inner.lock().unwrap(); + match stmt.step() { + Ok(limbo_core::StepResult::Row) => { + // unexpected row during execution, error out. + return Ok(2); + } + Ok(limbo_core::StepResult::Done) => { + return Ok(0); + } + Ok(limbo_core::StepResult::IO) => { + let _ = stmt.run_once(); + //return Ok(1); + } + Ok(limbo_core::StepResult::Busy) => { + return Ok(4); + } + Ok(limbo_core::StepResult::Interrupt) => { + return Ok(3); + } + Err(err) => { + return Err(err.into()); + } + } + } } } @@ -191,7 +232,12 @@ impl Row { let value = &self.values[index]; match value { limbo_core::OwnedValue::Integer(i) => Ok(Value::Integer(*i)), - _ => todo!(), + limbo_core::OwnedValue::Null => Ok(Value::Null), + limbo_core::OwnedValue::Float(f) => Ok(Value::Real(*f)), + limbo_core::OwnedValue::Text(text) => Ok(Value::Text(text.to_string())), + limbo_core::OwnedValue::Blob(items) => Ok(Value::Blob(items.to_vec())), + limbo_core::OwnedValue::Agg(_agg_context) => todo!(), + limbo_core::OwnedValue::Record(_record) => todo!(), } } } diff --git a/bindings/rust/src/value.rs b/bindings/rust/src/value.rs index d5e4e393b..899eeb4e3 100644 --- a/bindings/rust/src/value.rs +++ b/bindings/rust/src/value.rs @@ -110,6 +110,18 @@ impl Value { } } +impl Into for Value { + fn into(self) -> limbo_core::OwnedValue { + match self { + Value::Null => limbo_core::OwnedValue::Null, + Value::Integer(n) => limbo_core::OwnedValue::Integer(n), + Value::Real(n) => limbo_core::OwnedValue::Float(n), + Value::Text(t) => limbo_core::OwnedValue::from_text(&t), + Value::Blob(items) => limbo_core::OwnedValue::from_blob(items), + } + } +} + impl From for Value { fn from(value: i8) -> Value { Value::Integer(value as i64) diff --git a/core/lib.rs b/core/lib.rs index 9632e1829..64c86f33a 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -542,6 +542,10 @@ impl Statement { .step(&mut self.state, self.mv_store.clone(), self.pager.clone()) } + pub fn run_once(&self) -> Result<()> { + self.pager.io.run_once() + } + pub fn num_columns(&self) -> usize { self.program.result_columns.len() }