From 83303b8c5b2e8121f07061733e977d477554f131 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Tue, 16 Sep 2025 12:21:17 +0400 Subject: [PATCH] properly guard access to the sync engine with locks --- bindings/javascript/sync/src/lib.rs | 116 +++++++++++++++++----------- 1 file changed, 72 insertions(+), 44 deletions(-) diff --git a/bindings/javascript/sync/src/lib.rs b/bindings/javascript/sync/src/lib.rs index dd5e3f080..c70932081 100644 --- a/bindings/javascript/sync/src/lib.rs +++ b/bindings/javascript/sync/src/lib.rs @@ -6,7 +6,7 @@ pub mod js_protocol_io; use std::{ collections::HashMap, - sync::{Arc, Mutex, OnceLock, RwLock, RwLockReadGuard, RwLockWriteGuard}, + sync::{Arc, Mutex, MutexGuard, OnceLock, RwLock, RwLockReadGuard}, }; use napi::bindgen_prelude::{AsyncTask, Either5, Null}; @@ -28,17 +28,50 @@ pub struct DatabaseOpts { pub path: String, } +pub struct SyncEngineGuard { + inner: Arc>>>, + wait_lock: Mutex<()>, + push_lock: Mutex<()>, + pull_lock: Mutex<()>, + checkpoint_lock: Mutex<()>, +} + +impl SyncEngineGuard { + fn checkpoint_lock(&self) -> (MutexGuard<'_, ()>, MutexGuard<'_, ()>, MutexGuard<'_, ()>) { + let push = self.push_lock.lock().unwrap(); + let pull = self.pull_lock.lock().unwrap(); + let checkpoint = self.checkpoint_lock.lock().unwrap(); + (push, pull, checkpoint) + } + fn pull_lock(&self) -> (MutexGuard<'_, ()>, MutexGuard<'_, ()>, MutexGuard<'_, ()>) { + let wait = self.wait_lock.lock().unwrap(); + let push = self.push_lock.lock().unwrap(); + let pull = self.pull_lock.lock().unwrap(); + (wait, push, pull) + } + fn push_lock(&self) -> MutexGuard<'_, ()> { + let push = self.push_lock.lock().unwrap(); + push + } + fn wait_lock(&self) -> (MutexGuard<'_, ()>, MutexGuard<'_, ()>) { + let wait = self.wait_lock.lock().unwrap(); + let pull = self.pull_lock.lock().unwrap(); + (wait, pull) + } +} + #[napi] pub struct SyncEngine { path: String, client_name: String, wal_pull_batch_size: u32, + long_poll_timeout: Option, protocol_version: DatabaseSyncEngineProtocolVersion, tables_ignore: Vec, use_transform: bool, io: Option>, protocol: Option>, - sync_engine: Arc>>>, + sync_engine: Arc, opened: Arc>>, } @@ -123,6 +156,7 @@ pub struct SyncEngineOpts { pub path: String, pub client_name: Option, pub wal_pull_batch_size: Option, + pub long_poll_timeout_ms: Option, pub tracing: Option, pub tables_ignore: Option>, pub use_transform: bool, @@ -174,10 +208,19 @@ impl SyncEngine { path: opts.path, client_name: opts.client_name.unwrap_or("turso-sync-js".to_string()), wal_pull_batch_size: opts.wal_pull_batch_size.unwrap_or(100), + long_poll_timeout: opts + .long_poll_timeout_ms + .map(|x| std::time::Duration::from_millis(x as u64)), tables_ignore: opts.tables_ignore.unwrap_or_default(), use_transform: opts.use_transform, #[allow(clippy::arc_with_non_send_sync)] - sync_engine: Arc::new(RwLock::new(None)), + sync_engine: Arc::new(SyncEngineGuard { + inner: Arc::new(RwLock::new(None)), + wait_lock: Mutex::new(()), + push_lock: Mutex::new(()), + pull_lock: Mutex::new(()), + checkpoint_lock: Mutex::new(()), + }), io: Some(io), protocol: Some(Arc::new(JsProtocolIo::default())), #[allow(clippy::arc_with_non_send_sync)] @@ -196,6 +239,7 @@ impl SyncEngine { let opts = DatabaseSyncEngineOpts { client_name: self.client_name.clone(), wal_pull_batch_size: self.wal_pull_batch_size as u64, + long_poll_timeout: self.long_poll_timeout, tables_ignore: self.tables_ignore.clone(), use_transform: self.use_transform, protocol_version_hint: self.protocol_version, @@ -213,7 +257,7 @@ impl SyncEngine { let connection = initialized.connect_rw(&coro).await?; let db = turso_node::Database::create(None, io.clone(), connection, path); - *sync_engine.write().unwrap() = Some(initialized); + *sync_engine.inner.write().unwrap() = Some(initialized); *opened.lock().unwrap() = Some(db); Ok(()) }); @@ -246,9 +290,10 @@ impl SyncEngine { #[napi] pub fn sync(&self) -> GeneratorHolder { - self.run(async move |coro, sync_engine| { - let mut sync_engine = try_write(sync_engine)?; - let sync_engine = try_unwrap_mut(&mut sync_engine)?; + self.run(async move |coro, guard| { + let _lock = guard.pull_lock(); + let sync_engine = try_read(&guard.inner)?; + let sync_engine = try_unwrap(&sync_engine)?; sync_engine.sync(coro).await?; Ok(None) }) @@ -256,8 +301,9 @@ impl SyncEngine { #[napi] pub fn push(&self) -> GeneratorHolder { - self.run(async move |coro, sync_engine| { - let sync_engine = try_read(sync_engine)?; + self.run(async move |coro, guard| { + let _lock = guard.push_lock(); + let sync_engine = try_read(&guard.inner)?; let sync_engine = try_unwrap(&sync_engine)?; sync_engine.push_changes_to_remote(coro).await?; Ok(None) @@ -266,8 +312,8 @@ impl SyncEngine { #[napi] pub fn stats(&self) -> GeneratorHolder { - self.run(async move |coro, sync_engine| { - let sync_engine = try_read(sync_engine)?; + self.run(async move |coro, guard| { + let sync_engine = try_read(&guard.inner)?; let sync_engine = try_unwrap(&sync_engine)?; let stats = sync_engine.stats(coro).await?; Ok(Some(GeneratorResponse::SyncEngineStats { @@ -283,19 +329,25 @@ impl SyncEngine { #[napi] pub fn pull(&self) -> GeneratorHolder { - self.run(async move |coro, sync_engine| { - let mut sync_engine = try_write(sync_engine)?; - let sync_engine = try_unwrap_mut(&mut sync_engine)?; - sync_engine.pull_changes_from_remote(coro).await?; + self.run(async move |coro, guard| { + let sync_engine = try_read(&guard.inner)?; + let sync_engine = try_unwrap(&sync_engine)?; + let changes = { + let _lock = guard.wait_lock(); + sync_engine.wait_changes_from_remote(coro).await? + }; + let _lock = guard.pull_lock(); + sync_engine.apply_changes_from_remote(coro, changes).await?; Ok(None) }) } #[napi] pub fn checkpoint(&self) -> GeneratorHolder { - self.run(async move |coro, sync_engine| { - let mut sync_engine = try_write(sync_engine)?; - let sync_engine = try_unwrap_mut(&mut sync_engine)?; + self.run(async move |coro, guard| { + let _lock = guard.checkpoint_lock(); + let sync_engine = try_read(&guard.inner)?; + let sync_engine = try_unwrap(&sync_engine)?; sync_engine.checkpoint(coro).await?; Ok(None) }) @@ -315,7 +367,7 @@ impl SyncEngine { #[napi] pub fn close(&mut self) { - let _ = self.sync_engine.write().unwrap().take(); + let _ = self.sync_engine.inner.write().unwrap().take(); let _ = self.opened.lock().unwrap().take().unwrap(); let _ = self.io.take(); let _ = self.protocol.take(); @@ -344,7 +396,7 @@ impl SyncEngine { &self, f: impl AsyncFnOnce( &Coro<()>, - &Arc>>>, + &Arc, ) -> turso_sync_engine::Result> + 'static, ) -> GeneratorHolder { @@ -378,18 +430,6 @@ fn try_read( Ok(sync_engine) } -fn try_write( - sync_engine: &RwLock>>, -) -> turso_sync_engine::Result>>> { - let Ok(sync_engine) = sync_engine.try_write() else { - let nasty_error = "sync_engine is busy".to_string(); - return Err(turso_sync_engine::errors::Error::DatabaseSyncEngineError( - nasty_error, - )); - }; - Ok(sync_engine) -} - fn try_unwrap<'a>( sync_engine: &'a RwLockReadGuard<'_, Option>>, ) -> turso_sync_engine::Result<&'a DatabaseSyncEngine> { @@ -401,15 +441,3 @@ fn try_unwrap<'a>( }; Ok(sync_engine) } - -fn try_unwrap_mut<'a>( - sync_engine: &'a mut RwLockWriteGuard<'_, Option>>, -) -> turso_sync_engine::Result<&'a mut DatabaseSyncEngine> { - let Some(sync_engine) = sync_engine.as_mut() else { - let error = "sync_engine must be initialized".to_string(); - return Err(turso_sync_engine::errors::Error::DatabaseSyncEngineError( - error, - )); - }; - Ok(sync_engine) -}