diff --git a/bindings/javascript/__test__/sync.spec.mjs b/bindings/javascript/__test__/sync.spec.mjs index b45fdba0b..4f1b4a965 100644 --- a/bindings/javascript/__test__/sync.spec.mjs +++ b/bindings/javascript/__test__/sync.spec.mjs @@ -84,7 +84,7 @@ dualTest.both("Statement.get() [no parameters]", async (t) => { t.deepEqual(stmt.raw().get(), [1, 'Alice', 'alice@example.org']); }); -dualTest.onlySqlitePasses("Statement.get() [positional]", async (t) => { +dualTest.both("Statement.get() [positional]", async (t) => { const db = t.context.db; var stmt = 0; @@ -101,7 +101,7 @@ dualTest.onlySqlitePasses("Statement.get() [positional]", async (t) => { t.is(stmt.get({ 1: 2 }).name, "Bob"); }); -dualTest.onlySqlitePasses("Statement.get() [named]", async (t) => { +dualTest.both("Statement.get() [named]", async (t) => { const db = t.context.db; var stmt = undefined; @@ -132,7 +132,7 @@ dualTest.both("Statement.get() [raw]", async (t) => { t.deepEqual(stmt.raw().get(1), [1, "Alice", "alice@example.org"]); }); -dualTest.onlySqlitePasses("Statement.iterate() [empty]", async (t) => { +dualTest.both("Statement.iterate() [empty]", async (t) => { const db = t.context.db; const stmt = db.prepare("SELECT * FROM users WHERE id = 0"); diff --git a/bindings/javascript/index.d.ts b/bindings/javascript/index.d.ts index a57984b28..392fee456 100644 --- a/bindings/javascript/index.d.ts +++ b/bindings/javascript/index.d.ts @@ -44,4 +44,6 @@ export declare class Statement { static columns(): void bind(args?: Array | undefined | null): Statement } -export declare class IteratorStatement { } +export declare class IteratorStatement { + [Symbol.iterator](): Iterator +} diff --git a/bindings/javascript/src/lib.rs b/bindings/javascript/src/lib.rs index 64e8ab7b4..45563dfce 100644 --- a/bindings/javascript/src/lib.rs +++ b/bindings/javascript/src/lib.rs @@ -505,10 +505,21 @@ impl Statement { return Err(napi::Error::from_status(napi::Status::PendingException)); } - if args.len() == 1 && matches!(args[0].get_type()?, napi::ValueType::Object) { - bind_named_parameters(&mut stmt, args)?; + if args.len() == 1 { + if matches!(args[0].get_type()?, napi::ValueType::Object) { + let obj: napi::JsObject = + args.into_iter().next().unwrap().coerce_to_object()?; + + if obj.is_array()? { + bind_positional_param_array(&mut stmt, &obj)?; + } else { + bind_host_params(&mut stmt, &obj)?; + } + } else { + bind_single_param(&mut stmt, args.into_iter().next().unwrap())?; + } } else { - bind_parameters(&mut stmt, args)?; + bind_positional_params(&mut stmt, args)?; } } @@ -516,7 +527,7 @@ impl Statement { } } -fn bind_parameters( +fn bind_positional_params( stmt: &mut RefMut<'_, turso_core::Statement>, args: Vec, ) -> Result<(), napi::Error> { @@ -527,22 +538,94 @@ fn bind_parameters( Ok(()) } -fn bind_named_parameters( +fn bind_host_params( stmt: &mut RefMut<'_, turso_core::Statement>, - args: Vec, + obj: &napi::JsObject, ) -> Result<(), napi::Error> { - let obj: napi::JsObject = args.into_iter().next().unwrap().coerce_to_object()?; - for idx in 1..stmt.parameters_count() { + if first_key_is_number(obj) { + bind_numbered_params(stmt, obj)?; + } else { + bind_named_params(stmt, obj)?; + } + + Ok(()) +} + +fn first_key_is_number(obj: &napi::JsObject) -> bool { + napi::JsObject::keys(obj) + .iter() + .flatten() + .filter(|key| matches!(obj.has_own_property(key), Ok(result) if result)) + .take(1) + .any(|key| str::parse::(key).is_ok()) +} + +fn bind_numbered_params( + stmt: &mut RefMut<'_, turso_core::Statement>, + obj: &napi::JsObject, +) -> Result<(), napi::Error> { + for key in napi::JsObject::keys(obj)?.iter() { + let Ok(param_idx) = str::parse::(key) else { + return Err(napi::Error::new( + napi::Status::GenericFailure, + "cannot mix numbers and strings", + )); + }; + let Some(non_zero) = NonZero::new(param_idx as usize) else { + return Err(napi::Error::new( + napi::Status::GenericFailure, + "numbered parameters cannot be lower than 1", + )); + }; + + stmt.bind_at(non_zero, from_js_value(obj.get_named_property(key)?)?); + } + Ok(()) +} + +fn bind_named_params( + stmt: &mut RefMut<'_, turso_core::Statement>, + obj: &napi::JsObject, +) -> Result<(), napi::Error> { + for idx in 1..stmt.parameters_count() + 1 { let non_zero_idx = NonZero::new(idx).unwrap(); let param = stmt.parameters().name(non_zero_idx); let Some(name) = param else { - return Err(napi::Error::from_status(napi::Status::GenericFailure)); + return Err(napi::Error::from_reason(format!( + "could not find named parameter with index {}", + idx + ))); }; - let value = obj.get_named_property::(&name)?; + let value = obj.get_named_property::(&name[1..])?; stmt.bind_at(non_zero_idx, from_js_value(value)?); } + + Ok(()) +} + +fn bind_positional_param_array( + stmt: &mut RefMut<'_, turso_core::Statement>, + obj: &napi::JsObject, +) -> Result<(), napi::Error> { + assert!(obj.is_array()?, "bind_array can only be called with arrays"); + + for idx in 1..obj.get_array_length()? { + stmt.bind_at( + NonZero::new(idx as usize).unwrap(), + from_js_value(obj.get_element(idx)?)?, + ); + } + + Ok(()) +} + +fn bind_single_param( + stmt: &mut RefMut<'_, turso_core::Statement>, + obj: napi::JsUnknown, +) -> Result<(), napi::Error> { + stmt.bind_at(NonZero::new(1).unwrap(), from_js_value(obj)?); Ok(()) } @@ -554,6 +637,7 @@ pub struct IteratorStatement { presentation_mode: PresentationMode, } +#[napi] impl Generator for IteratorStatement { type Yield = JsUnknown; diff --git a/bindings/javascript/wrapper.js b/bindings/javascript/wrapper.js index 57235d8e5..9d2000386 100644 --- a/bindings/javascript/wrapper.js +++ b/bindings/javascript/wrapper.js @@ -264,8 +264,11 @@ class Statement { * * @param bindParameters - The bind parameters for executing the statement. */ - iterate(...bindParameters) { - return this.stmt.iterate(bindParameters.flat()); + *iterate(...bindParameters) { + // revisit this solution when https://github.com/napi-rs/napi-rs/issues/2574 is fixed + for (const row of this.stmt.iterate(bindParameters.flat())) { + yield row; + } } /**