diff --git a/Cargo.lock b/Cargo.lock index 754e91077..014ed79e4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1384,6 +1384,21 @@ dependencies = [ "slab", ] +[[package]] +name = "genawaiter" +version = "0.99.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c86bd0361bcbde39b13475e6e36cb24c329964aa2611be285289d1e4b751c1a0" +dependencies = [ + "genawaiter-macro", +] + +[[package]] +name = "genawaiter-macro" +version = "0.99.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b32dfe1fdfc0bbde1f22a5da25355514b5e450c33a6af6770884c8750aedfbc" + [[package]] name = "generic-array" version = "0.14.7" @@ -1609,7 +1624,7 @@ dependencies = [ "hyper", "libc", "pin-project-lite", - "socket2 0.5.10", + "socket2", "tokio", "tower-service", "tracing", @@ -3590,16 +3605,6 @@ version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" -[[package]] -name = "socket2" -version = "0.5.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" -dependencies = [ - "libc", - "windows-sys 0.52.0", -] - [[package]] name = "socket2" version = "0.6.0" @@ -3970,7 +3975,7 @@ dependencies = [ "pin-project-lite", "signal-hook-registry", "slab", - "socket2 0.6.0", + "socket2", "tokio-macros", "windows-sys 0.59.0", ] @@ -4353,6 +4358,29 @@ dependencies = [ "turso", ] +[[package]] +name = "turso_sync_engine" +version = "0.1.4-pre.2" +dependencies = [ + "bytes", + "ctor", + "futures", + "genawaiter", + "http", + "rand 0.9.2", + "rand_chacha 0.9.0", + "serde", + "serde_json", + "tempfile", + "thiserror 2.0.12", + "tokio", + "tracing", + "tracing-subscriber", + "turso", + "turso_core", + "uuid", +] + [[package]] name = "typenum" version = "1.18.0" diff --git a/Cargo.toml b/Cargo.toml index d77c9045d..1cea356af 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ members = [ "tests", "vendored/sqlite3-parser/sqlparser_bench", "packages/turso-sync", + "packages/turso-sync-engine", ] exclude = ["perf/latency/limbo"] diff --git a/bindings/rust/src/lib.rs b/bindings/rust/src/lib.rs index 79560491e..1e037a92f 100644 --- a/bindings/rust/src/lib.rs +++ b/bindings/rust/src/lib.rs @@ -150,12 +150,7 @@ impl Database { /// Connect to the database. pub fn connect(&self) -> Result { let conn = self.inner.connect()?; - #[allow(clippy::arc_with_non_send_sync)] - let connection = Connection { - inner: Arc::new(Mutex::new(conn)), - transaction_behavior: TransactionBehavior::Deferred, - }; - Ok(connection) + Ok(Connection::create(conn)) } } @@ -178,6 +173,14 @@ unsafe impl Send for Connection {} unsafe impl Sync for Connection {} impl Connection { + pub fn create(conn: Arc) -> Self { + #[allow(clippy::arc_with_non_send_sync)] + let connection = Connection { + inner: Arc::new(Mutex::new(conn)), + transaction_behavior: TransactionBehavior::Deferred, + }; + connection + } /// Query the database with SQL. pub async fn query(&self, sql: &str, params: impl IntoParams) -> Result { let mut stmt = self.prepare(sql).await?; diff --git a/core/storage/pager.rs b/core/storage/pager.rs index c260793d3..649c8cf97 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -1166,7 +1166,7 @@ impl Pager { "wal_frame_count() called on database without WAL".to_string(), )); }; - Ok(wal.borrow().get_max_frame_in_wal()) + Ok(wal.borrow().get_max_frame()) } /// Flush all dirty pages to disk. diff --git a/core/storage/wal.rs b/core/storage/wal.rs index b96d0bf26..84f66196d 100644 --- a/core/storage/wal.rs +++ b/core/storage/wal.rs @@ -868,8 +868,8 @@ impl Wal for WalFile { // if it's not, than pages from WAL range [frame_watermark..nBackfill] are already in the DB file, // and in case if page first occurrence in WAL was after frame_watermark - we will be unable to read proper previous version of the page turso_assert!( - frame_watermark.is_none() || frame_watermark.unwrap() >= self.min_frame, - "frame_watermark must be >= than current WAL min_value value" + frame_watermark.is_none() || frame_watermark.unwrap() >= self.get_shared().nbackfills.load(Ordering::SeqCst), + "frame_watermark must be >= than current WAL backfill amount: frame_watermark={:?}, nBackfill={}", frame_watermark, self.get_shared().nbackfills.load(Ordering::SeqCst) ); // if we are holding read_lock 0, skip and read right from db file. @@ -905,7 +905,7 @@ impl Wal for WalFile { let buf_len = buf.len(); turso_assert!( bytes_read == buf_len as i32, - "read({bytes_read}) less than expected({buf_len})" + "read({bytes_read}) less than expected({buf_len}): frame_id={frame_id}" ); let frame = frame.clone(); finish_read_page(page.get().id, buf, frame).unwrap(); diff --git a/packages/turso-sync-engine/.gitignore b/packages/turso-sync-engine/.gitignore new file mode 100644 index 000000000..45d1f6d52 --- /dev/null +++ b/packages/turso-sync-engine/.gitignore @@ -0,0 +1 @@ +!test_empty.db diff --git a/packages/turso-sync-engine/Cargo.toml b/packages/turso-sync-engine/Cargo.toml new file mode 100644 index 000000000..d49aa1df9 --- /dev/null +++ b/packages/turso-sync-engine/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "turso_sync_engine" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true + +[dependencies] +turso_core = { workspace = true, features = ["conn_raw_api"] } +thiserror = "2.0.12" +tracing = "0.1.41" +serde_json.workspace = true +serde = { workspace = true, features = ["derive"] } +bytes = "1.10.1" +genawaiter = { version = "0.99.1", default-features = false } +http = "1.3.1" +uuid = "1.17.0" + +[dev-dependencies] +ctor = "0.4.2" +tempfile = "3.20.0" +tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } +tokio = { version = "1.46.1", features = ["macros", "rt-multi-thread", "test-util"] } +uuid = "1.17.0" +rand = "0.9.2" +rand_chacha = "0.9.0" +turso = { workspace = true, features = ["conn_raw_api"] } +futures = "0.3.31" diff --git a/packages/turso-sync-engine/src/database_sync_engine.rs b/packages/turso-sync-engine/src/database_sync_engine.rs new file mode 100644 index 000000000..fdb879ab9 --- /dev/null +++ b/packages/turso-sync-engine/src/database_sync_engine.rs @@ -0,0 +1,987 @@ +use std::{path::Path, sync::Arc}; + +use crate::{ + database_sync_operations::{ + db_bootstrap, reset_wal_file, transfer_logical_changes, transfer_physical_changes, + wait_full_body, wal_pull, wal_push, WalPullResult, + }, + errors::Error, + io_operations::IoOperations, + protocol_io::ProtocolIO, + types::{Coro, DatabaseMetadata}, + wal_session::WalSession, + Result, +}; + +#[derive(Debug)] +pub struct DatabaseSyncEngineOpts { + client_name: String, + wal_pull_batch_size: u64, +} + +pub struct DatabaseSyncEngine { + io: Arc, + protocol: Arc

, + draft_path: String, + synced_path: String, + meta_path: String, + opts: DatabaseSyncEngineOpts, + meta: Option, + // we remember information if Synced DB is dirty - which will make Database to reset it in case of any sync attempt + // this bit is set to false when we properly reset Synced DB + // this bit is set to true when we transfer changes from Draft to Synced or on initialization + synced_is_dirty: bool, +} + +async fn update_meta( + coro: &Coro, + io: &IO, + meta_path: &str, + orig: &mut Option, + update: impl FnOnce(&mut DatabaseMetadata), +) -> Result<()> { + let mut meta = orig.as_ref().unwrap().clone(); + update(&mut meta); + tracing::debug!("update_meta: {meta:?}"); + let completion = io.full_write(meta_path, meta.dump()?)?; + // todo: what happen if we will actually update the metadata on disk but fail and so in memory state will not be updated + wait_full_body(coro, &completion).await?; + *orig = Some(meta); + Ok(()) +} + +async fn set_meta( + coro: &Coro, + io: &IO, + meta_path: &str, + orig: &mut Option, + meta: DatabaseMetadata, +) -> Result<()> { + tracing::debug!("set_meta: {meta:?}"); + let completion = io.full_write(meta_path, meta.dump()?)?; + // todo: what happen if we will actually update the metadata on disk but fail and so in memory state will not be updated + wait_full_body(coro, &completion).await?; + *orig = Some(meta); + Ok(()) +} + +impl DatabaseSyncEngine { + /// Creates new instance of SyncEngine and initialize it immediately if no consistent local data exists + pub async fn new( + coro: &Coro, + io: Arc, + http_client: Arc, + path: &Path, + opts: DatabaseSyncEngineOpts, + ) -> Result { + let Some(path) = path.to_str() else { + let error = format!("invalid path: {path:?}"); + return Err(Error::DatabaseSyncEngineError(error)); + }; + let mut db = Self { + io, + protocol: http_client, + draft_path: format!("{path}-draft"), + synced_path: format!("{path}-synced"), + meta_path: format!("{path}-info"), + opts, + meta: None, + synced_is_dirty: true, + }; + db.init(coro).await?; + Ok(db) + } + + /// Create database connection and appropriately configure it before use + pub async fn connect(&self, coro: &Coro) -> Result> { + let db = self.io.open_tape(&self.draft_path, true)?; + db.connect(coro).await + } + + /// Sync all new changes from remote DB and apply them locally + /// This method will **not** send local changed to the remote + /// This method will block writes for the period of pull + pub async fn pull(&mut self, coro: &Coro) -> Result<()> { + tracing::debug!( + "pull: draft={}, synced={}", + self.draft_path, + self.synced_path + ); + + // reset Synced DB if it wasn't properly cleaned-up on previous "sync-method" attempt + self.reset_synced_if_dirty(coro).await?; + + // update Synced DB with fresh changes from remote + self.pull_synced_from_remote(coro).await?; + + // we will "replay" Synced WAL to the Draft WAL later without pushing it to the remote + // so, we pass 'capture: true' as we need to preserve all changes for future push of WAL + let draft = self.io.open_tape(&self.draft_path, true)?; + let synced = self.io.open_tape(&self.synced_path, true)?; + + { + // we will start wal write session for Draft DB in order to hold write lock during transfer of changes + let mut draft_session = WalSession::new(draft.connect(coro).await?); + draft_session.begin()?; + + // mark Synced as dirty as we will start transfer of logical changes there and if we will fail in the middle - we will need to cleanup Synced db + self.synced_is_dirty = true; + + // transfer logical changes to the Synced DB in order to later execute physical "rebase" operation + let client_id = &self.meta().client_unique_id; + transfer_logical_changes(coro, &draft, &synced, client_id, true).await?; + + // now we are ready to do the rebase: let's transfer physical changes from Synced to Draft + let synced_wal_watermark = self.meta().synced_wal_match_watermark; + let synced_sync_watermark = self.meta().synced_frame_no.expect( + "synced_frame_no must be set as we call pull_synced_from_remote before that", + ); + let draft_wal_watermark = self.meta().draft_wal_match_watermark; + let draft_sync_watermark = transfer_physical_changes( + coro, + &synced, + draft_session, + synced_wal_watermark, + synced_sync_watermark, + draft_wal_watermark, + ) + .await?; + update_meta( + coro, + self.protocol.as_ref(), + &self.meta_path, + &mut self.meta, + |m| { + m.draft_wal_match_watermark = draft_sync_watermark; + m.synced_wal_match_watermark = synced_sync_watermark; + }, + ) + .await?; + } + + // Synced DB is 100% dirty now - let's reset it + assert!(self.synced_is_dirty); + self.reset_synced_if_dirty(coro).await?; + + Ok(()) + } + + /// Sync local changes to remote DB + /// This method will **not** pull remote changes to the local DB + /// This method will **not** block writes for the period of sync + pub async fn push(&mut self, coro: &Coro) -> Result<()> { + tracing::debug!( + "push: draft={}, synced={}", + self.draft_path, + self.synced_path + ); + + // reset Synced DB if it wasn't properly cleaned-up on previous "sync-method" attempt + self.reset_synced_if_dirty(coro).await?; + + // update Synced DB with fresh changes from remote in order to avoid WAL frame conflicts + self.pull_synced_from_remote(coro).await?; + + // we will push Synced WAL to the remote + // so, we pass 'capture: false' as we don't need to preserve changes made to Synced WAL in turso_cdc + let draft = self.io.open_tape(&self.draft_path, true)?; + let synced = self.io.open_tape(&self.synced_path, false)?; + + // mark Synced as dirty as we will start transfer of logical changes there and if we will fail in the middle - we will need to cleanup Synced db + self.synced_is_dirty = true; + + let client_id = &self.meta().client_unique_id; + transfer_logical_changes(coro, &draft, &synced, client_id, false).await?; + + self.push_synced_to_remote(coro).await?; + Ok(()) + } + + /// Sync local changes to remote DB and bring new changes from remote to local + /// This method will block writes for the period of sync + pub async fn sync(&mut self, coro: &Coro) -> Result<()> { + // todo(sivukhin): this is bit suboptimal as both 'push' and 'pull' will call pull_synced_from_remote + // but for now - keep it simple + self.push(coro).await?; + self.pull(coro).await?; + Ok(()) + } + + async fn init(&mut self, coro: &Coro) -> Result<()> { + tracing::debug!( + "initialize sync engine: draft={}, synced={}, opts={:?}", + self.draft_path, + self.synced_path, + self.opts, + ); + + let completion = self.protocol.full_read(&self.meta_path)?; + let data = wait_full_body(coro, &completion).await?; + let meta = if data.is_empty() { + None + } else { + Some(DatabaseMetadata::load(&data)?) + }; + + match meta { + Some(meta) => { + self.meta = Some(meta); + } + None => { + let meta = self.bootstrap_db_files(coro).await?; + tracing::debug!("write meta after successful bootstrap: meta={meta:?}"); + set_meta( + coro, + self.protocol.as_ref(), + &self.meta_path, + &mut self.meta, + meta, + ) + .await?; + } + }; + + let draft_exists = self.io.try_open(&self.draft_path)?.is_some(); + let synced_exists = self.io.try_open(&self.synced_path)?.is_some(); + if !draft_exists || !synced_exists { + let error = "Draft or Synced files doesn't exists, but metadata is".to_string(); + return Err(Error::DatabaseSyncEngineError(error)); + } + + if self.meta().synced_frame_no.is_none() { + // sync WAL from the remote in case of bootstrap - all subsequent initializations will be fast + self.pull(coro).await?; + } + Ok(()) + } + + async fn pull_synced_from_remote(&mut self, coro: &Coro) -> Result<()> { + tracing::debug!( + "pull_synced_from_remote: draft={:?}, synced={:?}", + self.draft_path, + self.synced_path, + ); + let synced = self.io.open_tape(&self.synced_path, false)?; + let synced_conn = synced.connect(coro).await?; + let mut wal = WalSession::new(synced_conn); + + let generation = self.meta().synced_generation; + let mut start_frame = self.meta().synced_frame_no.unwrap_or(0) + 1; + loop { + let end_frame = start_frame + self.opts.wal_pull_batch_size; + let update = async |coro, frame_no| { + update_meta( + coro, + self.protocol.as_ref(), + &self.meta_path, + &mut self.meta, + |m| m.synced_frame_no = Some(frame_no), + ) + .await + }; + match wal_pull( + coro, + self.protocol.as_ref(), + &mut wal, + generation, + start_frame, + end_frame, + update, + ) + .await? + { + WalPullResult::Done => return Ok(()), + WalPullResult::PullMore => { + start_frame = end_frame; + continue; + } + WalPullResult::NeedCheckpoint => { + return Err(Error::DatabaseSyncEngineError( + "checkpoint is temporary not supported".to_string(), + )); + } + } + } + } + + async fn push_synced_to_remote(&mut self, coro: &Coro) -> Result<()> { + tracing::debug!( + "push_synced_to_remote: draft={}, synced={}, id={}", + self.draft_path, + self.synced_path, + self.meta().client_unique_id + ); + let synced = self.io.open_tape(&self.synced_path, false)?; + let synced_conn = synced.connect(coro).await?; + + let mut wal = WalSession::new(synced_conn); + wal.begin()?; + + // todo(sivukhin): push frames in multiple batches + let generation = self.meta().synced_generation; + let start_frame = self.meta().synced_frame_no.unwrap_or(0) + 1; + let end_frame = wal.conn().wal_frame_count()? + 1; + match wal_push( + coro, + self.protocol.as_ref(), + &mut wal, + None, + generation, + start_frame, + end_frame, + ) + .await + { + Ok(_) => { + update_meta( + coro, + self.protocol.as_ref(), + &self.meta_path, + &mut self.meta, + |m| m.synced_frame_no = Some(end_frame - 1), + ) + .await?; + self.synced_is_dirty = false; + Ok(()) + } + Err(err) => { + tracing::info!("push_synced_to_remote: failed: err={err}"); + Err(err) + } + } + } + + async fn bootstrap_db_files(&mut self, coro: &Coro) -> Result { + assert!( + self.meta.is_none(), + "bootstrap_db_files must be called only when meta is not set" + ); + tracing::debug!( + "bootstrap_db_files: draft={}, synced={}", + self.draft_path, + self.synced_path, + ); + + let start_time = std::time::Instant::now(); + // cleanup all files left from previous attempt to bootstrap + // we shouldn't write any WAL files - but let's truncate them too for safety + if let Some(file) = self.io.try_open(&self.draft_path)? { + self.io.truncate(coro, file, 0).await?; + } + if let Some(file) = self.io.try_open(&self.synced_path)? { + self.io.truncate(coro, file, 0).await?; + } + if let Some(file) = self.io.try_open(&format!("{}-wal", self.draft_path))? { + self.io.truncate(coro, file, 0).await?; + } + if let Some(file) = self.io.try_open(&format!("{}-wal", self.synced_path))? { + self.io.truncate(coro, file, 0).await?; + } + + let files = &[ + self.io.create(&self.draft_path)?, + self.io.create(&self.synced_path)?, + ]; + let db_info = db_bootstrap(coro, self.protocol.as_ref(), files).await?; + + let elapsed = std::time::Instant::now().duration_since(start_time); + tracing::debug!( + "bootstrap_db_files: finished draft={:?}, synced={:?}: elapsed={:?}", + self.draft_path, + self.synced_path, + elapsed + ); + + Ok(DatabaseMetadata { + client_unique_id: format!("{}-{}", self.opts.client_name, uuid::Uuid::new_v4()), + synced_generation: db_info.current_generation, + synced_frame_no: None, + draft_wal_match_watermark: 0, + synced_wal_match_watermark: 0, + }) + } + + /// Reset WAL of Synced database which potentially can have some local changes + async fn reset_synced_if_dirty(&mut self, coro: &Coro) -> Result<()> { + tracing::debug!( + "reset_synced: synced_path={:?}, synced_is_dirty={}", + self.synced_path, + self.synced_is_dirty + ); + // if we know that Synced DB is not dirty - let's skip this phase completely + if !self.synced_is_dirty { + return Ok(()); + } + if let Some(synced_wal) = self.io.try_open(&format!("{}-wal", self.synced_path))? { + reset_wal_file(coro, synced_wal, self.meta().synced_frame_no.unwrap_or(0)).await?; + } + self.synced_is_dirty = false; + Ok(()) + } + + fn meta(&self) -> &DatabaseMetadata { + self.meta.as_ref().expect("metadata must be set") + } +} + +#[cfg(test)] +pub mod tests { + use std::sync::Arc; + + use rand::RngCore; + + use crate::{ + database_sync_engine::DatabaseSyncEngineOpts, + test_context::{FaultInjectionStrategy, TestContext}, + test_protocol_io::TestProtocolIo, + test_sync_server::convert_rows, + tests::{deterministic_runtime, seed_u64, TestRunner}, + Result, + }; + + async fn query_rows(conn: &turso::Connection, sql: &str) -> Result>> { + let mut rows = conn.query(sql, ()).await?; + convert_rows(&mut rows).await + } + + #[test] + pub fn test_sync_single_db_simple() { + deterministic_runtime(async || { + let io: Arc = Arc::new(turso_core::PlatformIO::new().unwrap()); + let dir = tempfile::TempDir::new().unwrap(); + let server_path = dir.path().join("server.db"); + let ctx = Arc::new(TestContext::new(seed_u64())); + let protocol = TestProtocolIo::new(ctx.clone(), &server_path) + .await + .unwrap(); + let mut runner = TestRunner::new(ctx.clone(), io, protocol.clone()); + let local_path = dir.path().join("local.db"); + let opts = DatabaseSyncEngineOpts { + client_name: "id-1".to_string(), + wal_pull_batch_size: 1, + }; + runner.init(local_path, opts).await.unwrap(); + + protocol + .server + .execute("CREATE TABLE t(x INTEGER PRIMARY KEY)", ()) + .await + .unwrap(); + protocol + .server + .execute("INSERT INTO t VALUES (1)", ()) + .await + .unwrap(); + + let conn = runner.connect().await.unwrap(); + + // no table in schema before sync from remote (as DB was initialized when remote was empty) + assert!(matches!( + query_rows(&conn, "SELECT * FROM t").await, + Err(x) if x.to_string().contains("no such table: t") + )); + + // 1 rows synced + runner.pull().await.unwrap(); + assert_eq!( + query_rows(&conn, "SELECT * FROM t").await.unwrap(), + vec![vec![turso::Value::Integer(1)]] + ); + + protocol + .server + .execute("INSERT INTO t VALUES (2)", ()) + .await + .unwrap(); + + conn.execute("INSERT INTO t VALUES (3)", ()).await.unwrap(); + + // changes are synced from the remote - but remote changes are not propagated locally + runner.push().await.unwrap(); + assert_eq!( + query_rows(&conn, "SELECT * FROM t").await.unwrap(), + vec![ + vec![turso::Value::Integer(1)], + vec![turso::Value::Integer(3)] + ] + ); + + let server_db = protocol.server.db(); + let server_conn = server_db.connect().unwrap(); + assert_eq!( + convert_rows(&mut server_conn.query("SELECT * FROM t", ()).await.unwrap()) + .await + .unwrap(), + vec![ + vec![turso::Value::Integer(1)], + vec![turso::Value::Integer(2)], + vec![turso::Value::Integer(3)], + ] + ); + + conn.execute("INSERT INTO t VALUES (4)", ()).await.unwrap(); + runner.push().await.unwrap(); + assert_eq!( + query_rows(&conn, "SELECT * FROM t").await.unwrap(), + vec![ + vec![turso::Value::Integer(1)], + vec![turso::Value::Integer(3)], + vec![turso::Value::Integer(4)] + ] + ); + + assert_eq!( + convert_rows(&mut server_conn.query("SELECT * FROM t", ()).await.unwrap()) + .await + .unwrap(), + vec![ + vec![turso::Value::Integer(1)], + vec![turso::Value::Integer(2)], + vec![turso::Value::Integer(3)], + vec![turso::Value::Integer(4)], + ] + ); + + runner.pull().await.unwrap(); + assert_eq!( + query_rows(&conn, "SELECT * FROM t").await.unwrap(), + vec![ + vec![turso::Value::Integer(1)], + vec![turso::Value::Integer(2)], + vec![turso::Value::Integer(3)], + vec![turso::Value::Integer(4)] + ] + ); + + assert_eq!( + convert_rows(&mut server_conn.query("SELECT * FROM t", ()).await.unwrap()) + .await + .unwrap(), + vec![ + vec![turso::Value::Integer(1)], + vec![turso::Value::Integer(2)], + vec![turso::Value::Integer(3)], + vec![turso::Value::Integer(4)], + ] + ); + }); + } + + #[test] + pub fn test_sync_single_db_full_syncs() { + deterministic_runtime(async || { + let io: Arc = Arc::new(turso_core::PlatformIO::new().unwrap()); + let dir = tempfile::TempDir::new().unwrap(); + let server_path = dir.path().join("server.db"); + let ctx = Arc::new(TestContext::new(seed_u64())); + let server = TestProtocolIo::new(ctx.clone(), &server_path) + .await + .unwrap(); + let mut runner = TestRunner::new(ctx.clone(), io.clone(), server.clone()); + let local_path = dir.path().join("local.db"); + let opts = DatabaseSyncEngineOpts { + client_name: "id-1".to_string(), + wal_pull_batch_size: 1, + }; + runner.init(local_path, opts).await.unwrap(); + + server + .server + .execute("CREATE TABLE t(x INTEGER PRIMARY KEY)", ()) + .await + .unwrap(); + server + .server + .execute("INSERT INTO t VALUES (1)", ()) + .await + .unwrap(); + + let conn = runner.connect().await.unwrap(); + + // no table in schema before sync from remote (as DB was initialized when remote was empty) + assert!(matches!( + query_rows(&conn, "SELECT * FROM t").await, + Err(x) if x.to_string().contains("no such table: t") + )); + + runner.sync().await.unwrap(); + + assert_eq!( + query_rows(&conn, "SELECT * FROM t").await.unwrap(), + vec![vec![turso::Value::Integer(1)]] + ); + + conn.execute("INSERT INTO t VALUES (2)", ()).await.unwrap(); + runner.sync().await.unwrap(); + + assert_eq!( + query_rows(&conn, "SELECT * FROM t").await.unwrap(), + vec![ + vec![turso::Value::Integer(1)], + vec![turso::Value::Integer(2)] + ] + ); + + conn.execute("INSERT INTO t VALUES (3)", ()).await.unwrap(); + runner.sync().await.unwrap(); + assert_eq!( + query_rows(&conn, "SELECT * FROM t").await.unwrap(), + vec![ + vec![turso::Value::Integer(1)], + vec![turso::Value::Integer(2)], + vec![turso::Value::Integer(3)] + ] + ); + }); + } + + #[test] + pub fn test_sync_multiple_dbs_conflict() { + deterministic_runtime(async || { + let dir = tempfile::TempDir::new().unwrap(); + let server_path = dir.path().join("server.db"); + let ctx = Arc::new(TestContext::new(seed_u64())); + let io: Arc = Arc::new(turso_core::PlatformIO::new().unwrap()); + let protocol = TestProtocolIo::new(ctx.clone(), &server_path) + .await + .unwrap(); + let mut dbs = Vec::new(); + const CLIENTS: usize = 8; + for i in 0..CLIENTS { + let mut runner = TestRunner::new(ctx.clone(), io.clone(), protocol.clone()); + let local_path = dir.path().join(format!("local-{i}.db")); + let opts = DatabaseSyncEngineOpts { + client_name: format!("id-{i}"), + wal_pull_batch_size: 1, + }; + runner.init(local_path, opts).await.unwrap(); + dbs.push(runner); + } + + protocol + .server + .execute("CREATE TABLE t(x INTEGER PRIMARY KEY)", ()) + .await + .unwrap(); + + for db in &mut dbs { + db.pull().await.unwrap(); + } + for (i, db) in dbs.iter().enumerate() { + let conn = db.connect().await.unwrap(); + conn.execute("INSERT INTO t VALUES (?)", (i as i32,)) + .await + .unwrap(); + } + + let try_sync = || async { + let mut tasks = Vec::new(); + for db in &dbs { + tasks.push(async move { db.push().await }); + } + futures::future::join_all(tasks).await + }; + for attempt in 0..CLIENTS { + let results = try_sync().await; + tracing::info!("attempt #{}: {:?}", attempt, results); + assert!(results.iter().filter(|x| x.is_ok()).count() > attempt); + } + }); + } + + #[test] + pub fn test_sync_multiple_clients_no_conflicts_synchronized() { + deterministic_runtime(async || { + let dir = tempfile::TempDir::new().unwrap(); + let server_path = dir.path().join("server.db"); + let io: Arc = Arc::new(turso_core::PlatformIO::new().unwrap()); + let ctx = Arc::new(TestContext::new(seed_u64())); + let protocol = TestProtocolIo::new(ctx.clone(), &server_path) + .await + .unwrap(); + + protocol + .server + .execute("CREATE TABLE t(k INTEGER PRIMARY KEY, v)", ()) + .await + .unwrap(); + + let sync_lock = Arc::new(tokio::sync::Mutex::new(())); + let mut clients = Vec::new(); + const CLIENTS: usize = 10; + let mut expected_rows = Vec::new(); + for i in 0..CLIENTS { + let mut queries = Vec::new(); + let cnt = ctx.rng().await.next_u32() % CLIENTS as u32 + 1; + for q in 0..cnt { + let key = i * CLIENTS + q as usize; + let length = ctx.rng().await.next_u32() % 4096; + queries.push(format!( + "INSERT INTO t VALUES ({key}, randomblob({length}))", + )); + expected_rows.push(vec![ + turso::Value::Integer(key as i64), + turso::Value::Integer(length as i64), + ]); + } + clients.push({ + let io = io.clone(); + let dir = dir.path().to_path_buf().clone(); + let ctx = ctx.clone(); + let server = protocol.clone(); + let sync_lock = sync_lock.clone(); + async move { + let mut runner = TestRunner::new(ctx.clone(), io.clone(), server.clone()); + let local_path = dir.join(format!("local-{i}.db")); + let opts = DatabaseSyncEngineOpts { + client_name: format!("id-{i}"), + wal_pull_batch_size: 1, + }; + runner.init(local_path, opts).await.unwrap(); + runner.pull().await.unwrap(); + let conn = runner.connect().await.unwrap(); + for query in queries { + conn.execute(&query, ()).await.unwrap(); + } + let guard = sync_lock.lock().await; + runner.push().await.unwrap(); + drop(guard); + } + }); + } + for client in clients { + client.await; + } + let db = protocol.server.db(); + let conn = db.connect().unwrap(); + let mut result = conn.query("SELECT k, length(v) FROM t", ()).await.unwrap(); + let rows = convert_rows(&mut result).await.unwrap(); + assert_eq!(rows, expected_rows); + }); + } + + #[test] + pub fn test_sync_single_db_sync_from_remote_nothing_single_failure() { + deterministic_runtime(async || { + let dir = tempfile::TempDir::new().unwrap(); + let server_path = dir.path().join("server.db"); + let io: Arc = Arc::new(turso_core::PlatformIO::new().unwrap()); + let ctx = Arc::new(TestContext::new(seed_u64())); + let protocol = TestProtocolIo::new(ctx.clone(), &server_path) + .await + .unwrap(); + + protocol + .server + .execute("CREATE TABLE t(x)", ()) + .await + .unwrap(); + protocol + .server + .execute("INSERT INTO t VALUES (1), (2), (3)", ()) + .await + .unwrap(); + + let mut session = ctx.fault_session(); + let mut it = 0; + while let Some(strategy) = session.next().await { + it += 1; + + let mut runner = TestRunner::new(ctx.clone(), io.clone(), protocol.clone()); + let local_path = dir.path().join(format!("local-{it}.db")); + let opts = DatabaseSyncEngineOpts { + client_name: format!("id-{it}"), + wal_pull_batch_size: 1, + }; + runner.init(local_path, opts).await.unwrap(); + + let has_fault = matches!(strategy, FaultInjectionStrategy::Enabled { .. }); + + ctx.switch_mode(strategy).await; + let result = runner.pull().await; + ctx.switch_mode(FaultInjectionStrategy::Disabled).await; + + if !has_fault { + result.unwrap(); + } else { + let err = result.err().unwrap(); + tracing::info!("error after fault injection: {}", err); + } + + let conn = runner.connect().await.unwrap(); + let rows = query_rows(&conn, "SELECT * FROM t").await.unwrap(); + assert_eq!( + rows, + vec![ + vec![turso::Value::Integer(1)], + vec![turso::Value::Integer(2)], + vec![turso::Value::Integer(3)], + ] + ); + + runner.pull().await.unwrap(); + + let rows = query_rows(&conn, "SELECT * FROM t").await.unwrap(); + assert_eq!( + rows, + vec![ + vec![turso::Value::Integer(1)], + vec![turso::Value::Integer(2)], + vec![turso::Value::Integer(3)], + ] + ); + } + }); + } + + #[test] + pub fn test_sync_single_db_sync_from_remote_single_failure() { + deterministic_runtime(async || { + let dir = tempfile::TempDir::new().unwrap(); + let io: Arc = Arc::new(turso_core::PlatformIO::new().unwrap()); + let ctx = Arc::new(TestContext::new(seed_u64())); + + let mut session = ctx.fault_session(); + let mut it = 0; + while let Some(strategy) = session.next().await { + it += 1; + + let server_path = dir.path().join(format!("server-{it}.db")); + let protocol = TestProtocolIo::new(ctx.clone(), &server_path) + .await + .unwrap(); + + protocol + .server + .execute("CREATE TABLE t(x)", ()) + .await + .unwrap(); + + let mut runner = TestRunner::new(ctx.clone(), io.clone(), protocol.clone()); + let local_path = dir.path().join(format!("local-{it}.db")); + let opts = DatabaseSyncEngineOpts { + client_name: format!("id-{it}"), + wal_pull_batch_size: 1, + }; + runner.init(local_path, opts).await.unwrap(); + + protocol + .server + .execute("INSERT INTO t VALUES (1), (2), (3)", ()) + .await + .unwrap(); + + let has_fault = matches!(strategy, FaultInjectionStrategy::Enabled { .. }); + + ctx.switch_mode(strategy).await; + let result = runner.pull().await; + ctx.switch_mode(FaultInjectionStrategy::Disabled).await; + + if !has_fault { + result.unwrap(); + } else { + let err = result.err().unwrap(); + tracing::info!("error after fault injection: {}", err); + } + + let conn = runner.connect().await.unwrap(); + let rows = query_rows(&conn, "SELECT * FROM t").await.unwrap(); + assert!(rows.len() <= 3); + + runner.pull().await.unwrap(); + + let rows = query_rows(&conn, "SELECT * FROM t").await.unwrap(); + assert_eq!( + rows, + vec![ + vec![turso::Value::Integer(1)], + vec![turso::Value::Integer(2)], + vec![turso::Value::Integer(3)], + ] + ); + } + }); + } + + #[test] + pub fn test_sync_single_db_sync_to_remote_single_failure() { + deterministic_runtime(async || { + let dir = tempfile::TempDir::new().unwrap(); + let io: Arc = Arc::new(turso_core::PlatformIO::new().unwrap()); + let ctx = Arc::new(TestContext::new(seed_u64())); + + let mut session = ctx.fault_session(); + let mut it = 0; + while let Some(strategy) = session.next().await { + it += 1; + + let server_path = dir.path().join(format!("server-{it}.db")); + let protocol = TestProtocolIo::new(ctx.clone(), &server_path) + .await + .unwrap(); + + protocol + .server + .execute("CREATE TABLE t(x INTEGER PRIMARY KEY)", ()) + .await + .unwrap(); + + protocol + .server + .execute("INSERT INTO t VALUES (1)", ()) + .await + .unwrap(); + + let mut runner = TestRunner::new(ctx.clone(), io.clone(), protocol.clone()); + let local_path = dir.path().join(format!("local-{it}.db")); + let opts = DatabaseSyncEngineOpts { + client_name: format!("id-{it}"), + wal_pull_batch_size: 1, + }; + runner.init(local_path, opts).await.unwrap(); + + let conn = runner.connect().await.unwrap(); + + conn.execute("INSERT INTO t VALUES (2), (3)", ()) + .await + .unwrap(); + + let has_fault = matches!(strategy, FaultInjectionStrategy::Enabled { .. }); + + ctx.switch_mode(strategy).await; + let result = runner.push().await; + ctx.switch_mode(FaultInjectionStrategy::Disabled).await; + + if !has_fault { + result.unwrap(); + } else { + let err = result.err().unwrap(); + tracing::info!("error after fault injection: {}", err); + } + + let server_db = protocol.server.db(); + let server_conn = server_db.connect().unwrap(); + let rows = + convert_rows(&mut server_conn.query("SELECT * FROM t", ()).await.unwrap()) + .await + .unwrap(); + assert!(rows.len() <= 3); + + runner.push().await.unwrap(); + + let rows = + convert_rows(&mut server_conn.query("SELECT * FROM t", ()).await.unwrap()) + .await + .unwrap(); + assert_eq!( + rows, + vec![ + vec![turso::Value::Integer(1)], + vec![turso::Value::Integer(2)], + vec![turso::Value::Integer(3)], + ] + ); + } + }); + } +} diff --git a/packages/turso-sync-engine/src/database_sync_operations.rs b/packages/turso-sync-engine/src/database_sync_operations.rs new file mode 100644 index 000000000..eff3d8826 --- /dev/null +++ b/packages/turso-sync-engine/src/database_sync_operations.rs @@ -0,0 +1,732 @@ +use std::{rc::Rc, sync::Arc}; + +use turso_core::{types::Text, Buffer, Completion, LimboError, Value}; + +use crate::{ + database_tape::{ + exec_stmt, run_stmt, DatabaseChangesIteratorMode, DatabaseChangesIteratorOpts, + DatabaseReplaySessionOpts, DatabaseTape, DatabaseWalSession, + }, + errors::Error, + protocol_io::{DataCompletion, DataPollResult, ProtocolIO}, + types::{Coro, DatabaseTapeOperation, DbSyncInfo, DbSyncStatus, ProtocolCommand}, + wal_session::WalSession, + Result, +}; + +pub const WAL_HEADER: usize = 32; +pub const WAL_FRAME_HEADER: usize = 24; +const PAGE_SIZE: usize = 4096; +const WAL_FRAME_SIZE: usize = WAL_FRAME_HEADER + PAGE_SIZE; + +enum WalHttpPullResult { + Frames(C), + NeedCheckpoint(DbSyncStatus), +} + +pub enum WalPullResult { + Done, + PullMore, + NeedCheckpoint, +} + +/// Bootstrap multiple DB files from latest generation from remote +pub async fn db_bootstrap( + coro: &Coro, + client: &C, + dbs: &[Arc], +) -> Result { + tracing::debug!("db_bootstrap"); + let start_time = std::time::Instant::now(); + let db_info = db_info_http(coro, client).await?; + tracing::debug!("db_bootstrap: fetched db_info={db_info:?}"); + let content = db_bootstrap_http(coro, client, db_info.current_generation).await?; + let mut pos = 0; + loop { + while let Some(chunk) = content.poll_data()? { + let chunk = chunk.data(); + let content_len = chunk.len(); + // todo(sivukhin): optimize allocations here + #[allow(clippy::arc_with_non_send_sync)] + let buffer = Arc::new(Buffer::allocate(chunk.len(), Rc::new(|_| {}))); + buffer.as_mut_slice().copy_from_slice(chunk); + let mut completions = Vec::with_capacity(dbs.len()); + for db in dbs { + let c = Completion::new_write(move |size| { + // todo(sivukhin): we need to error out in case of partial read + assert!(size as usize == content_len); + }); + completions.push(db.pwrite(pos, buffer.clone(), c)?); + } + while !completions.iter().all(|x| x.is_completed()) { + coro.yield_(ProtocolCommand::IO).await?; + } + pos += content_len; + } + if content.is_done()? { + break; + } + coro.yield_(ProtocolCommand::IO).await?; + } + + let elapsed = std::time::Instant::now().duration_since(start_time); + tracing::debug!("db_bootstrap: finished: bytes={pos}, elapsed={:?}", elapsed); + + Ok(db_info) +} + +/// Pull updates from remote to the database file +/// +/// Returns [WalPullResult::Done] if pull reached the end of database history +/// Returns [WalPullResult::PullMore] if all frames from [start_frame..end_frame) range were pulled, but remote have more +/// Returns [WalPullResult::NeedCheckpoint] if remote generation increased and local version must be checkpointed +/// +/// Guarantees: +/// 1. Frames are commited to the WAL (i.e. db_size is not zero 0) only at transaction boundaries from remote +/// 2. wal_pull is idempotent for fixed generation and can be called multiple times with same frame range +pub async fn wal_pull<'a, C: ProtocolIO, U: AsyncFnMut(&'a Coro, u64) -> Result<()>>( + coro: &'a Coro, + client: &C, + wal_session: &mut WalSession, + generation: u64, + mut start_frame: u64, + end_frame: u64, + mut update: U, +) -> Result { + tracing::debug!( + "wal_pull: generation={}, start_frame={}, end_frame={}", + generation, + start_frame, + end_frame + ); + + // todo(sivukhin): optimize allocation by using buffer pool in the DatabaseSyncOperations + let mut buffer = Vec::with_capacity(WAL_FRAME_SIZE); + + let result = wal_pull_http(coro, client, generation, start_frame, end_frame).await?; + let data = match result { + WalHttpPullResult::NeedCheckpoint(status) => { + assert!(status.status == "checkpoint_needed"); + tracing::debug!("wal_pull: need checkpoint: status={status:?}"); + if status.generation == generation && status.max_frame_no < start_frame { + tracing::debug!("wal_pull: end of history: status={:?}", status); + update(coro, status.max_frame_no).await?; + return Ok(WalPullResult::Done); + } + return Ok(WalPullResult::NeedCheckpoint); + } + WalHttpPullResult::Frames(content) => content, + }; + loop { + while let Some(chunk) = data.poll_data()? { + let mut chunk = chunk.data(); + while !chunk.is_empty() { + let to_fill = WAL_FRAME_SIZE - buffer.len(); + buffer.extend_from_slice(&chunk[0..to_fill]); + chunk = &chunk[to_fill..]; + + assert!( + buffer.capacity() == WAL_FRAME_SIZE, + "buffer should not extend its capacity" + ); + if buffer.len() < WAL_FRAME_SIZE { + continue; + } + if !wal_session.in_txn() { + wal_session.begin()?; + } + let frame_info = wal_session.insert_at(start_frame, &buffer)?; + if frame_info.is_commit_frame() { + wal_session.end()?; + // transaction boundary reached - safe to commit progress + update(coro, start_frame).await?; + } + buffer.clear(); + start_frame += 1; + } + } + if data.is_done()? { + break; + } + coro.yield_(ProtocolCommand::IO).await?; + } + if !buffer.is_empty() { + return Err(Error::DatabaseSyncEngineError(format!( + "wal_pull: response has unexpected trailing data: buffer.len()={}", + buffer.len() + ))); + } + Ok(WalPullResult::PullMore) +} + +/// Push frame range [start_frame..end_frame) to the remote +/// Returns baton for WAL remote-session in case of success +/// Returns [Error::DatabaseSyncEngineConflict] in case of frame conflict at remote side +/// +/// Guarantees: +/// 1. If there is a single client which calls wal_push, then this operation is idempotent for fixed generation +/// and can be called multiple times with same frame range +pub async fn wal_push( + coro: &Coro, + client: &C, + wal_session: &mut WalSession, + baton: Option, + generation: u64, + start_frame: u64, + end_frame: u64, +) -> Result> { + assert!(wal_session.in_txn()); + tracing::debug!("wal_push: baton={baton:?}, generation={generation}, start_frame={start_frame}, end_frame={end_frame}"); + + if start_frame == end_frame { + return Ok(None); + } + + let mut frames_data = Vec::with_capacity((end_frame - start_frame) as usize * WAL_FRAME_SIZE); + let mut buffer = [0u8; WAL_FRAME_SIZE]; + for frame_no in start_frame..end_frame { + let frame_info = wal_session.read_at(frame_no, &mut buffer)?; + tracing::trace!( + "wal_push: collect frame {} ({:?}) for push", + frame_no, + frame_info + ); + frames_data.extend_from_slice(&buffer); + } + + let status = wal_push_http( + coro, + client, + None, + generation, + start_frame, + end_frame, + frames_data, + ) + .await?; + if status.status == "conflict" { + return Err(Error::DatabaseSyncEngineConflict(format!( + "wal_push conflict: {status:?}" + ))); + } + if status.status != "ok" { + return Err(Error::DatabaseSyncEngineError(format!( + "wal_push unexpected status: {status:?}" + ))); + } + Ok(status.baton) +} + +const TURSO_SYNC_META_TABLE: &str = + "CREATE TABLE IF NOT EXISTS turso_sync_last_change_id (client_id TEXT PRIMARY KEY, pull_gen INTEGER, change_id INTEGER)"; +const TURSO_SYNC_SELECT_LAST_CHANGE_ID: &str = + "SELECT pull_gen, change_id FROM turso_sync_last_change_id WHERE client_id = ?"; +const TURSO_SYNC_INSERT_LAST_CHANGE_ID: &str = + "INSERT INTO turso_sync_last_change_id(client_id, pull_gen, change_id) VALUES (?, 0, 0)"; +const TURSO_SYNC_UPDATE_LAST_CHANGE_ID: &str = + "UPDATE turso_sync_last_change_id SET pull_gen = ?, change_id = ? WHERE client_id = ?"; + +/// Transfers row changes from source DB to target DB +/// In order to guarantee atomicity and avoid conflicts - method maintain last_change_id counter in the target db table turso_sync_last_change_id +pub async fn transfer_logical_changes( + coro: &Coro, + source: &DatabaseTape, + target: &DatabaseTape, + client_id: &str, + bump_pull_gen: bool, +) -> Result<()> { + tracing::debug!("transfer_logical_changes: client_id={client_id}"); + let source_conn = source.connect_untracked()?; + let target_conn = target.connect_untracked()?; + + // fetch last_change_id from the target DB in order to guarantee atomic replay of changes and avoid conflicts in case of failure + let source_pull_gen = 'source_pull_gen: { + let mut select_last_change_id_stmt = + match source_conn.prepare(TURSO_SYNC_SELECT_LAST_CHANGE_ID) { + Ok(stmt) => stmt, + Err(LimboError::ParseError(..)) => break 'source_pull_gen 0, + Err(err) => return Err(err.into()), + }; + + select_last_change_id_stmt + .bind_at(1.try_into().unwrap(), Value::Text(Text::new(client_id))); + + match run_stmt(coro, &mut select_last_change_id_stmt).await? { + Some(row) => row.get_value(0).as_int().ok_or_else(|| { + Error::DatabaseSyncEngineError("unexpected source pull_gen type".to_string()) + })?, + None => { + tracing::debug!("transfer_logical_changes: client_id={client_id}, turso_sync_last_change_id table is not found"); + 0 + } + } + }; + tracing::debug!( + "transfer_logical_changes: client_id={client_id}, source_pull_gen={source_pull_gen}" + ); + + // fetch last_change_id from the target DB in order to guarantee atomic replay of changes and avoid conflicts in case of failure + let mut schema_stmt = target_conn.prepare(TURSO_SYNC_META_TABLE)?; + exec_stmt(coro, &mut schema_stmt).await?; + + let mut select_last_change_id_stmt = target_conn.prepare(TURSO_SYNC_SELECT_LAST_CHANGE_ID)?; + select_last_change_id_stmt.bind_at(1.try_into().unwrap(), Value::Text(Text::new(client_id))); + + let mut last_change_id = match run_stmt(coro, &mut select_last_change_id_stmt).await? { + Some(row) => { + let target_pull_gen = row.get_value(0).as_int().ok_or_else(|| { + Error::DatabaseSyncEngineError("unexpected target pull_gen type".to_string()) + })?; + let target_change_id = row.get_value(1).as_int().ok_or_else(|| { + Error::DatabaseSyncEngineError("unexpected target change_id type".to_string()) + })?; + tracing::debug!( + "transfer_logical_changes: client_id={client_id}, target_pull_gen={target_pull_gen}, target_change_id={target_change_id}" + ); + if target_pull_gen > source_pull_gen { + return Err(Error::DatabaseSyncEngineError(format!("protocol error: target_pull_gen > source_pull_gen: {target_pull_gen} > {source_pull_gen}"))); + } + if target_pull_gen == source_pull_gen { + Some(target_change_id) + } else { + Some(0) + } + } + None => { + let mut insert_last_change_id_stmt = + target_conn.prepare(TURSO_SYNC_INSERT_LAST_CHANGE_ID)?; + insert_last_change_id_stmt + .bind_at(1.try_into().unwrap(), Value::Text(Text::new(client_id))); + exec_stmt(coro, &mut insert_last_change_id_stmt).await?; + None + } + }; + + tracing::debug!( + "transfer_logical_changes: last_change_id={:?}", + last_change_id + ); + let replay_opts = DatabaseReplaySessionOpts { + use_implicit_rowid: false, + }; + let mut session = target.start_replay_session(coro, replay_opts).await?; + + let iterate_opts = DatabaseChangesIteratorOpts { + first_change_id: last_change_id.map(|x| x + 1), + mode: DatabaseChangesIteratorMode::Apply, + ..Default::default() + }; + let mut changes = source.iterate_changes(iterate_opts)?; + let mut updated = false; + while let Some(operation) = changes.next(coro).await? { + match &operation { + DatabaseTapeOperation::RowChange(change) => { + assert!( + last_change_id.is_none() || last_change_id.unwrap() < change.change_id, + "change id must be strictly increasing: last_change_id={:?}, change.change_id={}", + last_change_id, + change.change_id + ); + // we give user full control over CDC table - so let's not emit assert here for now + if last_change_id.is_some() && last_change_id.unwrap() + 1 != change.change_id { + tracing::warn!( + "out of order change sequence: {} -> {}", + last_change_id.unwrap(), + change.change_id + ); + } + last_change_id = Some(change.change_id); + updated = true; + } + DatabaseTapeOperation::Commit if updated || bump_pull_gen => { + tracing::debug!("prepare update stmt for turso_sync_last_change_id table with client_id={} and last_change_id={:?}", client_id, last_change_id); + // update turso_sync_last_change_id table with new value before commit + let mut set_last_change_id_stmt = + session.conn().prepare(TURSO_SYNC_UPDATE_LAST_CHANGE_ID)?; + let (next_pull_gen, next_change_id) = if bump_pull_gen { + (source_pull_gen + 1, 0) + } else { + (source_pull_gen, last_change_id.unwrap_or(0)) + }; + tracing::debug!("transfer_logical_changes: client_id={client_id}, set pull_gen={next_pull_gen}, change_id={next_change_id}"); + set_last_change_id_stmt + .bind_at(1.try_into().unwrap(), Value::Integer(next_pull_gen)); + set_last_change_id_stmt + .bind_at(2.try_into().unwrap(), Value::Integer(next_change_id)); + set_last_change_id_stmt + .bind_at(3.try_into().unwrap(), Value::Text(Text::new(client_id))); + exec_stmt(coro, &mut set_last_change_id_stmt).await?; + } + _ => {} + } + session.replay(coro, operation).await?; + } + + Ok(()) +} + +/// Replace WAL frames [target_wal_match_watermark..) in the target DB with frames [source_wal_match_watermark..) from source DB +/// Return the position in target DB wal which logically equivalent to the source_sync_watermark in the source DB WAL +pub async fn transfer_physical_changes( + coro: &Coro, + source: &DatabaseTape, + target_session: WalSession, + source_wal_match_watermark: u64, + source_sync_watermark: u64, + target_wal_match_watermark: u64, +) -> Result { + tracing::debug!("transfer_physical_changes: source_wal_match_watermark={source_wal_match_watermark}, source_sync_watermark={source_sync_watermark}, target_wal_match_watermark={target_wal_match_watermark}"); + + let source_conn = source.connect(coro).await?; + let mut source_session = WalSession::new(source_conn.clone()); + source_session.begin()?; + + let source_frames_count = source_conn.wal_frame_count()?; + assert!( + source_frames_count >= source_wal_match_watermark, + "watermark can't be greater than current frames count: {source_frames_count} vs {source_wal_match_watermark}", + ); + if source_frames_count == source_wal_match_watermark { + assert!(source_sync_watermark == source_wal_match_watermark); + return Ok(target_wal_match_watermark); + } + assert!( + (source_wal_match_watermark..=source_frames_count).contains(&source_sync_watermark), + "source_sync_watermark={source_sync_watermark} must be in range: {source_wal_match_watermark}..={source_frames_count}", + ); + + let target_sync_watermark = { + let mut target_session = DatabaseWalSession::new(coro, target_session).await?; + target_session.rollback_changes_after(target_wal_match_watermark)?; + let mut last_frame_info = None; + let mut frame = vec![0u8; WAL_FRAME_SIZE]; + let mut target_sync_watermark = target_session.frames_count()?; + tracing::debug!( + "transfer_physical_changes: start={}, end={}", + source_wal_match_watermark + 1, + source_frames_count + ); + for source_frame_no in source_wal_match_watermark + 1..=source_frames_count { + let frame_info = source_conn.wal_get_frame(source_frame_no, &mut frame)?; + tracing::trace!("append page {} to target DB", frame_info.page_no); + target_session.append_page(frame_info.page_no, &frame[WAL_FRAME_HEADER..])?; + if source_frame_no == source_sync_watermark { + target_sync_watermark = target_session.frames_count()? + 1; // +1 because page will be actually commited on next iteration + tracing::debug!("set target_sync_watermark to {}", target_sync_watermark); + } + last_frame_info = Some(frame_info); + } + let db_size = last_frame_info.unwrap().db_size; + tracing::trace!("commit WAL session to target with db_size={db_size}"); + target_session.commit(db_size)?; + assert!(target_sync_watermark != 0); + target_sync_watermark + }; + Ok(target_sync_watermark) +} + +pub async fn reset_wal_file( + coro: &Coro, + wal: Arc, + frames_count: u64, +) -> Result<()> { + let wal_size = if frames_count == 0 { + // let's truncate WAL file completely in order for this operation to safely execute on empty WAL in case of initial bootstrap phase + 0 + } else { + WAL_HEADER + WAL_FRAME_SIZE * (frames_count as usize) + }; + tracing::debug!("reset db wal to the size of {} frames", frames_count); + let c = Completion::new_trunc(move |rc| { + assert!(rc as usize == 0); + }); + let c = wal.truncate(wal_size, c)?; + while !c.is_completed() { + coro.yield_(ProtocolCommand::IO).await?; + } + Ok(()) +} + +async fn wal_pull_http( + coro: &Coro, + client: &C, + generation: u64, + start_frame: u64, + end_frame: u64, +) -> Result> { + let completion = client.http( + http::Method::GET, + format!("/sync/{generation}/{start_frame}/{end_frame}"), + None, + )?; + let status = wait_status(coro, &completion).await?; + if status == http::StatusCode::BAD_REQUEST { + let status_body = wait_full_body(coro, &completion).await?; + let status: DbSyncStatus = serde_json::from_slice(&status_body)?; + if status.status == "checkpoint_needed" { + return Ok(WalHttpPullResult::NeedCheckpoint(status)); + } else { + let error = format!("wal_pull: unexpected sync status: {status:?}"); + return Err(Error::DatabaseSyncEngineError(error)); + } + } + if status != http::StatusCode::OK { + let error = format!("wal_pull: unexpected status code: {status}"); + return Err(Error::DatabaseSyncEngineError(error)); + } + Ok(WalHttpPullResult::Frames(completion)) +} + +async fn wal_push_http( + coro: &Coro, + client: &C, + baton: Option, + generation: u64, + start_frame: u64, + end_frame: u64, + frames: Vec, +) -> Result { + let baton = baton + .map(|baton| format!("/{baton}")) + .unwrap_or("".to_string()); + let completion = client.http( + http::Method::POST, + format!("/sync/{generation}/{start_frame}/{end_frame}{baton}"), + Some(frames), + )?; + let status = wait_status(coro, &completion).await?; + let status_body = wait_full_body(coro, &completion).await?; + if status != http::StatusCode::OK { + let error = std::str::from_utf8(&status_body).ok().unwrap_or(""); + return Err(Error::DatabaseSyncEngineError(format!( + "wal_push go unexpected status: {status} (error={error})" + ))); + } + Ok(serde_json::from_slice(&status_body)?) +} + +async fn db_info_http(coro: &Coro, client: &C) -> Result { + let completion = client.http(http::Method::GET, "/info".to_string(), None)?; + let status = wait_status(coro, &completion).await?; + let status_body = wait_full_body(coro, &completion).await?; + if status != http::StatusCode::OK { + return Err(Error::DatabaseSyncEngineError(format!( + "db_info go unexpected status: {status}" + ))); + } + Ok(serde_json::from_slice(&status_body)?) +} + +async fn db_bootstrap_http( + coro: &Coro, + client: &C, + generation: u64, +) -> Result { + let completion = client.http(http::Method::GET, format!("/export/{generation}"), None)?; + let status = wait_status(coro, &completion).await?; + if status != http::StatusCode::OK.as_u16() { + return Err(Error::DatabaseSyncEngineError(format!( + "db_bootstrap go unexpected status: {status}" + ))); + } + Ok(completion) +} + +pub async fn wait_status(coro: &Coro, completion: &impl DataCompletion) -> Result { + while completion.status()?.is_none() { + coro.yield_(ProtocolCommand::IO).await?; + } + Ok(completion.status()?.unwrap()) +} + +pub async fn wait_full_body(coro: &Coro, completion: &impl DataCompletion) -> Result> { + let mut bytes = Vec::new(); + loop { + while let Some(poll) = completion.poll_data()? { + bytes.extend_from_slice(poll.data()); + } + if completion.is_done()? { + break; + } + coro.yield_(ProtocolCommand::IO).await?; + } + Ok(bytes) +} + +#[cfg(test)] +pub mod tests { + use std::sync::Arc; + + use tempfile::NamedTempFile; + use turso_core::Value; + + use crate::{ + database_sync_operations::{transfer_logical_changes, transfer_physical_changes}, + database_tape::{run_stmt, DatabaseTape, DatabaseTapeOpts}, + wal_session::WalSession, + Result, + }; + + #[test] + pub fn test_transfer_logical_changes() { + let temp_file1 = NamedTempFile::new().unwrap(); + let db_path1 = temp_file1.path().to_str().unwrap(); + let temp_file2 = NamedTempFile::new().unwrap(); + let db_path2 = temp_file2.path().to_str().unwrap(); + + let io: Arc = Arc::new(turso_core::PlatformIO::new().unwrap()); + let db1 = turso_core::Database::open_file(io.clone(), db_path1, false, true).unwrap(); + let db1 = Arc::new(DatabaseTape::new(db1)); + + let db2 = turso_core::Database::open_file(io.clone(), db_path2, false, true).unwrap(); + let db2 = Arc::new(DatabaseTape::new(db2)); + + let mut gen = genawaiter::sync::Gen::new(|coro| async move { + let conn1 = db1.connect(&coro).await?; + conn1.execute("CREATE TABLE t(x, y)")?; + conn1.execute("INSERT INTO t VALUES (1, 2), (3, 4), (5, 6)")?; + + let conn2 = db2.connect(&coro).await?; + conn2.execute("CREATE TABLE t(x, y)")?; + + transfer_logical_changes(&coro, &db1, &db2, "id-1", false).await?; + + let mut rows = Vec::new(); + let mut stmt = conn2.prepare("SELECT x, y FROM t").unwrap(); + while let Some(row) = run_stmt(&coro, &mut stmt).await.unwrap() { + rows.push(row.get_values().cloned().collect::>()); + } + assert_eq!( + rows, + vec![ + vec![Value::Integer(1), Value::Integer(2)], + vec![Value::Integer(3), Value::Integer(4)], + vec![Value::Integer(5), Value::Integer(6)], + ] + ); + + conn1.execute("INSERT INTO t VALUES (7, 8)")?; + transfer_logical_changes(&coro, &db1, &db2, "id-1", false).await?; + + let mut rows = Vec::new(); + let mut stmt = conn2.prepare("SELECT x, y FROM t").unwrap(); + while let Some(row) = run_stmt(&coro, &mut stmt).await.unwrap() { + rows.push(row.get_values().cloned().collect::>()); + } + assert_eq!( + rows, + vec![ + vec![Value::Integer(1), Value::Integer(2)], + vec![Value::Integer(3), Value::Integer(4)], + vec![Value::Integer(5), Value::Integer(6)], + vec![Value::Integer(7), Value::Integer(8)], + ] + ); + + Result::Ok(()) + }); + loop { + match gen.resume_with(Ok(())) { + genawaiter::GeneratorState::Yielded(..) => io.run_once().unwrap(), + genawaiter::GeneratorState::Complete(result) => { + result.unwrap(); + break; + } + } + } + } + + #[test] + pub fn test_transfer_physical_changes() { + let temp_file1 = NamedTempFile::new().unwrap(); + let db_path1 = temp_file1.path().to_str().unwrap(); + let temp_file2 = NamedTempFile::new().unwrap(); + let db_path2 = temp_file2.path().to_str().unwrap(); + + let opts = DatabaseTapeOpts { + cdc_mode: Some("off".to_string()), + cdc_table: None, + }; + let io: Arc = Arc::new(turso_core::PlatformIO::new().unwrap()); + let db1 = turso_core::Database::open_file(io.clone(), db_path1, false, true).unwrap(); + let db1 = Arc::new(DatabaseTape::new_with_opts(db1, opts.clone())); + + let db2 = turso_core::Database::open_file(io.clone(), db_path2, false, true).unwrap(); + let db2 = Arc::new(DatabaseTape::new_with_opts(db2, opts.clone())); + + let mut gen = genawaiter::sync::Gen::new(|coro| async move { + let conn1 = db1.connect(&coro).await?; + conn1.execute("CREATE TABLE t(x, y)")?; + conn1.execute("INSERT INTO t VALUES (1, 2)")?; + let conn1_match_watermark = conn1.wal_frame_count().unwrap(); + conn1.execute("INSERT INTO t VALUES (3, 4)")?; + let conn1_sync_watermark = conn1.wal_frame_count().unwrap(); + conn1.execute("INSERT INTO t VALUES (5, 6)")?; + + let conn2 = db2.connect(&coro).await?; + conn2.execute("CREATE TABLE t(x, y)")?; + conn2.execute("INSERT INTO t VALUES (1, 2)")?; + let conn2_match_watermark = conn2.wal_frame_count().unwrap(); + conn2.execute("INSERT INTO t VALUES (5, 6)")?; + + // db1 WAL frames: [A1 A2] [A3] [A4] (sync_watermark) [A5] + // db2 WAL frames: [B1 B2] [B3] [B4] + + let session = WalSession::new(conn2); + let conn2_sync_watermark = transfer_physical_changes( + &coro, + &db1, + session, + conn1_match_watermark, + conn1_sync_watermark, + conn2_match_watermark, + ) + .await?; + + // db2 WAL frames: [B1 B2] [B3] [B4] [B4^-1] [A4] (sync_watermark) [A5] + assert_eq!(conn2_sync_watermark, 6); + + let conn2 = db2.connect(&coro).await.unwrap(); + let mut rows = Vec::new(); + let mut stmt = conn2.prepare("SELECT x, y FROM t").unwrap(); + while let Some(row) = run_stmt(&coro, &mut stmt).await.unwrap() { + rows.push(row.get_values().cloned().collect::>()); + } + assert_eq!( + rows, + vec![ + vec![Value::Integer(1), Value::Integer(2)], + vec![Value::Integer(3), Value::Integer(4)], + vec![Value::Integer(5), Value::Integer(6)], + ] + ); + + conn2.execute("INSERT INTO t VALUES (7, 8)")?; + let mut rows = Vec::new(); + let mut stmt = conn2.prepare("SELECT x, y FROM t").unwrap(); + while let Some(row) = run_stmt(&coro, &mut stmt).await.unwrap() { + rows.push(row.get_values().cloned().collect::>()); + } + assert_eq!( + rows, + vec![ + vec![Value::Integer(1), Value::Integer(2)], + vec![Value::Integer(3), Value::Integer(4)], + vec![Value::Integer(5), Value::Integer(6)], + vec![Value::Integer(7), Value::Integer(8)], + ] + ); + + Result::Ok(()) + }); + loop { + match gen.resume_with(Ok(())) { + genawaiter::GeneratorState::Yielded(..) => io.run_once().unwrap(), + genawaiter::GeneratorState::Complete(result) => { + result.unwrap(); + break; + } + } + } + } +} diff --git a/packages/turso-sync-engine/src/database_tape.rs b/packages/turso-sync-engine/src/database_tape.rs new file mode 100644 index 000000000..a2bbddaca --- /dev/null +++ b/packages/turso-sync-engine/src/database_tape.rs @@ -0,0 +1,944 @@ +use std::{ + collections::{HashMap, VecDeque}, + sync::Arc, +}; + +use turso_core::{types::WalFrameInfo, StepResult}; + +use crate::{ + database_sync_operations::WAL_FRAME_HEADER, + errors::Error, + types::{ + Coro, DatabaseChange, DatabaseTapeOperation, DatabaseTapeRowChange, + DatabaseTapeRowChangeType, ProtocolCommand, + }, + wal_session::WalSession, + Result, +}; + +/// Simple wrapper over [turso::Database] which extends its intereface with few methods +/// to collect changes made to the database and apply/revert arbitrary changes to the database +pub struct DatabaseTape { + inner: Arc, + cdc_table: Arc, + pragma_query: String, +} + +const DEFAULT_CDC_TABLE_NAME: &str = "turso_cdc"; +const DEFAULT_CDC_MODE: &str = "full"; +const DEFAULT_CHANGES_BATCH_SIZE: usize = 100; +const CDC_PRAGMA_NAME: &str = "unstable_capture_data_changes_conn"; + +#[derive(Debug, Clone)] +pub struct DatabaseTapeOpts { + pub cdc_table: Option, + pub cdc_mode: Option, +} + +pub(crate) async fn run_stmt<'a>( + coro: &'_ Coro, + stmt: &'a mut turso_core::Statement, +) -> Result> { + loop { + match stmt.step()? { + StepResult::IO => { + coro.yield_(ProtocolCommand::IO).await?; + } + StepResult::Done => { + return Ok(None); + } + StepResult::Interrupt => { + return Err(Error::DatabaseTapeError( + "statement was interrupted".to_string(), + )) + } + StepResult::Busy => { + return Err(Error::DatabaseTapeError("database is busy".to_string())) + } + StepResult::Row => return Ok(Some(stmt.row().unwrap())), + } + } +} + +pub(crate) async fn exec_stmt(coro: &Coro, stmt: &mut turso_core::Statement) -> Result<()> { + loop { + match stmt.step()? { + StepResult::IO => { + coro.yield_(ProtocolCommand::IO).await?; + } + StepResult::Done => { + return Ok(()); + } + StepResult::Interrupt => { + return Err(Error::DatabaseTapeError( + "statement was interrupted".to_string(), + )) + } + StepResult::Busy => { + return Err(Error::DatabaseTapeError("database is busy".to_string())) + } + StepResult::Row => panic!("statement should not return any rows"), + } + } +} + +impl DatabaseTape { + pub fn new(database: Arc) -> Self { + let opts = DatabaseTapeOpts { + cdc_table: None, + cdc_mode: None, + }; + Self::new_with_opts(database, opts) + } + pub fn new_with_opts(database: Arc, opts: DatabaseTapeOpts) -> Self { + tracing::debug!("create local sync database with options {:?}", opts); + let cdc_table_name = opts.cdc_table.unwrap_or(DEFAULT_CDC_TABLE_NAME.to_string()); + let cdc_mode = opts.cdc_mode.unwrap_or(DEFAULT_CDC_MODE.to_string()); + let pragma_query = format!("PRAGMA {CDC_PRAGMA_NAME}('{cdc_mode},{cdc_table_name}')"); + Self { + inner: database, + cdc_table: Arc::new(cdc_table_name.to_string()), + pragma_query, + } + } + pub(crate) fn connect_untracked(&self) -> Result> { + let connection = self.inner.connect()?; + Ok(connection) + } + pub async fn connect(&self, coro: &Coro) -> Result> { + let connection = self.inner.connect()?; + tracing::debug!("set '{CDC_PRAGMA_NAME}' for new connection"); + let mut stmt = connection.prepare(&self.pragma_query)?; + run_stmt(coro, &mut stmt).await?; + Ok(connection) + } + /// Builds an iterator which emits [DatabaseTapeOperation] by extracting data from CDC table + pub fn iterate_changes( + &self, + opts: DatabaseChangesIteratorOpts, + ) -> Result { + tracing::debug!("opening changes iterator with options {:?}", opts); + let conn = self.inner.connect()?; + let query = opts.mode.query(&self.cdc_table, opts.batch_size); + let query_stmt = conn.prepare(&query)?; + Ok(DatabaseChangesIterator { + first_change_id: opts.first_change_id, + batch: VecDeque::with_capacity(opts.batch_size), + query_stmt, + txn_boundary_returned: false, + mode: opts.mode, + ignore_schema_changes: opts.ignore_schema_changes, + }) + } + /// Start raw WAL edit session which can append or rollback pages directly in the current WAL + pub async fn start_wal_session(&self, coro: &Coro) -> Result { + let conn = self.connect(coro).await?; + let mut wal_session = WalSession::new(conn); + wal_session.begin()?; + DatabaseWalSession::new(coro, wal_session).await + } + + /// Start replay session which can apply [DatabaseTapeOperation] from [Self::iterate_changes] + pub async fn start_replay_session( + &self, + coro: &Coro, + opts: DatabaseReplaySessionOpts, + ) -> Result { + tracing::debug!("opening replay session"); + Ok(DatabaseReplaySession { + conn: self.connect(coro).await?, + cached_delete_stmt: HashMap::new(), + cached_insert_stmt: HashMap::new(), + in_txn: false, + opts, + }) + } +} + +pub struct DatabaseWalSession { + page_size: usize, + next_wal_frame_no: u64, + wal_session: WalSession, + prepared_frame: Option<(u32, Vec)>, +} + +impl DatabaseWalSession { + pub async fn new(coro: &Coro, wal_session: WalSession) -> Result { + let conn = wal_session.conn(); + let frames_count = conn.wal_frame_count()?; + let mut page_size_stmt = conn.prepare("PRAGMA page_size")?; + let Some(row) = run_stmt(coro, &mut page_size_stmt).await? else { + return Err(Error::DatabaseTapeError( + "unable to get database page size".to_string(), + )); + }; + if row.len() != 1 { + return Err(Error::DatabaseTapeError( + "unexpected columns count for PRAGMA page_size query".to_string(), + )); + } + let turso_core::Value::Integer(page_size) = row.get_value(0) else { + return Err(Error::DatabaseTapeError( + "unexpected column type for PRAGMA page_size query".to_string(), + )); + }; + let page_size = *page_size; + let None = run_stmt(coro, &mut page_size_stmt).await? else { + return Err(Error::DatabaseTapeError( + "page size pragma returned multiple rows".to_string(), + )); + }; + + Ok(Self { + page_size: page_size as usize, + next_wal_frame_no: frames_count + 1, + wal_session, + prepared_frame: None, + }) + } + + pub fn frames_count(&self) -> Result { + Ok(self.wal_session.conn().wal_frame_count()?) + } + + pub fn append_page(&mut self, page_no: u32, page: &[u8]) -> Result<()> { + if page.len() != self.page_size { + return Err(Error::DatabaseTapeError(format!( + "page.len() must be equal to page_size: {} != {}", + page.len(), + self.page_size + ))); + } + self.flush_prepared_frame(0)?; + + let mut frame = vec![0u8; WAL_FRAME_HEADER + self.page_size]; + frame[WAL_FRAME_HEADER..].copy_from_slice(page); + self.prepared_frame = Some((page_no, frame)); + + Ok(()) + } + + pub fn rollback_page(&mut self, page_no: u32, frame_watermark: u64) -> Result<()> { + self.flush_prepared_frame(0)?; + + let conn = self.wal_session.conn(); + let mut frame = vec![0u8; WAL_FRAME_HEADER + self.page_size]; + if conn.try_wal_watermark_read_page( + page_no, + &mut frame[WAL_FRAME_HEADER..], + Some(frame_watermark), + )? { + tracing::trace!("rollback page {}", page_no); + self.prepared_frame = Some((page_no, frame)); + } else { + tracing::trace!( + "skip rollback page {} as no page existed with given watermark", + page_no + ); + } + + Ok(()) + } + + pub fn rollback_changes_after(&mut self, frame_watermark: u64) -> Result<()> { + let conn = self.wal_session.conn(); + let pages = conn.wal_changed_pages_after(frame_watermark)?; + for page_no in pages { + self.rollback_page(page_no, frame_watermark)?; + } + Ok(()) + } + + pub fn db_size(&self) -> Result { + let frames_count = self.frames_count()?; + let conn = self.wal_session.conn(); + let mut page = vec![0u8; self.page_size]; + assert!(conn.try_wal_watermark_read_page(1, &mut page, Some(frames_count))?); + let db_size = u32::from_be_bytes(page[28..32].try_into().unwrap()); + Ok(db_size) + } + + pub fn commit(&mut self, db_size: u32) -> Result<()> { + self.flush_prepared_frame(db_size) + } + + fn flush_prepared_frame(&mut self, db_size: u32) -> Result<()> { + let Some((page_no, mut frame)) = self.prepared_frame.take() else { + return Ok(()); + }; + + let frame_info = WalFrameInfo { db_size, page_no }; + frame_info.put_to_frame_header(&mut frame); + + let frame_no = self.next_wal_frame_no; + tracing::trace!( + "flush prepared frame {:?} as frame_no {}", + frame_info, + frame_no + ); + self.wal_session.conn().wal_insert_frame(frame_no, &frame)?; + self.next_wal_frame_no += 1; + + Ok(()) + } +} + +#[derive(Debug)] +pub enum DatabaseChangesIteratorMode { + Apply, + Revert, +} + +impl DatabaseChangesIteratorMode { + pub fn query(&self, table_name: &str, limit: usize) -> String { + let (operation, order) = match self { + DatabaseChangesIteratorMode::Apply => (">=", "ASC"), + DatabaseChangesIteratorMode::Revert => ("<=", "DESC"), + }; + format!( + "SELECT * FROM {table_name} WHERE change_id {operation} ? ORDER BY change_id {order} LIMIT {limit}", + ) + } + pub fn first_id(&self) -> i64 { + match self { + DatabaseChangesIteratorMode::Apply => -1, + DatabaseChangesIteratorMode::Revert => i64::MAX, + } + } + pub fn next_id(&self, id: i64) -> i64 { + match self { + DatabaseChangesIteratorMode::Apply => id + 1, + DatabaseChangesIteratorMode::Revert => id - 1, + } + } +} + +#[derive(Debug)] +pub struct DatabaseChangesIteratorOpts { + pub first_change_id: Option, + pub batch_size: usize, + pub mode: DatabaseChangesIteratorMode, + pub ignore_schema_changes: bool, +} + +impl Default for DatabaseChangesIteratorOpts { + fn default() -> Self { + Self { + first_change_id: None, + batch_size: DEFAULT_CHANGES_BATCH_SIZE, + mode: DatabaseChangesIteratorMode::Apply, + ignore_schema_changes: true, + } + } +} + +pub struct DatabaseChangesIterator { + query_stmt: turso_core::Statement, + first_change_id: Option, + batch: VecDeque, + txn_boundary_returned: bool, + mode: DatabaseChangesIteratorMode, + ignore_schema_changes: bool, +} + +impl DatabaseChangesIterator { + pub async fn next(&mut self, coro: &Coro) -> Result> { + if self.batch.is_empty() { + self.refill(coro).await?; + } + // todo(sivukhin): iterator must be more clever about transaction boundaries - but for that we need to extend CDC table + // for now, if iterator reach the end of CDC table - we are sure that this is a transaction boundary + loop { + let next = if let Some(change) = self.batch.pop_front() { + self.txn_boundary_returned = false; + Some(DatabaseTapeOperation::RowChange(change)) + } else if !self.txn_boundary_returned { + self.txn_boundary_returned = true; + Some(DatabaseTapeOperation::Commit) + } else { + None + }; + if let Some(DatabaseTapeOperation::RowChange(change)) = &next { + if self.ignore_schema_changes && change.table_name == "sqlite_schema" { + continue; + } + } + return Ok(next); + } + } + async fn refill(&mut self, coro: &Coro) -> Result<()> { + let change_id_filter = self.first_change_id.unwrap_or(self.mode.first_id()); + self.query_stmt.reset(); + self.query_stmt.bind_at( + 1.try_into().unwrap(), + turso_core::Value::Integer(change_id_filter), + ); + + while let Some(row) = run_stmt(coro, &mut self.query_stmt).await? { + let database_change: DatabaseChange = row.try_into()?; + let tape_change = match self.mode { + DatabaseChangesIteratorMode::Apply => database_change.into_apply()?, + DatabaseChangesIteratorMode::Revert => database_change.into_revert()?, + }; + self.batch.push_back(tape_change); + } + let batch_len = self.batch.len(); + if batch_len > 0 { + self.first_change_id = Some(self.mode.next_id(self.batch[batch_len - 1].change_id)); + } + Ok(()) + } +} + +#[derive(Debug)] +pub struct DatabaseReplaySessionOpts { + pub use_implicit_rowid: bool, +} + +struct DeleteCachedStmt { + stmt: turso_core::Statement, + pk_column_indices: Option>, // if None - use rowid instead +} + +pub struct DatabaseReplaySession { + conn: Arc, + cached_delete_stmt: HashMap, + cached_insert_stmt: HashMap<(String, usize), turso_core::Statement>, + in_txn: bool, + opts: DatabaseReplaySessionOpts, +} + +impl DatabaseReplaySession { + pub fn conn(&self) -> Arc { + self.conn.clone() + } + pub async fn replay(&mut self, coro: &Coro, operation: DatabaseTapeOperation) -> Result<()> { + match operation { + DatabaseTapeOperation::Commit => { + tracing::debug!("replay: commit replayed changes after transaction boundary"); + if self.in_txn { + self.conn.execute("COMMIT")?; + self.in_txn = false; + } + } + DatabaseTapeOperation::RowChange(change) => { + if !self.in_txn { + tracing::trace!("replay: start txn for replaying changes"); + self.conn.execute("BEGIN")?; + self.in_txn = true; + } + tracing::trace!("replay: change={:?}", change); + let table_name = &change.table_name; + match change.change { + DatabaseTapeRowChangeType::Delete { before } => { + let before = parse_bin_record(before)?; + self.replay_delete(coro, table_name, change.id, before) + .await? + } + DatabaseTapeRowChangeType::Update { before, after } => { + let before = parse_bin_record(before)?; + self.replay_delete(coro, table_name, change.id, before) + .await?; + let after = parse_bin_record(after)?; + self.replay_insert(coro, table_name, change.id, after) + .await?; + } + DatabaseTapeRowChangeType::Insert { after } => { + let values = parse_bin_record(after)?; + self.replay_insert(coro, table_name, change.id, values) + .await?; + } + } + } + } + Ok(()) + } + async fn replay_delete( + &mut self, + coro: &Coro, + table_name: &str, + id: i64, + mut values: Vec, + ) -> Result<()> { + let cached = self.cached_delete_stmt(coro, table_name).await?; + if let Some(pk_column_indices) = &cached.pk_column_indices { + for (i, pk_column) in pk_column_indices.iter().enumerate() { + let value = std::mem::replace(&mut values[*pk_column], turso_core::Value::Null); + cached.stmt.bind_at((i + 1).try_into().unwrap(), value); + } + } else { + let value = turso_core::Value::Integer(id); + cached.stmt.bind_at(1.try_into().unwrap(), value); + } + exec_stmt(coro, &mut cached.stmt).await?; + Ok(()) + } + async fn replay_insert( + &mut self, + coro: &Coro, + table_name: &str, + id: i64, + values: Vec, + ) -> Result<()> { + let columns = values.len(); + let use_implicit_rowid = self.opts.use_implicit_rowid; + let stmt = self.cached_insert_stmt(coro, table_name, columns).await?; + stmt.reset(); + + for (i, value) in values.into_iter().enumerate() { + stmt.bind_at((i + 1).try_into().unwrap(), value); + } + if use_implicit_rowid { + stmt.bind_at( + (columns + 1).try_into().unwrap(), + turso_core::Value::Integer(id), + ); + } + exec_stmt(coro, stmt).await?; + Ok(()) + } + async fn cached_delete_stmt( + &mut self, + coro: &Coro, + table_name: &str, + ) -> Result<&mut DeleteCachedStmt> { + if !self.cached_delete_stmt.contains_key(table_name) { + tracing::trace!("prepare delete statement for replay: table={}", table_name); + let stmt = self.delete_query(coro, table_name).await?; + self.cached_delete_stmt.insert(table_name.to_string(), stmt); + } + tracing::trace!( + "ready to use prepared delete statement for replay: table={}", + table_name + ); + let cached = self.cached_delete_stmt.get_mut(table_name).unwrap(); + cached.stmt.reset(); + Ok(cached) + } + async fn cached_insert_stmt( + &mut self, + coro: &Coro, + table_name: &str, + columns: usize, + ) -> Result<&mut turso_core::Statement> { + let key = (table_name.to_string(), columns); + if !self.cached_insert_stmt.contains_key(&key) { + tracing::trace!( + "prepare insert statement for replay: table={}, columns={}", + table_name, + columns + ); + let stmt = self.insert_query(coro, table_name, columns).await?; + self.cached_insert_stmt.insert(key.clone(), stmt); + } + tracing::trace!( + "ready to use prepared insert statement for replay: table={}, columns={}", + table_name, + columns + ); + let stmt = self.cached_insert_stmt.get_mut(&key).unwrap(); + stmt.reset(); + Ok(stmt) + } + + async fn insert_query( + &self, + coro: &Coro, + table_name: &str, + columns: usize, + ) -> Result { + let query = if !self.opts.use_implicit_rowid { + let placeholders = ["?"].repeat(columns).join(","); + format!("INSERT INTO {table_name} VALUES ({placeholders})") + } else { + let mut table_info_stmt = self.conn.prepare(format!( + "SELECT name FROM pragma_table_info('{table_name}')" + ))?; + let mut column_names = Vec::with_capacity(columns + 1); + while let Some(column) = run_stmt(coro, &mut table_info_stmt).await? { + let turso_core::Value::Text(text) = column.get_value(0) else { + return Err(Error::DatabaseTapeError( + "unexpected column type for pragma_table_info query".to_string(), + )); + }; + column_names.push(text.to_string()); + } + column_names.push("rowid".to_string()); + + let placeholders = ["?"].repeat(columns + 1).join(","); + let column_names = column_names.join(", "); + format!("INSERT INTO {table_name}({column_names}) VALUES ({placeholders})") + }; + Ok(self.conn.prepare(&query)?) + } + + async fn delete_query(&self, coro: &Coro, table_name: &str) -> Result { + let (query, pk_column_indices) = if self.opts.use_implicit_rowid { + (format!("DELETE FROM {table_name} WHERE rowid = ?"), None) + } else { + let mut pk_info_stmt = self.conn.prepare(format!( + "SELECT cid, name FROM pragma_table_info('{table_name}') WHERE pk = 1" + ))?; + let mut pk_predicates = Vec::with_capacity(1); + let mut pk_column_indices = Vec::with_capacity(1); + while let Some(column) = run_stmt(coro, &mut pk_info_stmt).await? { + let turso_core::Value::Integer(column_id) = column.get_value(0) else { + return Err(Error::DatabaseTapeError( + "unexpected column type for pragma_table_info query".to_string(), + )); + }; + let turso_core::Value::Text(name) = column.get_value(1) else { + return Err(Error::DatabaseTapeError( + "unexpected column type for pragma_table_info query".to_string(), + )); + }; + pk_predicates.push(format!("{name} = ?")); + pk_column_indices.push(*column_id as usize); + } + + if pk_column_indices.is_empty() { + (format!("DELETE FROM {table_name} WHERE rowid = ?"), None) + } else { + let pk_predicates = pk_predicates.join(" AND "); + let query = format!("DELETE FROM {table_name} WHERE {pk_predicates}"); + (query, Some(pk_column_indices)) + } + }; + let use_implicit_rowid = self.opts.use_implicit_rowid; + tracing::trace!("delete_query: table_name={table_name}, query={query}, use_implicit_rowid={use_implicit_rowid}"); + let stmt = self.conn.prepare(&query)?; + Ok(DeleteCachedStmt { + stmt, + pk_column_indices, + }) + } +} + +fn parse_bin_record(bin_record: Vec) -> Result> { + let record = turso_core::types::ImmutableRecord::from_bin_record(bin_record); + let mut cursor = turso_core::types::RecordCursor::new(); + let columns = cursor.count(&record); + let mut values = Vec::with_capacity(columns); + for i in 0..columns { + let value = cursor.get_value(&record, i)?; + values.push(value.to_owned()); + } + Ok(values) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use tempfile::NamedTempFile; + + use crate::{ + database_tape::{run_stmt, DatabaseReplaySessionOpts, DatabaseTape}, + types::{DatabaseTapeOperation, DatabaseTapeRowChange, DatabaseTapeRowChangeType}, + }; + + #[test] + pub fn test_database_tape_connect() { + let temp_file1 = NamedTempFile::new().unwrap(); + let db_path1 = temp_file1.path().to_str().unwrap(); + + let io: Arc = Arc::new(turso_core::PlatformIO::new().unwrap()); + let db1 = turso_core::Database::open_file(io.clone(), db_path1, false, false).unwrap(); + let db1 = Arc::new(DatabaseTape::new(db1)); + let mut gen = genawaiter::sync::Gen::new({ + let db1 = db1.clone(); + |coro| async move { + let conn = db1.connect(&coro).await.unwrap(); + let mut stmt = conn.prepare("SELECT * FROM turso_cdc").unwrap(); + let mut rows = Vec::new(); + while let Some(row) = run_stmt(&coro, &mut stmt).await.unwrap() { + rows.push(row.get_values().cloned().collect::>()); + } + rows + } + }); + let rows = loop { + match gen.resume_with(Ok(())) { + genawaiter::GeneratorState::Yielded(..) => io.run_once().unwrap(), + genawaiter::GeneratorState::Complete(result) => break result, + } + }; + assert_eq!(rows, vec![] as Vec>); + } + + #[test] + pub fn test_database_tape_iterate_changes() { + let temp_file1 = NamedTempFile::new().unwrap(); + let db_path1 = temp_file1.path().to_str().unwrap(); + + let io: Arc = Arc::new(turso_core::PlatformIO::new().unwrap()); + let db1 = turso_core::Database::open_file(io.clone(), db_path1, false, false).unwrap(); + let db1 = Arc::new(DatabaseTape::new(db1)); + + let mut gen = genawaiter::sync::Gen::new({ + let db1 = db1.clone(); + |coro| async move { + let conn = db1.connect(&coro).await.unwrap(); + conn.execute("CREATE TABLE t(x)").unwrap(); + conn.execute("INSERT INTO t VALUES (1), (2), (3)").unwrap(); + let opts = Default::default(); + let mut iterator = db1.iterate_changes(opts).unwrap(); + let mut changes = Vec::new(); + while let Some(change) = iterator.next(&coro).await.unwrap() { + changes.push(change); + } + changes + } + }); + let changes = loop { + match gen.resume_with(Ok(())) { + genawaiter::GeneratorState::Yielded(..) => io.run_once().unwrap(), + genawaiter::GeneratorState::Complete(result) => break result, + } + }; + tracing::info!("changes: {:?}", changes); + assert_eq!(changes.len(), 4); + assert!(matches!( + changes[0], + DatabaseTapeOperation::RowChange(DatabaseTapeRowChange { + change_id: 2, + id: 1, + ref table_name, + change: DatabaseTapeRowChangeType::Insert { .. }, + .. + }) if table_name == "t" + )); + assert!(matches!( + changes[1], + DatabaseTapeOperation::RowChange(DatabaseTapeRowChange { + change_id: 3, + id: 2, + ref table_name, + change: DatabaseTapeRowChangeType::Insert { .. }, + .. + }) if table_name == "t" + )); + assert!(matches!( + changes[2], + DatabaseTapeOperation::RowChange(DatabaseTapeRowChange { + change_id: 4, + id: 3, + ref table_name, + change: DatabaseTapeRowChangeType::Insert { .. }, + .. + }) if table_name == "t" + )); + assert!(matches!(changes[3], DatabaseTapeOperation::Commit)); + } + + #[test] + pub fn test_database_tape_replay_changes_preserve_rowid() { + let temp_file1 = NamedTempFile::new().unwrap(); + let db_path1 = temp_file1.path().to_str().unwrap(); + let temp_file2 = NamedTempFile::new().unwrap(); + let db_path2 = temp_file2.path().to_str().unwrap(); + + let io: Arc = Arc::new(turso_core::PlatformIO::new().unwrap()); + let db1 = turso_core::Database::open_file(io.clone(), db_path1, false, false).unwrap(); + let db1 = Arc::new(DatabaseTape::new(db1)); + + let db2 = turso_core::Database::open_file(io.clone(), db_path2, false, false).unwrap(); + let db2 = Arc::new(DatabaseTape::new(db2)); + + let mut gen = genawaiter::sync::Gen::new({ + let db1 = db1.clone(); + let db2 = db2.clone(); + |coro| async move { + let conn1 = db1.connect(&coro).await.unwrap(); + conn1.execute("CREATE TABLE t(x)").unwrap(); + conn1 + .execute("INSERT INTO t(rowid, x) VALUES (10, 1), (20, 2)") + .unwrap(); + let conn2 = db2.connect(&coro).await.unwrap(); + conn2.execute("CREATE TABLE t(x)").unwrap(); + conn2 + .execute("INSERT INTO t(rowid, x) VALUES (1, -1), (2, -2)") + .unwrap(); + + { + let opts = DatabaseReplaySessionOpts { + use_implicit_rowid: true, + }; + let mut session = db2.start_replay_session(&coro, opts).await.unwrap(); + let opts = Default::default(); + let mut iterator = db1.iterate_changes(opts).unwrap(); + while let Some(operation) = iterator.next(&coro).await.unwrap() { + session.replay(&coro, operation).await.unwrap(); + } + } + let mut stmt = conn2.prepare("SELECT rowid, x FROM t").unwrap(); + let mut rows = Vec::new(); + while let Some(row) = run_stmt(&coro, &mut stmt).await.unwrap() { + rows.push(row.get_values().cloned().collect::>()); + } + rows + } + }); + let rows = loop { + match gen.resume_with(Ok(())) { + genawaiter::GeneratorState::Yielded(..) => io.run_once().unwrap(), + genawaiter::GeneratorState::Complete(rows) => break rows, + } + }; + tracing::info!("rows: {:?}", rows); + assert_eq!( + rows, + vec![ + vec![ + turso_core::Value::Integer(1), + turso_core::Value::Integer(-1) + ], + vec![ + turso_core::Value::Integer(2), + turso_core::Value::Integer(-2) + ], + vec![ + turso_core::Value::Integer(10), + turso_core::Value::Integer(1) + ], + vec![ + turso_core::Value::Integer(20), + turso_core::Value::Integer(2) + ] + ] + ); + } + + #[test] + pub fn test_database_tape_replay_changes_do_not_preserve_rowid() { + let temp_file1 = NamedTempFile::new().unwrap(); + let db_path1 = temp_file1.path().to_str().unwrap(); + let temp_file2 = NamedTempFile::new().unwrap(); + let db_path2 = temp_file2.path().to_str().unwrap(); + + let io: Arc = Arc::new(turso_core::PlatformIO::new().unwrap()); + let db1 = turso_core::Database::open_file(io.clone(), db_path1, false, false).unwrap(); + let db1 = Arc::new(DatabaseTape::new(db1)); + + let db2 = turso_core::Database::open_file(io.clone(), db_path2, false, false).unwrap(); + let db2 = Arc::new(DatabaseTape::new(db2)); + + let mut gen = genawaiter::sync::Gen::new({ + let db1 = db1.clone(); + let db2 = db2.clone(); + |coro| async move { + let conn1 = db1.connect(&coro).await.unwrap(); + conn1.execute("CREATE TABLE t(x)").unwrap(); + conn1 + .execute("INSERT INTO t(rowid, x) VALUES (10, 1), (20, 2)") + .unwrap(); + let conn2 = db2.connect(&coro).await.unwrap(); + conn2.execute("CREATE TABLE t(x)").unwrap(); + conn2 + .execute("INSERT INTO t(rowid, x) VALUES (1, -1), (2, -2)") + .unwrap(); + + { + let opts = DatabaseReplaySessionOpts { + use_implicit_rowid: false, + }; + let mut session = db2.start_replay_session(&coro, opts).await.unwrap(); + let opts = Default::default(); + let mut iterator = db1.iterate_changes(opts).unwrap(); + while let Some(operation) = iterator.next(&coro).await.unwrap() { + session.replay(&coro, operation).await.unwrap(); + } + } + let mut stmt = conn2.prepare("SELECT rowid, x FROM t").unwrap(); + let mut rows = Vec::new(); + while let Some(row) = run_stmt(&coro, &mut stmt).await.unwrap() { + rows.push(row.get_values().cloned().collect::>()); + } + rows + } + }); + let rows = loop { + match gen.resume_with(Ok(())) { + genawaiter::GeneratorState::Yielded(..) => io.run_once().unwrap(), + genawaiter::GeneratorState::Complete(rows) => break rows, + } + }; + tracing::info!("rows: {:?}", rows); + assert_eq!( + rows, + vec![ + vec![ + turso_core::Value::Integer(1), + turso_core::Value::Integer(-1) + ], + vec![ + turso_core::Value::Integer(2), + turso_core::Value::Integer(-2) + ], + vec![turso_core::Value::Integer(3), turso_core::Value::Integer(1)], + vec![turso_core::Value::Integer(4), turso_core::Value::Integer(2)] + ] + ); + } + + #[test] + pub fn test_database_tape_replay_changes_delete() { + let temp_file1 = NamedTempFile::new().unwrap(); + let db_path1 = temp_file1.path().to_str().unwrap(); + let temp_file2 = NamedTempFile::new().unwrap(); + let db_path2 = temp_file2.path().to_str().unwrap(); + + let io: Arc = Arc::new(turso_core::PlatformIO::new().unwrap()); + let db1 = turso_core::Database::open_file(io.clone(), db_path1, false, true).unwrap(); + let db1 = Arc::new(DatabaseTape::new(db1)); + + let db2 = turso_core::Database::open_file(io.clone(), db_path2, false, true).unwrap(); + let db2 = Arc::new(DatabaseTape::new(db2)); + + let mut gen = genawaiter::sync::Gen::new({ + let db1 = db1.clone(); + let db2 = db2.clone(); + |coro| async move { + let conn1 = db1.connect(&coro).await.unwrap(); + conn1.execute("CREATE TABLE t(x TEXT PRIMARY KEY)").unwrap(); + conn1.execute("INSERT INTO t(x) VALUES ('a')").unwrap(); + conn1.execute("DELETE FROM t").unwrap(); + let conn2 = db2.connect(&coro).await.unwrap(); + conn2.execute("CREATE TABLE t(x TEXT PRIMARY KEY)").unwrap(); + conn2.execute("INSERT INTO t(x) VALUES ('b')").unwrap(); + + { + let opts = DatabaseReplaySessionOpts { + use_implicit_rowid: false, + }; + let mut session = db2.start_replay_session(&coro, opts).await.unwrap(); + let opts = Default::default(); + let mut iterator = db1.iterate_changes(opts).unwrap(); + while let Some(operation) = iterator.next(&coro).await.unwrap() { + session.replay(&coro, operation).await.unwrap(); + } + } + let mut stmt = conn2.prepare("SELECT rowid, x FROM t").unwrap(); + let mut rows = Vec::new(); + while let Some(row) = run_stmt(&coro, &mut stmt).await.unwrap() { + rows.push(row.get_values().cloned().collect::>()); + } + rows + } + }); + let rows = loop { + match gen.resume_with(Ok(())) { + genawaiter::GeneratorState::Yielded(..) => io.run_once().unwrap(), + genawaiter::GeneratorState::Complete(rows) => break rows, + } + }; + tracing::info!("rows: {:?}", rows); + assert_eq!( + rows, + vec![vec![ + turso_core::Value::Integer(1), + turso_core::Value::Text(turso_core::types::Text::new("b")) + ]] + ); + } +} diff --git a/packages/turso-sync-engine/src/errors.rs b/packages/turso-sync-engine/src/errors.rs new file mode 100644 index 000000000..abd40a666 --- /dev/null +++ b/packages/turso-sync-engine/src/errors.rs @@ -0,0 +1,20 @@ +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("database error: {0}")] + TursoError(#[from] turso_core::LimboError), + #[error("database tape error: {0}")] + DatabaseTapeError(String), + #[error("deserialization error: {0}")] + JsonDecode(#[from] serde_json::Error), + #[error("database sync engine error: {0}")] + DatabaseSyncEngineError(String), + #[error("database sync engine conflict: {0}")] + DatabaseSyncEngineConflict(String), +} + +#[cfg(test)] +impl From for Error { + fn from(value: turso::Error) -> Self { + Self::TursoError(turso_core::LimboError::InternalError(value.to_string())) + } +} diff --git a/packages/turso-sync-engine/src/io_operations.rs b/packages/turso-sync-engine/src/io_operations.rs new file mode 100644 index 000000000..a5eee38f7 --- /dev/null +++ b/packages/turso-sync-engine/src/io_operations.rs @@ -0,0 +1,61 @@ +use std::sync::Arc; + +use turso_core::{Completion, LimboError, OpenFlags}; + +use crate::{ + database_tape::{DatabaseTape, DatabaseTapeOpts}, + types::{Coro, ProtocolCommand}, + Result, +}; + +pub trait IoOperations { + fn open_tape(&self, path: &str, capture: bool) -> Result; + fn try_open(&self, path: &str) -> Result>>; + fn create(&self, path: &str) -> Result>; + fn truncate( + &self, + coro: &Coro, + file: Arc, + len: usize, + ) -> impl std::future::Future>; +} + +impl IoOperations for Arc { + fn open_tape(&self, path: &str, capture: bool) -> Result { + let io = self.clone(); + let clean = turso_core::Database::open_file(io, path, false, true).unwrap(); + let opts = DatabaseTapeOpts { + cdc_table: None, + cdc_mode: Some(if capture { "full" } else { "off" }.to_string()), + }; + tracing::debug!("initialize database tape connection: path={}", path); + Ok(DatabaseTape::new_with_opts(clean, opts)) + } + fn try_open(&self, path: &str) -> Result>> { + match self.open_file(path, OpenFlags::None, false) { + Ok(file) => Ok(Some(file)), + Err(LimboError::IOError(err)) if err.kind() == std::io::ErrorKind::NotFound => Ok(None), + Err(err) => Err(err.into()), + } + } + fn create(&self, path: &str) -> Result> { + match self.open_file(path, OpenFlags::Create, false) { + Ok(file) => Ok(file), + Err(err) => Err(err.into()), + } + } + + async fn truncate( + &self, + coro: &Coro, + file: Arc, + len: usize, + ) -> Result<()> { + let c = Completion::new_trunc(move |rc| tracing::debug!("file truncated: rc={}", rc)); + let c = file.truncate(len, c)?; + while !c.is_completed() { + coro.yield_(ProtocolCommand::IO).await?; + } + Ok(()) + } +} diff --git a/packages/turso-sync-engine/src/lib.rs b/packages/turso-sync-engine/src/lib.rs new file mode 100644 index 000000000..4b0f723c1 --- /dev/null +++ b/packages/turso-sync-engine/src/lib.rs @@ -0,0 +1,173 @@ +pub mod database_sync_engine; +pub mod database_sync_operations; +pub mod database_tape; +pub mod errors; +pub mod io_operations; +pub mod protocol_io; +pub mod types; +pub mod wal_session; + +#[cfg(test)] +pub mod test_context; +#[cfg(test)] +pub mod test_protocol_io; +#[cfg(test)] +pub mod test_sync_server; + +pub type Result = std::result::Result; + +#[cfg(test)] +mod tests { + use std::{path::PathBuf, sync::Arc}; + + use tokio::{select, sync::Mutex}; + use tracing_subscriber::EnvFilter; + use turso_core::IO; + + use crate::{ + database_sync_engine::{DatabaseSyncEngine, DatabaseSyncEngineOpts}, + errors::Error, + test_context::TestContext, + test_protocol_io::TestProtocolIo, + types::{Coro, ProtocolCommand}, + Result, + }; + + #[ctor::ctor] + fn init() { + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .with_ansi(false) + .init(); + } + + pub fn seed_u64() -> u64 { + seed().parse().unwrap_or(0) + } + + pub fn seed() -> String { + std::env::var("SEED").unwrap_or("0".to_string()) + } + + pub fn deterministic_runtime_from_seed>( + seed: &[u8], + f: impl Fn() -> F, + ) { + let seed = tokio::runtime::RngSeed::from_bytes(seed); + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_time() + .start_paused(true) + .rng_seed(seed) + .build_local(Default::default()) + .unwrap(); + runtime.block_on(f()); + } + + pub fn deterministic_runtime>(f: impl Fn() -> F) { + let seed = seed(); + deterministic_runtime_from_seed(seed.as_bytes(), f); + } + + pub struct TestRunner { + pub ctx: Arc, + pub io: Arc, + pub sync_server: TestProtocolIo, + db: Option>>>, + } + + impl TestRunner { + pub fn new(ctx: Arc, io: Arc, sync_server: TestProtocolIo) -> Self { + Self { + ctx, + io, + sync_server, + db: None, + } + } + pub async fn init( + &mut self, + local_path: PathBuf, + opts: DatabaseSyncEngineOpts, + ) -> Result<()> { + let io = self.io.clone(); + let server = self.sync_server.clone(); + let db = self + .run(genawaiter::sync::Gen::new(|coro| async move { + DatabaseSyncEngine::new(&coro, io, Arc::new(server), &local_path, opts).await + })) + .await + .unwrap(); + self.db = Some(Arc::new(Mutex::new(db))); + Ok(()) + } + pub async fn connect(&self) -> Result { + self.run_db_fn(self.db.as_ref().unwrap(), async move |coro, db| { + Ok(turso::Connection::create(db.connect(coro).await?)) + }) + .await + } + pub async fn pull(&self) -> Result<()> { + self.run_db_fn(self.db.as_ref().unwrap(), async move |coro, db| { + db.pull(coro).await + }) + .await + } + pub async fn push(&self) -> Result<()> { + self.run_db_fn(self.db.as_ref().unwrap(), async move |coro, db| { + db.push(coro).await + }) + .await + } + pub async fn sync(&self) -> Result<()> { + self.run_db_fn(self.db.as_ref().unwrap(), async move |coro, db| { + db.sync(coro).await + }) + .await + } + pub async fn run_db_fn( + &self, + db: &Arc>>, + f: impl AsyncFn(&Coro, &mut DatabaseSyncEngine) -> Result, + ) -> Result { + let g = genawaiter::sync::Gen::new({ + let db = db.clone(); + |coro| async move { + let mut db = db.lock().await; + f(&coro, &mut db).await + } + }); + self.run(g).await + } + pub async fn run>>( + &self, + mut g: genawaiter::sync::Gen, F>, + ) -> Result { + let mut response = Ok(()); + loop { + // we must drive internal tokio clocks on every iteration - otherwise one TestRunner without work can block everything + // if other TestRunner sleeping - as time will "freeze" in this case + self.ctx.random_sleep().await; + + match g.resume_with(response) { + genawaiter::GeneratorState::Complete(result) => return result, + genawaiter::GeneratorState::Yielded(ProtocolCommand::IO) => { + let drained = { + let mut requests = self.sync_server.requests.lock().unwrap(); + requests.drain(..).collect::>() + }; + for mut request in drained { + select! { + value = &mut request => { value.unwrap(); }, + _ = self.ctx.random_sleep() => { self.sync_server.requests.lock().unwrap().push(request); } + }; + } + response = + self.io.run_once().map(|_| ()).map_err(|e| { + Error::DatabaseSyncEngineError(format!("io error: {e}")) + }); + } + } + } + } + } +} diff --git a/packages/turso-sync-engine/src/protocol_io.rs b/packages/turso-sync-engine/src/protocol_io.rs new file mode 100644 index 000000000..f1f599aa7 --- /dev/null +++ b/packages/turso-sync-engine/src/protocol_io.rs @@ -0,0 +1,24 @@ +use crate::Result; + +pub trait DataPollResult { + fn data(&self) -> &[u8]; +} + +pub trait DataCompletion { + type HttpPollResult: DataPollResult; + fn status(&self) -> Result>; + fn poll_data(&self) -> Result>; + fn is_done(&self) -> Result; +} + +pub trait ProtocolIO { + type DataCompletion: DataCompletion; + fn full_read(&self, path: &str) -> Result; + fn full_write(&self, path: &str, content: Vec) -> Result; + fn http( + &self, + method: http::Method, + path: String, + body: Option>, + ) -> Result; +} diff --git a/packages/turso-sync-engine/src/test_context.rs b/packages/turso-sync-engine/src/test_context.rs new file mode 100644 index 000000000..3691550da --- /dev/null +++ b/packages/turso-sync-engine/src/test_context.rs @@ -0,0 +1,147 @@ +use std::{ + collections::{HashMap, HashSet}, + future::Future, + pin::Pin, + sync::Arc, +}; + +use rand::{RngCore, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use tokio::sync::Mutex; + +use crate::{errors::Error, Result}; + +type PinnedFuture = Pin + Send>>; + +pub struct FaultInjectionPlan { + pub is_fault: Box PinnedFuture + Send + Sync>, +} + +pub enum FaultInjectionStrategy { + Disabled, + Record, + Enabled { plan: FaultInjectionPlan }, +} + +impl std::fmt::Debug for FaultInjectionStrategy { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Disabled => write!(f, "Disabled"), + Self::Record => write!(f, "Record"), + Self::Enabled { .. } => write!(f, "Enabled"), + } + } +} + +pub struct TestContext { + fault_injection: Mutex, + faulty_call: Mutex>, + rng: Mutex, +} + +pub struct FaultSession { + ctx: Arc, + recording: bool, + plans: Option>, +} + +impl FaultSession { + pub async fn next(&mut self) -> Option { + if !self.recording { + self.recording = true; + return Some(FaultInjectionStrategy::Record); + } + if self.plans.is_none() { + self.plans = Some(self.ctx.enumerate_simple_plans().await); + } + + let plans = self.plans.as_mut().unwrap(); + if plans.is_empty() { + return None; + } + + let plan = plans.pop().unwrap(); + Some(FaultInjectionStrategy::Enabled { plan }) + } +} + +impl TestContext { + pub fn new(seed: u64) -> Self { + Self { + rng: Mutex::new(ChaCha8Rng::seed_from_u64(seed)), + fault_injection: Mutex::new(FaultInjectionStrategy::Disabled), + faulty_call: Mutex::new(HashSet::new()), + } + } + pub async fn random_sleep(&self) { + let delay = self.rng.lock().await.next_u64() % 1000; + tokio::time::sleep(std::time::Duration::from_millis(delay)).await + } + + pub async fn rng(&self) -> tokio::sync::MutexGuard { + self.rng.lock().await + } + pub fn fault_session(self: &Arc) -> FaultSession { + FaultSession { + ctx: self.clone(), + recording: false, + plans: None, + } + } + pub async fn switch_mode(&self, updated: FaultInjectionStrategy) { + let mut mode = self.fault_injection.lock().await; + tracing::info!("switch fault injection mode: {:?}", updated); + *mode = updated; + } + pub async fn enumerate_simple_plans(&self) -> Vec { + let mut plans = vec![]; + for call in self.faulty_call.lock().await.iter() { + let mut fault_counts = HashMap::new(); + fault_counts.insert(call.clone(), 1); + + let count = Arc::new(Mutex::new(1)); + let call = call.clone(); + plans.push(FaultInjectionPlan { + is_fault: Box::new(move |name, bt| { + let call = call.clone(); + let count = count.clone(); + Box::pin(async move { + if (name, bt) != call { + return false; + } + let mut count = count.lock().await; + *count -= 1; + *count >= 0 + }) + }), + }) + } + plans + } + pub async fn faulty_call(&self, name: &str) -> Result<()> { + tracing::trace!("faulty_call: {}", name); + + // sleep here in order for scheduler to interleave different executions + self.random_sleep().await; + + if let FaultInjectionStrategy::Disabled = &*self.fault_injection.lock().await { + return Ok(()); + } + let bt = std::backtrace::Backtrace::force_capture().to_string(); + match &mut *self.fault_injection.lock().await { + FaultInjectionStrategy::Record => { + let mut call_sites = self.faulty_call.lock().await; + call_sites.insert((name.to_string(), bt)); + Ok(()) + } + FaultInjectionStrategy::Enabled { plan } => { + if plan.is_fault.as_ref()(name.to_string(), bt.clone()).await { + Err(Error::DatabaseSyncEngineError("injected fault".to_string())) + } else { + Ok(()) + } + } + _ => unreachable!("Disabled case handled above"), + } + } +} diff --git a/packages/turso-sync-engine/src/test_empty.db b/packages/turso-sync-engine/src/test_empty.db new file mode 100644 index 000000000..0a06b0094 Binary files /dev/null and b/packages/turso-sync-engine/src/test_empty.db differ diff --git a/packages/turso-sync-engine/src/test_protocol_io.rs b/packages/turso-sync-engine/src/test_protocol_io.rs new file mode 100644 index 000000000..eaef3b6bb --- /dev/null +++ b/packages/turso-sync-engine/src/test_protocol_io.rs @@ -0,0 +1,227 @@ +use std::{ + collections::{HashMap, VecDeque}, + path::Path, + pin::Pin, + sync::Arc, +}; + +use tokio::{sync::Mutex, task::JoinHandle}; + +use crate::{ + errors::Error, + protocol_io::{DataCompletion, DataPollResult, ProtocolIO}, + test_context::TestContext, + test_sync_server::TestSyncServer, + Result, +}; + +#[derive(Clone)] +pub struct TestProtocolIo { + #[allow(clippy::type_complexity)] + pub requests: Arc>>>>>, + pub server: TestSyncServer, + ctx: Arc, + files: Arc>>>, +} + +pub struct TestDataPollResult(Vec); + +impl DataPollResult for TestDataPollResult { + fn data(&self) -> &[u8] { + &self.0 + } +} + +#[derive(Clone)] +pub struct TestDataCompletion { + status: Arc>>, + chunks: Arc>>>, + done: Arc>, + poisoned: Arc>>, +} + +impl Default for TestDataCompletion { + fn default() -> Self { + Self::new() + } +} + +impl TestDataCompletion { + pub fn new() -> Self { + Self { + status: Arc::new(std::sync::Mutex::new(None)), + chunks: Arc::new(std::sync::Mutex::new(VecDeque::new())), + done: Arc::new(std::sync::Mutex::new(false)), + poisoned: Arc::new(std::sync::Mutex::new(None)), + } + } + pub fn set_status(&self, status: u16) { + *self.status.lock().unwrap() = Some(status); + } + + pub fn push_data(&self, data: Vec) { + let mut chunks = self.chunks.lock().unwrap(); + chunks.push_back(data); + } + + pub fn set_done(&self) { + *self.done.lock().unwrap() = true; + } + + pub fn poison(&self, err: &str) { + *self.poisoned.lock().unwrap() = Some(err.to_string()); + } +} + +impl DataCompletion for TestDataCompletion { + type HttpPollResult = TestDataPollResult; + + fn status(&self) -> Result> { + let poison = self.poisoned.lock().unwrap(); + if poison.is_some() { + return Err(Error::DatabaseSyncEngineError(format!( + "status error: {poison:?}" + ))); + } + Ok(*self.status.lock().unwrap()) + } + + fn poll_data(&self) -> Result> { + let poison = self.poisoned.lock().unwrap(); + if poison.is_some() { + return Err(Error::DatabaseSyncEngineError(format!( + "poll_data error: {poison:?}" + ))); + } + let mut chunks = self.chunks.lock().unwrap(); + Ok(chunks.pop_front().map(TestDataPollResult)) + } + + fn is_done(&self) -> Result { + let poison = self.poisoned.lock().unwrap(); + if poison.is_some() { + return Err(Error::DatabaseSyncEngineError(format!( + "is_done error: {poison:?}" + ))); + } + Ok(*self.done.lock().unwrap()) + } +} + +impl TestProtocolIo { + pub async fn new(ctx: Arc, path: &Path) -> Result { + Ok(Self { + ctx: ctx.clone(), + requests: Arc::new(std::sync::Mutex::new(Vec::new())), + server: TestSyncServer::new(ctx, path).await?, + files: Arc::new(Mutex::new(HashMap::new())), + }) + } + fn schedule< + Fut: std::future::Future> + Send + 'static, + F: FnOnce(TestSyncServer, TestDataCompletion) -> Fut + Send + 'static, + >( + &self, + completion: TestDataCompletion, + f: F, + ) { + let server = self.server.clone(); + let mut requests = self.requests.lock().unwrap(); + requests.push(Box::pin(tokio::spawn(async move { + if let Err(err) = f(server, completion.clone()).await { + tracing::info!("poison completion: {}", err); + completion.poison(&err.to_string()); + } + }))); + } +} + +impl ProtocolIO for TestProtocolIo { + type DataCompletion = TestDataCompletion; + fn http( + &self, + method: http::Method, + path: String, + data: Option>, + ) -> Result { + let completion = TestDataCompletion::new(); + { + let completion = completion.clone(); + let path = &path[1..].split("/").collect::>(); + match (method.as_str(), path.as_slice()) { + ("GET", ["info"]) => { + self.schedule(completion, |s, c| async move { s.db_info(c).await }); + } + ("GET", ["export", generation]) => { + let generation = generation.parse().unwrap(); + self.schedule(completion, async move |s, c| { + s.db_export(c, generation).await + }); + } + ("GET", ["sync", generation, start, end]) => { + let generation = generation.parse().unwrap(); + let start = start.parse().unwrap(); + let end = end.parse().unwrap(); + self.schedule(completion, async move |s, c| { + s.wal_pull(c, generation, start, end).await + }); + } + ("POST", ["sync", generation, start, end]) => { + let generation = generation.parse().unwrap(); + let start = start.parse().unwrap(); + let end = end.parse().unwrap(); + let data = data.unwrap(); + self.schedule(completion, async move |s, c| { + s.wal_push(c, None, generation, start, end, data).await + }); + } + ("POST", ["sync", generation, start, end, baton]) => { + let baton = baton.to_string(); + let generation = generation.parse().unwrap(); + let start = start.parse().unwrap(); + let end = end.parse().unwrap(); + let data = data.unwrap(); + self.schedule(completion, async move |s, c| { + s.wal_push(c, Some(baton), generation, start, end, data) + .await + }); + } + _ => panic!("unexpected sync server request: {method} {path:?}"), + }; + } + Ok(completion) + } + + fn full_read(&self, path: &str) -> Result { + let completion = TestDataCompletion::new(); + let ctx = self.ctx.clone(); + let files = self.files.clone(); + let path = path.to_string(); + self.schedule(completion.clone(), async move |_, c| { + ctx.faulty_call("full_read_start").await?; + let files = files.lock().await; + let result = files.get(&path); + c.push_data(result.cloned().unwrap_or(Vec::new())); + ctx.faulty_call("full_read_end").await?; + c.set_done(); + Ok(()) + }); + Ok(completion) + } + + fn full_write(&self, path: &str, content: Vec) -> Result { + let completion = TestDataCompletion::new(); + let ctx = self.ctx.clone(); + let files = self.files.clone(); + let path = path.to_string(); + self.schedule(completion.clone(), async move |_, c| { + ctx.faulty_call("full_write_start").await?; + let mut files = files.lock().await; + files.insert(path, content); + ctx.faulty_call("full_write_end").await?; + c.set_done(); + Ok(()) + }); + Ok(completion) + } +} diff --git a/packages/turso-sync-engine/src/test_sync_server.rs b/packages/turso-sync-engine/src/test_sync_server.rs new file mode 100644 index 000000000..ccfbc1eb6 --- /dev/null +++ b/packages/turso-sync-engine/src/test_sync_server.rs @@ -0,0 +1,303 @@ +use std::{collections::HashMap, path::Path, sync::Arc}; + +use tokio::sync::Mutex; + +use crate::{ + errors::Error, + test_context::TestContext, + test_protocol_io::TestDataCompletion, + types::{DbSyncInfo, DbSyncStatus}, + Result, +}; + +const PAGE_SIZE: usize = 4096; +const FRAME_SIZE: usize = 24 + PAGE_SIZE; + +struct Generation { + snapshot: Vec, + frames: Vec>, +} + +#[derive(Clone)] +struct SyncSession { + baton: String, + conn: turso::Connection, + in_txn: bool, +} + +struct TestSyncServerState { + generation: u64, + generations: HashMap, + sessions: HashMap, +} + +#[derive(Clone)] +pub struct TestSyncServer { + ctx: Arc, + db: turso::Database, + state: Arc>, +} + +impl TestSyncServer { + pub async fn new(ctx: Arc, path: &Path) -> Result { + let mut generations = HashMap::new(); + generations.insert( + 1, + Generation { + snapshot: EMPTY_WAL_MODE_DB.to_vec(), + frames: Vec::new(), + }, + ); + Ok(Self { + ctx, + db: turso::Builder::new_local(path.to_str().unwrap()) + .build() + .await?, + state: Arc::new(Mutex::new(TestSyncServerState { + generation: 1, + generations, + sessions: HashMap::new(), + })), + }) + } + pub async fn db_info(&self, completion: TestDataCompletion) -> Result<()> { + tracing::debug!("db_info"); + self.ctx.faulty_call("db_info_start").await?; + + let state = self.state.lock().await; + let result = DbSyncInfo { + current_generation: state.generation, + }; + + completion.set_status(200); + self.ctx.faulty_call("db_info_status").await?; + + completion.push_data(serde_json::to_vec(&result)?); + self.ctx.faulty_call("db_info_data").await?; + + completion.set_done(); + + Ok(()) + } + + pub async fn db_export( + &self, + completion: TestDataCompletion, + generation_id: u64, + ) -> Result<()> { + tracing::debug!("db_export: {}", generation_id); + self.ctx.faulty_call("db_export_start").await?; + + let state = self.state.lock().await; + let Some(generation) = state.generations.get(&generation_id) else { + return Err(Error::DatabaseSyncEngineError( + "generation not found".to_string(), + )); + }; + completion.set_status(200); + self.ctx.faulty_call("db_export_status").await?; + + completion.push_data(generation.snapshot.clone()); + self.ctx.faulty_call("db_export_push").await?; + + completion.set_done(); + + Ok(()) + } + + pub async fn wal_pull( + &self, + completion: TestDataCompletion, + generation_id: u64, + start_frame: u64, + end_frame: u64, + ) -> Result<()> { + tracing::debug!("wal_pull: {}/{}/{}", generation_id, start_frame, end_frame); + self.ctx.faulty_call("wal_pull_start").await?; + + let state = self.state.lock().await; + let Some(generation) = state.generations.get(&generation_id) else { + return Err(Error::DatabaseSyncEngineError( + "generation not found".to_string(), + )); + }; + let mut data = Vec::new(); + for frame_no in start_frame..end_frame { + let frame_idx = frame_no - 1; + let Some(frame) = generation.frames.get(frame_idx as usize) else { + break; + }; + data.extend_from_slice(frame); + } + if data.is_empty() { + let last_generation = state.generations.get(&state.generation).unwrap(); + + let status = DbSyncStatus { + baton: None, + status: "checkpoint_needed".to_string(), + generation: state.generation, + max_frame_no: last_generation.frames.len() as u64, + }; + completion.set_status(400); + self.ctx.faulty_call("wal_pull_400_status").await?; + + completion.push_data(serde_json::to_vec(&status)?); + self.ctx.faulty_call("wal_pull_400_push").await?; + + completion.set_done(); + } else { + completion.set_status(200); + self.ctx.faulty_call("wal_pull_200_status").await?; + + completion.push_data(data); + self.ctx.faulty_call("wal_pull_200_push").await?; + + completion.set_done(); + }; + + Ok(()) + } + + pub async fn wal_push( + &self, + completion: TestDataCompletion, + mut baton: Option, + generation_id: u64, + start_frame: u64, + end_frame: u64, + frames: Vec, + ) -> Result<()> { + tracing::debug!( + "wal_push: {}/{}/{}/{:?}", + generation_id, + start_frame, + end_frame, + baton + ); + self.ctx.faulty_call("wal_push_start").await?; + + let mut session = { + let mut state = self.state.lock().await; + if state.generation != generation_id { + return Err(Error::DatabaseSyncEngineError( + "generation id mismatch".to_string(), + )); + } + let baton_str = baton.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); + let session = match state.sessions.get(&baton_str) { + Some(session) => session.clone(), + None => { + let session = SyncSession { + baton: baton_str.clone(), + conn: self.db.connect()?, + in_txn: false, + }; + state.sessions.insert(baton_str.clone(), session.clone()); + session + } + }; + baton = Some(baton_str.clone()); + session + }; + + let conflict = 'conflict: { + let mut offset = 0; + for frame_no in start_frame..end_frame { + if offset + FRAME_SIZE > frames.len() { + return Err(Error::DatabaseSyncEngineError( + "unexpected length of frames data".to_string(), + )); + } + if !session.in_txn { + session.conn.wal_insert_begin()?; + session.in_txn = true; + } + let frame = &frames[offset..offset + FRAME_SIZE]; + match session.conn.wal_insert_frame(frame_no, frame) { + Ok(info) => { + if info.is_commit_frame() { + if session.in_txn { + session.conn.wal_insert_end()?; + session.in_txn = false; + } + self.sync_frames_from_conn(&session.conn).await?; + } + } + Err(turso::Error::WalOperationError(err)) if err.contains("Conflict") => { + session.conn.wal_insert_end()?; + break 'conflict true; + } + Err(err) => { + session.conn.wal_insert_end()?; + return Err(err.into()); + } + } + offset += FRAME_SIZE; + } + false + }; + let mut state = self.state.lock().await; + state + .sessions + .insert(baton.clone().unwrap(), session.clone()); + let status = DbSyncStatus { + baton: Some(session.baton.clone()), + status: if conflict { "conflict" } else { "ok" }.into(), + generation: state.generation, + max_frame_no: session.conn.wal_frame_count()?, + }; + + let status = serde_json::to_vec(&status)?; + + completion.set_status(200); + self.ctx.faulty_call("wal_push_status").await?; + + completion.push_data(status); + self.ctx.faulty_call("wal_push_push").await?; + + completion.set_done(); + + Ok(()) + } + + pub fn db(&self) -> turso::Database { + self.db.clone() + } + pub async fn execute(&self, sql: &str, params: impl turso::IntoParams) -> Result<()> { + let conn = self.db.connect()?; + conn.execute(sql, params).await?; + self.sync_frames_from_conn(&conn).await?; + Ok(()) + } + async fn sync_frames_from_conn(&self, conn: &turso::Connection) -> Result<()> { + let mut state = self.state.lock().await; + let generation = state.generation; + let generation = state.generations.get_mut(&generation).unwrap(); + let last_frame = generation.frames.len() + 1; + let mut frame = [0u8; FRAME_SIZE]; + let wal_frame_count = conn.wal_frame_count()?; + tracing::debug!("conn frames count: {}", wal_frame_count); + for frame_no in last_frame..=wal_frame_count as usize { + conn.wal_get_frame(frame_no as u64, &mut frame)?; + tracing::debug!("push local frame {}", frame_no); + generation.frames.push(frame.to_vec()); + } + Ok(()) + } +} + +// empty DB with single 4096-byte page and WAL mode (PRAGMA journal_mode=WAL) +// see test test_empty_wal_mode_db_content which validates asset content +pub const EMPTY_WAL_MODE_DB: &[u8] = include_bytes!("test_empty.db"); + +pub async fn convert_rows(rows: &mut turso::Rows) -> Result>> { + let mut rows_values = vec![]; + while let Some(row) = rows.next().await? { + let mut row_values = vec![]; + for i in 0..row.column_count() { + row_values.push(row.get_value(i)?); + } + rows_values.push(row_values); + } + Ok(rows_values) +} diff --git a/packages/turso-sync-engine/src/types.rs b/packages/turso-sync-engine/src/types.rs new file mode 100644 index 000000000..af7cad9ce --- /dev/null +++ b/packages/turso-sync-engine/src/types.rs @@ -0,0 +1,269 @@ +use serde::{Deserialize, Serialize}; + +use crate::{errors::Error, Result}; + +pub type Coro = genawaiter::sync::Co>; + +#[derive(Debug, Deserialize, Serialize)] +pub struct DbSyncInfo { + pub current_generation: u64, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct DbSyncStatus { + pub baton: Option, + pub status: String, + pub generation: u64, + pub max_frame_no: u64, +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum DatabaseChangeType { + Delete, + Update, + Insert, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct DatabaseMetadata { + /// Unique identifier of the client - generated on sync startup + pub client_unique_id: String, + /// Latest generation from remote which was pulled locally to the Synced DB + pub synced_generation: u64, + /// Latest frame number from remote which was pulled locally to the Synced DB + pub synced_frame_no: Option, + /// pair of frame_no for Draft and Synced DB such that content of the database file up to these frames is identical + pub draft_wal_match_watermark: u64, + pub synced_wal_match_watermark: u64, +} + +impl DatabaseMetadata { + pub fn load(data: &[u8]) -> Result { + let meta = serde_json::from_slice::(data)?; + Ok(meta) + } + pub fn dump(&self) -> Result> { + let data = serde_json::to_string(self)?; + Ok(data.into_bytes()) + } +} + +/// [DatabaseChange] struct represents data from CDC table as-i +/// (see `turso_cdc_table_columns` definition in turso-core) +#[derive(Clone)] +pub struct DatabaseChange { + /// Monotonically incrementing change number + pub change_id: i64, + /// Unix timestamp of the change (not guaranteed to be strictly monotonic as host clocks can drift) + pub change_time: u64, + /// Type of the change + pub change_type: DatabaseChangeType, + /// Table of the change + pub table_name: String, + /// Rowid of changed row + pub id: i64, + /// Binary record of the row before the change, if CDC pragma set to either 'before' or 'full' + pub before: Option>, + /// Binary record of the row after the change, if CDC pragma set to either 'after' or 'full' + pub after: Option>, +} + +impl DatabaseChange { + /// Converts [DatabaseChange] into the operation which effect will be the application of the change + pub fn into_apply(self) -> Result { + let tape_change = match self.change_type { + DatabaseChangeType::Delete => DatabaseTapeRowChangeType::Delete { + before: self.before.ok_or_else(|| { + Error::DatabaseTapeError( + "cdc_mode must be set to either 'full' or 'before'".to_string(), + ) + })?, + }, + DatabaseChangeType::Update => DatabaseTapeRowChangeType::Update { + before: self.before.ok_or_else(|| { + Error::DatabaseTapeError("cdc_mode must be set to 'full'".to_string()) + })?, + after: self.after.ok_or_else(|| { + Error::DatabaseTapeError("cdc_mode must be set to 'full'".to_string()) + })?, + }, + DatabaseChangeType::Insert => DatabaseTapeRowChangeType::Insert { + after: self.after.ok_or_else(|| { + Error::DatabaseTapeError( + "cdc_mode must be set to either 'full' or 'after'".to_string(), + ) + })?, + }, + }; + Ok(DatabaseTapeRowChange { + change_id: self.change_id, + change_time: self.change_time, + change: tape_change, + table_name: self.table_name, + id: self.id, + }) + } + /// Converts [DatabaseChange] into the operation which effect will be the revert of the change + pub fn into_revert(self) -> Result { + let tape_change = match self.change_type { + DatabaseChangeType::Delete => DatabaseTapeRowChangeType::Insert { + after: self.before.ok_or_else(|| { + Error::DatabaseTapeError( + "cdc_mode must be set to either 'full' or 'before'".to_string(), + ) + })?, + }, + DatabaseChangeType::Update => DatabaseTapeRowChangeType::Update { + before: self.after.ok_or_else(|| { + Error::DatabaseTapeError("cdc_mode must be set to 'full'".to_string()) + })?, + after: self.before.ok_or_else(|| { + Error::DatabaseTapeError( + "cdc_mode must be set to either 'full' or 'before'".to_string(), + ) + })?, + }, + DatabaseChangeType::Insert => DatabaseTapeRowChangeType::Delete { + before: self.after.ok_or_else(|| { + Error::DatabaseTapeError( + "cdc_mode must be set to either 'full' or 'after'".to_string(), + ) + })?, + }, + }; + Ok(DatabaseTapeRowChange { + change_id: self.change_id, + change_time: self.change_time, + change: tape_change, + table_name: self.table_name, + id: self.id, + }) + } +} + +impl std::fmt::Debug for DatabaseChange { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DatabaseChange") + .field("change_id", &self.change_id) + .field("change_time", &self.change_time) + .field("change_type", &self.change_type) + .field("table_name", &self.table_name) + .field("id", &self.id) + .field("before.len()", &self.before.as_ref().map(|x| x.len())) + .field("after.len()", &self.after.as_ref().map(|x| x.len())) + .finish() + } +} + +impl TryFrom<&turso_core::Row> for DatabaseChange { + type Error = Error; + + fn try_from(row: &turso_core::Row) -> Result { + let change_id = get_core_value_i64(row, 0)?; + let change_time = get_core_value_i64(row, 1)? as u64; + let change_type = get_core_value_i64(row, 2)?; + let table_name = get_core_value_text(row, 3)?; + let id = get_core_value_i64(row, 4)?; + let before = get_core_value_blob_or_null(row, 5)?; + let after = get_core_value_blob_or_null(row, 6)?; + + let change_type = match change_type { + -1 => DatabaseChangeType::Delete, + 0 => DatabaseChangeType::Update, + 1 => DatabaseChangeType::Insert, + v => { + return Err(Error::DatabaseTapeError(format!( + "unexpected change type: expected -1|0|1, got '{v:?}'" + ))) + } + }; + Ok(Self { + change_id, + change_time, + change_type, + table_name, + id, + before, + after, + }) + } +} + +pub enum DatabaseTapeRowChangeType { + Delete { before: Vec }, + Update { before: Vec, after: Vec }, + Insert { after: Vec }, +} + +/// [DatabaseTapeOperation] extends [DatabaseTapeRowChange] by adding information about transaction boundary +/// +/// This helps [crate::database_tape::DatabaseTapeSession] to properly maintain transaction state and COMMIT or ROLLBACK changes in appropriate time +/// by consuming events from [crate::database_tape::DatabaseChangesIterator] +#[derive(Debug)] +pub enum DatabaseTapeOperation { + RowChange(DatabaseTapeRowChange), + Commit, +} + +/// [DatabaseTapeRowChange] is the specific operation over single row which can be performed on database +#[derive(Debug)] +pub struct DatabaseTapeRowChange { + pub change_id: i64, + pub change_time: u64, + pub change: DatabaseTapeRowChangeType, + pub table_name: String, + pub id: i64, +} + +impl std::fmt::Debug for DatabaseTapeRowChangeType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Delete { before } => f + .debug_struct("Delete") + .field("before.len()", &before.len()) + .finish(), + Self::Update { before, after } => f + .debug_struct("Update") + .field("before.len()", &before.len()) + .field("after.len()", &after.len()) + .finish(), + Self::Insert { after } => f + .debug_struct("Insert") + .field("after.len()", &after.len()) + .finish(), + } + } +} + +fn get_core_value_i64(row: &turso_core::Row, index: usize) -> Result { + match row.get_value(index) { + turso_core::Value::Integer(v) => Ok(*v), + v => Err(Error::DatabaseTapeError(format!( + "column {index} type mismatch: expected integer, got '{v:?}'" + ))), + } +} + +fn get_core_value_text(row: &turso_core::Row, index: usize) -> Result { + match row.get_value(index) { + turso_core::Value::Text(x) => Ok(x.to_string()), + v => Err(Error::DatabaseTapeError(format!( + "column {index} type mismatch: expected string, got '{v:?}'" + ))), + } +} + +fn get_core_value_blob_or_null(row: &turso_core::Row, index: usize) -> Result>> { + match row.get_value(index) { + turso_core::Value::Null => Ok(None), + turso_core::Value::Blob(x) => Ok(Some(x.clone())), + v => Err(Error::DatabaseTapeError(format!( + "column {index} type mismatch: expected blob, got '{v:?}'" + ))), + } +} + +pub enum ProtocolCommand { + // Protocol waits for some IO - caller must spin turso-db IO event loop and also drive ProtocolIO + IO, +} diff --git a/packages/turso-sync-engine/src/wal_session.rs b/packages/turso-sync-engine/src/wal_session.rs new file mode 100644 index 000000000..92e872a92 --- /dev/null +++ b/packages/turso-sync-engine/src/wal_session.rs @@ -0,0 +1,60 @@ +use std::sync::Arc; + +use turso_core::types::WalFrameInfo; + +use crate::Result; + +pub struct WalSession { + conn: Arc, + in_txn: bool, +} + +unsafe impl Send for WalSession {} +unsafe impl Sync for WalSession {} + +impl WalSession { + pub fn new(conn: Arc) -> Self { + Self { + conn, + in_txn: false, + } + } + pub fn conn(&self) -> &Arc { + &self.conn + } + pub fn begin(&mut self) -> Result<()> { + assert!(!self.in_txn); + self.conn.wal_insert_begin()?; + self.in_txn = true; + Ok(()) + } + pub fn insert_at(&mut self, frame_no: u64, frame: &[u8]) -> Result { + assert!(self.in_txn); + let info = self.conn.wal_insert_frame(frame_no, frame)?; + Ok(info) + } + pub fn read_at(&mut self, frame_no: u64, frame: &mut [u8]) -> Result { + assert!(self.in_txn); + let info = self.conn.wal_get_frame(frame_no, frame)?; + Ok(info) + } + pub fn end(&mut self) -> Result<()> { + assert!(self.in_txn); + self.conn.wal_insert_end()?; + self.in_txn = false; + Ok(()) + } + pub fn in_txn(&self) -> bool { + self.in_txn + } +} + +impl Drop for WalSession { + fn drop(&mut self) { + if self.in_txn { + let _ = self + .end() + .inspect_err(|e| tracing::error!("failed to close WAL session: {}", e)); + } + } +}