From 98e3bc0c0c62848ba03c18c241784a88bbef374b Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Thu, 27 Feb 2025 10:10:07 +0200 Subject: [PATCH] bindings/rust: Make library thread-safe --- bindings/rust/src/lib.rs | 47 ++++++++++++++++++++++++++++++++++------ 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/bindings/rust/src/lib.rs b/bindings/rust/src/lib.rs index 366700a9e..0e331caf6 100644 --- a/bindings/rust/src/lib.rs +++ b/bindings/rust/src/lib.rs @@ -6,12 +6,14 @@ pub use params::params_from_iter; use crate::params::*; use crate::value::*; use std::rc::Rc; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; #[derive(Debug, thiserror::Error)] pub enum Error { #[error("SQL conversion failure: `{0}`")] ToSqlConversionFailure(BoxError), + #[error("Mutex lock error: {0}")] + MutexError(String), } impl From for Error { @@ -51,17 +53,33 @@ pub struct Database { inner: Arc, } +unsafe impl Send for Database {} +unsafe impl Sync for Database {} + impl Database { pub fn connect(&self) -> Result { let conn = self.inner.connect(); - Ok(Connection { inner: conn }) + Ok(Connection { + inner: Arc::new(Mutex::new(conn)), + }) } } pub struct Connection { - inner: Rc, + inner: Arc>>, } +impl Clone for Connection { + fn clone(&self) -> Self { + Self { + inner: Arc::clone(&self.inner), + } + } +} + +unsafe impl Send for Connection {} +unsafe impl Sync for Connection {} + impl Connection { pub async fn query(&self, sql: &str, params: impl IntoParams) -> Result { let mut stmt = self.prepare(sql).await?; @@ -74,17 +92,26 @@ impl Connection { } pub async fn prepare(&self, sql: &str) -> Result { - let stmt = self.inner.prepare(sql)?; + let conn = self + .inner + .lock() + .map_err(|e| Error::MutexError(e.to_string()))?; + + let stmt = conn.prepare(sql)?; + Ok(Statement { - _inner: Rc::new(stmt), + inner: Arc::new(Mutex::new(Rc::new(stmt))), }) } } pub struct Statement { - _inner: Rc, + inner: Arc>>, } +unsafe impl Send for Statement {} +unsafe impl Sync for Statement {} + impl Statement { pub async fn query(&mut self, params: impl IntoParams) -> Result { let _params = params.into_params()?; @@ -110,9 +137,12 @@ pub enum Params { pub struct Transaction {} pub struct Rows { - _inner: Rc, + inner: Arc>>, } +unsafe impl Send for Rows {} +unsafe impl Sync for Rows {} + impl Rows { pub async fn next(&mut self) -> Result> { todo!(); @@ -121,6 +151,9 @@ impl Rows { pub struct Row {} +unsafe impl Send for Row {} +unsafe impl Sync for Row {} + impl Row { pub fn get_value(&self, _index: usize) -> Result { todo!();