properly guard access to the sync engine with locks

This commit is contained in:
Nikita Sivukhin
2025-09-16 12:21:17 +04:00
parent 160119b12e
commit 83303b8c5b

View File

@@ -6,7 +6,7 @@ pub mod js_protocol_io;
use std::{
collections::HashMap,
sync::{Arc, Mutex, OnceLock, RwLock, RwLockReadGuard, RwLockWriteGuard},
sync::{Arc, Mutex, MutexGuard, OnceLock, RwLock, RwLockReadGuard},
};
use napi::bindgen_prelude::{AsyncTask, Either5, Null};
@@ -28,17 +28,50 @@ pub struct DatabaseOpts {
pub path: String,
}
pub struct SyncEngineGuard {
inner: Arc<RwLock<Option<DatabaseSyncEngine<JsProtocolIo>>>>,
wait_lock: Mutex<()>,
push_lock: Mutex<()>,
pull_lock: Mutex<()>,
checkpoint_lock: Mutex<()>,
}
impl SyncEngineGuard {
fn checkpoint_lock(&self) -> (MutexGuard<'_, ()>, MutexGuard<'_, ()>, MutexGuard<'_, ()>) {
let push = self.push_lock.lock().unwrap();
let pull = self.pull_lock.lock().unwrap();
let checkpoint = self.checkpoint_lock.lock().unwrap();
(push, pull, checkpoint)
}
fn pull_lock(&self) -> (MutexGuard<'_, ()>, MutexGuard<'_, ()>, MutexGuard<'_, ()>) {
let wait = self.wait_lock.lock().unwrap();
let push = self.push_lock.lock().unwrap();
let pull = self.pull_lock.lock().unwrap();
(wait, push, pull)
}
fn push_lock(&self) -> MutexGuard<'_, ()> {
let push = self.push_lock.lock().unwrap();
push
}
fn wait_lock(&self) -> (MutexGuard<'_, ()>, MutexGuard<'_, ()>) {
let wait = self.wait_lock.lock().unwrap();
let pull = self.pull_lock.lock().unwrap();
(wait, pull)
}
}
#[napi]
pub struct SyncEngine {
path: String,
client_name: String,
wal_pull_batch_size: u32,
long_poll_timeout: Option<std::time::Duration>,
protocol_version: DatabaseSyncEngineProtocolVersion,
tables_ignore: Vec<String>,
use_transform: bool,
io: Option<Arc<dyn turso_core::IO>>,
protocol: Option<Arc<JsProtocolIo>>,
sync_engine: Arc<RwLock<Option<DatabaseSyncEngine<JsProtocolIo>>>>,
sync_engine: Arc<SyncEngineGuard>,
opened: Arc<Mutex<Option<turso_node::Database>>>,
}
@@ -123,6 +156,7 @@ pub struct SyncEngineOpts {
pub path: String,
pub client_name: Option<String>,
pub wal_pull_batch_size: Option<u32>,
pub long_poll_timeout_ms: Option<u32>,
pub tracing: Option<String>,
pub tables_ignore: Option<Vec<String>>,
pub use_transform: bool,
@@ -174,10 +208,19 @@ impl SyncEngine {
path: opts.path,
client_name: opts.client_name.unwrap_or("turso-sync-js".to_string()),
wal_pull_batch_size: opts.wal_pull_batch_size.unwrap_or(100),
long_poll_timeout: opts
.long_poll_timeout_ms
.map(|x| std::time::Duration::from_millis(x as u64)),
tables_ignore: opts.tables_ignore.unwrap_or_default(),
use_transform: opts.use_transform,
#[allow(clippy::arc_with_non_send_sync)]
sync_engine: Arc::new(RwLock::new(None)),
sync_engine: Arc::new(SyncEngineGuard {
inner: Arc::new(RwLock::new(None)),
wait_lock: Mutex::new(()),
push_lock: Mutex::new(()),
pull_lock: Mutex::new(()),
checkpoint_lock: Mutex::new(()),
}),
io: Some(io),
protocol: Some(Arc::new(JsProtocolIo::default())),
#[allow(clippy::arc_with_non_send_sync)]
@@ -196,6 +239,7 @@ impl SyncEngine {
let opts = DatabaseSyncEngineOpts {
client_name: self.client_name.clone(),
wal_pull_batch_size: self.wal_pull_batch_size as u64,
long_poll_timeout: self.long_poll_timeout,
tables_ignore: self.tables_ignore.clone(),
use_transform: self.use_transform,
protocol_version_hint: self.protocol_version,
@@ -213,7 +257,7 @@ impl SyncEngine {
let connection = initialized.connect_rw(&coro).await?;
let db = turso_node::Database::create(None, io.clone(), connection, path);
*sync_engine.write().unwrap() = Some(initialized);
*sync_engine.inner.write().unwrap() = Some(initialized);
*opened.lock().unwrap() = Some(db);
Ok(())
});
@@ -246,9 +290,10 @@ impl SyncEngine {
#[napi]
pub fn sync(&self) -> GeneratorHolder {
self.run(async move |coro, sync_engine| {
let mut sync_engine = try_write(sync_engine)?;
let sync_engine = try_unwrap_mut(&mut sync_engine)?;
self.run(async move |coro, guard| {
let _lock = guard.pull_lock();
let sync_engine = try_read(&guard.inner)?;
let sync_engine = try_unwrap(&sync_engine)?;
sync_engine.sync(coro).await?;
Ok(None)
})
@@ -256,8 +301,9 @@ impl SyncEngine {
#[napi]
pub fn push(&self) -> GeneratorHolder {
self.run(async move |coro, sync_engine| {
let sync_engine = try_read(sync_engine)?;
self.run(async move |coro, guard| {
let _lock = guard.push_lock();
let sync_engine = try_read(&guard.inner)?;
let sync_engine = try_unwrap(&sync_engine)?;
sync_engine.push_changes_to_remote(coro).await?;
Ok(None)
@@ -266,8 +312,8 @@ impl SyncEngine {
#[napi]
pub fn stats(&self) -> GeneratorHolder {
self.run(async move |coro, sync_engine| {
let sync_engine = try_read(sync_engine)?;
self.run(async move |coro, guard| {
let sync_engine = try_read(&guard.inner)?;
let sync_engine = try_unwrap(&sync_engine)?;
let stats = sync_engine.stats(coro).await?;
Ok(Some(GeneratorResponse::SyncEngineStats {
@@ -283,19 +329,25 @@ impl SyncEngine {
#[napi]
pub fn pull(&self) -> GeneratorHolder {
self.run(async move |coro, sync_engine| {
let mut sync_engine = try_write(sync_engine)?;
let sync_engine = try_unwrap_mut(&mut sync_engine)?;
sync_engine.pull_changes_from_remote(coro).await?;
self.run(async move |coro, guard| {
let sync_engine = try_read(&guard.inner)?;
let sync_engine = try_unwrap(&sync_engine)?;
let changes = {
let _lock = guard.wait_lock();
sync_engine.wait_changes_from_remote(coro).await?
};
let _lock = guard.pull_lock();
sync_engine.apply_changes_from_remote(coro, changes).await?;
Ok(None)
})
}
#[napi]
pub fn checkpoint(&self) -> GeneratorHolder {
self.run(async move |coro, sync_engine| {
let mut sync_engine = try_write(sync_engine)?;
let sync_engine = try_unwrap_mut(&mut sync_engine)?;
self.run(async move |coro, guard| {
let _lock = guard.checkpoint_lock();
let sync_engine = try_read(&guard.inner)?;
let sync_engine = try_unwrap(&sync_engine)?;
sync_engine.checkpoint(coro).await?;
Ok(None)
})
@@ -315,7 +367,7 @@ impl SyncEngine {
#[napi]
pub fn close(&mut self) {
let _ = self.sync_engine.write().unwrap().take();
let _ = self.sync_engine.inner.write().unwrap().take();
let _ = self.opened.lock().unwrap().take().unwrap();
let _ = self.io.take();
let _ = self.protocol.take();
@@ -344,7 +396,7 @@ impl SyncEngine {
&self,
f: impl AsyncFnOnce(
&Coro<()>,
&Arc<RwLock<Option<DatabaseSyncEngine<JsProtocolIo>>>>,
&Arc<SyncEngineGuard>,
) -> turso_sync_engine::Result<Option<GeneratorResponse>>
+ 'static,
) -> GeneratorHolder {
@@ -378,18 +430,6 @@ fn try_read(
Ok(sync_engine)
}
fn try_write(
sync_engine: &RwLock<Option<DatabaseSyncEngine<JsProtocolIo>>>,
) -> turso_sync_engine::Result<RwLockWriteGuard<'_, Option<DatabaseSyncEngine<JsProtocolIo>>>> {
let Ok(sync_engine) = sync_engine.try_write() else {
let nasty_error = "sync_engine is busy".to_string();
return Err(turso_sync_engine::errors::Error::DatabaseSyncEngineError(
nasty_error,
));
};
Ok(sync_engine)
}
fn try_unwrap<'a>(
sync_engine: &'a RwLockReadGuard<'_, Option<DatabaseSyncEngine<JsProtocolIo>>>,
) -> turso_sync_engine::Result<&'a DatabaseSyncEngine<JsProtocolIo>> {
@@ -401,15 +441,3 @@ fn try_unwrap<'a>(
};
Ok(sync_engine)
}
fn try_unwrap_mut<'a>(
sync_engine: &'a mut RwLockWriteGuard<'_, Option<DatabaseSyncEngine<JsProtocolIo>>>,
) -> turso_sync_engine::Result<&'a mut DatabaseSyncEngine<JsProtocolIo>> {
let Some(sync_engine) = sync_engine.as_mut() else {
let error = "sync_engine must be initialized".to_string();
return Err(turso_sync_engine::errors::Error::DatabaseSyncEngineError(
error,
));
};
Ok(sync_engine)
}