diff --git a/Cargo.lock b/Cargo.lock index a1839a233..3ccac9997 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4015,6 +4015,7 @@ dependencies = [ name = "turso_sync_engine" version = "0.1.4-pre.8" dependencies = [ + "base64", "bytes", "ctor", "futures", diff --git a/packages/turso-sync-engine/Cargo.toml b/packages/turso-sync-engine/Cargo.toml index d49aa1df9..c2a13c4ee 100644 --- a/packages/turso-sync-engine/Cargo.toml +++ b/packages/turso-sync-engine/Cargo.toml @@ -16,6 +16,7 @@ bytes = "1.10.1" genawaiter = { version = "0.99.1", default-features = false } http = "1.3.1" uuid = "1.17.0" +base64 = "0.22.1" [dev-dependencies] ctor = "0.4.2" diff --git a/packages/turso-sync-engine/src/database_replay_generator.rs b/packages/turso-sync-engine/src/database_replay_generator.rs new file mode 100644 index 000000000..cceb6f98f --- /dev/null +++ b/packages/turso-sync-engine/src/database_replay_generator.rs @@ -0,0 +1,338 @@ +use std::sync::Arc; + +use crate::{ + database_tape::{run_stmt_once, DatabaseReplaySessionOpts}, + errors::Error, + types::{Coro, DatabaseChangeType, DatabaseTapeRowChange, DatabaseTapeRowChangeType}, + Result, +}; + +pub struct DatabaseReplayGenerator { + pub conn: Arc, + pub opts: DatabaseReplaySessionOpts, +} + +pub struct ReplayInfo { + pub change_type: DatabaseChangeType, + pub query: String, + pub pk_column_indices: Option>, + pub is_ddl_replay: bool, +} + +const SQLITE_SCHEMA_TABLE: &str = "sqlite_schema"; +impl DatabaseReplayGenerator { + pub fn new(conn: Arc, opts: DatabaseReplaySessionOpts) -> Self { + Self { conn, opts } + } + pub fn replay_values( + &self, + info: &ReplayInfo, + change: DatabaseChangeType, + id: i64, + mut record: Vec, + updates: Option>, + ) -> Vec { + if info.is_ddl_replay { + return Vec::new(); + } + match change { + DatabaseChangeType::Delete => { + if self.opts.use_implicit_rowid { + vec![turso_core::Value::Integer(id)] + } else { + let mut values = Vec::new(); + let pk_column_indices = info.pk_column_indices.as_ref().unwrap(); + for pk in pk_column_indices { + let value = std::mem::replace(&mut record[*pk], turso_core::Value::Null); + values.push(value); + } + values + } + } + DatabaseChangeType::Insert => { + if self.opts.use_implicit_rowid { + record.push(turso_core::Value::Integer(id)); + } + record + } + DatabaseChangeType::Update => { + let mut updates = updates.unwrap(); + assert!(updates.len() % 2 == 0); + let columns_cnt = updates.len() / 2; + let mut values = Vec::with_capacity(columns_cnt + 1); + for i in 0..columns_cnt { + let changed = match updates[i] { + turso_core::Value::Integer(x @ (1 | 0)) => x > 0, + _ => panic!( + "unexpected 'changes' binary record first-half component: {:?}", + updates[i] + ), + }; + if !changed { + continue; + } + let value = + std::mem::replace(&mut updates[i + columns_cnt], turso_core::Value::Null); + values.push(value); + } + if let Some(pk_column_indices) = &info.pk_column_indices { + for pk in pk_column_indices { + let value = std::mem::replace(&mut record[*pk], turso_core::Value::Null); + values.push(value); + } + } else { + values.push(turso_core::Value::Integer(id)); + } + values + } + } + } + pub async fn replay_info( + &self, + coro: &Coro, + change: &DatabaseTapeRowChange, + ) -> Result> { + tracing::trace!("replay: change={:?}", change); + let table_name = &change.table_name; + + if table_name == SQLITE_SCHEMA_TABLE { + // sqlite_schema table: type, name, tbl_name, rootpage, sql + match &change.change { + DatabaseTapeRowChangeType::Delete { before } => { + assert!(before.len() == 5); + let Some(turso_core::Value::Text(entity_type)) = before.first() else { + panic!( + "unexpected 'type' column of sqlite_schema table: {:?}", + before.first() + ); + }; + let Some(turso_core::Value::Text(entity_name)) = before.get(1) else { + panic!( + "unexpected 'name' column of sqlite_schema table: {:?}", + before.get(1) + ); + }; + let query = format!("DROP {} {}", entity_type.as_str(), entity_name.as_str()); + let delete = ReplayInfo { + change_type: DatabaseChangeType::Delete, + query, + pk_column_indices: None, + is_ddl_replay: true, + }; + Ok(vec![delete]) + } + DatabaseTapeRowChangeType::Insert { after } => { + assert!(after.len() == 5); + let Some(turso_core::Value::Text(sql)) = after.last() else { + return Err(Error::DatabaseTapeError(format!( + "unexpected 'sql' column of sqlite_schema table: {:?}", + after.last() + ))); + }; + let insert = ReplayInfo { + change_type: DatabaseChangeType::Insert, + query: sql.as_str().to_string(), + pk_column_indices: None, + is_ddl_replay: true, + }; + Ok(vec![insert]) + } + DatabaseTapeRowChangeType::Update { updates, .. } => { + let Some(updates) = updates else { + return Err(Error::DatabaseTapeError( + "'updates' column of CDC table must be populated".to_string(), + )); + }; + assert!(updates.len() % 2 == 0); + assert!(updates.len() / 2 == 5); + let turso_core::Value::Text(ddl_stmt) = updates.last().unwrap() else { + panic!( + "unexpected 'sql' column of sqlite_schema table update record: {:?}", + updates.last() + ); + }; + let update = ReplayInfo { + change_type: DatabaseChangeType::Update, + query: ddl_stmt.as_str().to_string(), + pk_column_indices: None, + is_ddl_replay: true, + }; + Ok(vec![update]) + } + } + } else { + match &change.change { + DatabaseTapeRowChangeType::Delete { .. } => { + let delete = self.delete_query(coro, table_name).await?; + Ok(vec![delete]) + } + DatabaseTapeRowChangeType::Update { updates, after, .. } => { + if let Some(updates) = updates { + assert!(updates.len() % 2 == 0); + let columns_cnt = updates.len() / 2; + let mut columns = Vec::with_capacity(columns_cnt); + for value in updates.iter().take(columns_cnt) { + columns.push(match value { + turso_core::Value::Integer(x @ (1 | 0)) => *x > 0, + _ => panic!("unexpected 'changes' binary record first-half component: {value:?}") + }); + } + let update = self.update_query(coro, table_name, &columns).await?; + Ok(vec![update]) + } else { + let delete = self.delete_query(coro, table_name).await?; + let insert = self.insert_query(coro, table_name, after.len()).await?; + Ok(vec![delete, insert]) + } + } + DatabaseTapeRowChangeType::Insert { after } => { + let insert = self.insert_query(coro, table_name, after.len()).await?; + Ok(vec![insert]) + } + } + } + } + pub(crate) async fn update_query( + &self, + coro: &Coro, + table_name: &str, + columns: &[bool], + ) -> Result { + let mut table_info_stmt = self.conn.prepare(format!( + "SELECT cid, name, pk FROM pragma_table_info('{table_name}')" + ))?; + let mut pk_predicates = Vec::with_capacity(1); + let mut pk_column_indices = Vec::with_capacity(1); + let mut column_updates = Vec::with_capacity(1); + while let Some(column) = run_stmt_once(coro, &mut table_info_stmt).await? { + let turso_core::Value::Integer(column_id) = column.get_value(0) else { + return Err(Error::DatabaseTapeError( + "unexpected column type for pragma_table_info query".to_string(), + )); + }; + let turso_core::Value::Text(name) = column.get_value(1) else { + return Err(Error::DatabaseTapeError( + "unexpected column type for pragma_table_info query".to_string(), + )); + }; + let turso_core::Value::Integer(pk) = column.get_value(2) else { + return Err(Error::DatabaseTapeError( + "unexpected column type for pragma_table_info query".to_string(), + )); + }; + if *pk == 1 { + pk_predicates.push(format!("{name} = ?")); + pk_column_indices.push(*column_id as usize); + } + if columns[*column_id as usize] { + column_updates.push(format!("{name} = ?")); + } + } + + let (query, pk_column_indices) = if self.opts.use_implicit_rowid { + ( + format!( + "UPDATE {table_name} SET {} WHERE rowid = ?", + column_updates.join(", ") + ), + None, + ) + } else { + ( + format!( + "UPDATE {table_name} SET {} WHERE {}", + column_updates.join(", "), + pk_predicates.join(" AND ") + ), + Some(pk_column_indices), + ) + }; + Ok(ReplayInfo { + change_type: DatabaseChangeType::Update, + query, + pk_column_indices, + is_ddl_replay: false, + }) + } + pub(crate) async fn insert_query( + &self, + coro: &Coro, + table_name: &str, + columns: usize, + ) -> Result { + if !self.opts.use_implicit_rowid { + let placeholders = ["?"].repeat(columns).join(","); + let query = format!("INSERT INTO {table_name} VALUES ({placeholders})"); + return Ok(ReplayInfo { + change_type: DatabaseChangeType::Insert, + query, + pk_column_indices: None, + is_ddl_replay: false, + }); + }; + let mut table_info_stmt = self.conn.prepare(format!( + "SELECT name FROM pragma_table_info('{table_name}')" + ))?; + let mut column_names = Vec::with_capacity(columns + 1); + while let Some(column) = run_stmt_once(coro, &mut table_info_stmt).await? { + let turso_core::Value::Text(text) = column.get_value(0) else { + return Err(Error::DatabaseTapeError( + "unexpected column type for pragma_table_info query".to_string(), + )); + }; + column_names.push(text.to_string()); + } + column_names.push("rowid".to_string()); + + let placeholders = ["?"].repeat(columns + 1).join(","); + let column_names = column_names.join(", "); + let query = format!("INSERT INTO {table_name}({column_names}) VALUES ({placeholders})"); + Ok(ReplayInfo { + change_type: DatabaseChangeType::Insert, + query, + pk_column_indices: None, + is_ddl_replay: false, + }) + } + pub(crate) async fn delete_query(&self, coro: &Coro, table_name: &str) -> Result { + let (query, pk_column_indices) = if self.opts.use_implicit_rowid { + (format!("DELETE FROM {table_name} WHERE rowid = ?"), None) + } else { + let mut pk_info_stmt = self.conn.prepare(format!( + "SELECT cid, name FROM pragma_table_info('{table_name}') WHERE pk = 1" + ))?; + let mut pk_predicates = Vec::with_capacity(1); + let mut pk_column_indices = Vec::with_capacity(1); + while let Some(column) = run_stmt_once(coro, &mut pk_info_stmt).await? { + let turso_core::Value::Integer(column_id) = column.get_value(0) else { + return Err(Error::DatabaseTapeError( + "unexpected column type for pragma_table_info query".to_string(), + )); + }; + let turso_core::Value::Text(name) = column.get_value(1) else { + return Err(Error::DatabaseTapeError( + "unexpected column type for pragma_table_info query".to_string(), + )); + }; + pk_predicates.push(format!("{name} = ?")); + pk_column_indices.push(*column_id as usize); + } + + if pk_column_indices.is_empty() { + (format!("DELETE FROM {table_name} WHERE rowid = ?"), None) + } else { + let pk_predicates = pk_predicates.join(" AND "); + let query = format!("DELETE FROM {table_name} WHERE {pk_predicates}"); + (query, Some(pk_column_indices)) + } + }; + let use_implicit_rowid = self.opts.use_implicit_rowid; + tracing::trace!("delete_query: table_name={table_name}, query={query}, use_implicit_rowid={use_implicit_rowid}"); + Ok(ReplayInfo { + change_type: DatabaseChangeType::Delete, + query, + pk_column_indices, + is_ddl_replay: false, + }) + } +} diff --git a/packages/turso-sync-engine/src/database_sync_engine.rs b/packages/turso-sync-engine/src/database_sync_engine.rs index 1b8d53161..4c16970c7 100644 --- a/packages/turso-sync-engine/src/database_sync_engine.rs +++ b/packages/turso-sync-engine/src/database_sync_engine.rs @@ -2,9 +2,9 @@ use std::sync::Arc; use crate::{ database_sync_operations::{ - checkpoint_wal_file, connect, connect_untracked, db_bootstrap, reset_wal_file, - transfer_logical_changes, transfer_physical_changes, wait_full_body, wal_pull, wal_push, - WalPullResult, + checkpoint_wal_file, connect, connect_untracked, db_bootstrap, push_logical_changes, + reset_wal_file, transfer_logical_changes, transfer_physical_changes, wait_full_body, + wal_pull, wal_push, WalPullResult, }, database_tape::DatabaseTape, errors::Error, @@ -15,7 +15,7 @@ use crate::{ Result, }; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct DatabaseSyncEngineOpts { pub client_name: String, pub wal_pull_batch_size: u64, @@ -213,15 +213,19 @@ impl DatabaseSyncEngine { // so, we pass 'capture: false' as we don't need to preserve changes made to Synced WAL in turso_cdc let synced = self.io.open_tape(&self.synced_path, false)?; - // mark Synced as dirty as we will start transfer of logical changes there and if we will fail in the middle - we will need to cleanup Synced db self.synced_is_dirty = true; - let client_id = &self.meta().client_unique_id; - transfer_logical_changes(coro, &self.draft_tape, &synced, client_id, false).await?; - - self.push_synced_to_remote(coro).await?; + push_logical_changes( + coro, + self.protocol.as_ref(), + &self.draft_tape, + &synced, + client_id, + ) + .await?; self.reset_synced_if_dirty(coro).await?; + Ok(()) } @@ -328,6 +332,7 @@ impl DatabaseSyncEngine { } } + #[allow(dead_code)] async fn push_synced_to_remote(&mut self, coro: &Coro) -> Result<()> { tracing::info!( "push_synced_to_remote: draft={}, synced={}, id={}", @@ -447,969 +452,3 @@ impl DatabaseSyncEngine { self.meta.as_ref().expect("metadata must be set") } } - -#[cfg(test)] -pub mod tests { - use std::{collections::BTreeMap, sync::Arc}; - - use rand::RngCore; - use tokio::join; - - use crate::{ - database_sync_engine::DatabaseSyncEngineOpts, - errors::Error, - test_context::{FaultInjectionPlan, FaultInjectionStrategy, TestContext}, - test_protocol_io::TestProtocolIo, - test_sync_server::convert_rows, - tests::{deterministic_runtime, seed_u64, TestRunner}, - Result, - }; - - async fn query_rows(conn: &turso::Connection, sql: &str) -> Result>> { - let mut rows = conn.query(sql, ()).await?; - convert_rows(&mut rows).await - } - - #[test] - pub fn test_sync_single_db_simple() { - deterministic_runtime(async || { - let io: Arc = Arc::new(turso_core::PlatformIO::new().unwrap()); - let dir = tempfile::TempDir::new().unwrap(); - let server_path = dir.path().join("server.db"); - let ctx = Arc::new(TestContext::new(seed_u64())); - let protocol = TestProtocolIo::new(ctx.clone(), &server_path) - .await - .unwrap(); - let mut runner = TestRunner::new(ctx.clone(), io, protocol.clone()); - let local_path = dir.path().join("local.db").to_str().unwrap().to_string(); - let opts = DatabaseSyncEngineOpts { - client_name: "id-1".to_string(), - wal_pull_batch_size: 1, - }; - runner.init(&local_path, opts).await.unwrap(); - - protocol - .server - .execute("CREATE TABLE t(x INTEGER PRIMARY KEY)", ()) - .await - .unwrap(); - protocol - .server - .execute("INSERT INTO t VALUES (1)", ()) - .await - .unwrap(); - - let conn = runner.connect().await.unwrap(); - - // no table in schema before sync from remote (as DB was initialized when remote was empty) - assert!(matches!( - query_rows(&conn, "SELECT * FROM t").await, - Err(x) if x.to_string().contains("no such table: t") - )); - - // 1 rows synced - runner.pull().await.unwrap(); - assert_eq!( - query_rows(&conn, "SELECT * FROM t").await.unwrap(), - vec![vec![turso::Value::Integer(1)]] - ); - - protocol - .server - .execute("INSERT INTO t VALUES (2)", ()) - .await - .unwrap(); - - conn.execute("INSERT INTO t VALUES (3)", ()).await.unwrap(); - - // changes are synced from the remote - but remote changes are not propagated locally - runner.push().await.unwrap(); - assert_eq!( - query_rows(&conn, "SELECT * FROM t").await.unwrap(), - vec![ - vec![turso::Value::Integer(1)], - vec![turso::Value::Integer(3)] - ] - ); - - let server_db = protocol.server.db(); - let server_conn = server_db.connect().unwrap(); - assert_eq!( - convert_rows(&mut server_conn.query("SELECT * FROM t", ()).await.unwrap()) - .await - .unwrap(), - vec![ - vec![turso::Value::Integer(1)], - vec![turso::Value::Integer(2)], - vec![turso::Value::Integer(3)], - ] - ); - - conn.execute("INSERT INTO t VALUES (4)", ()).await.unwrap(); - runner.push().await.unwrap(); - assert_eq!( - query_rows(&conn, "SELECT * FROM t").await.unwrap(), - vec![ - vec![turso::Value::Integer(1)], - vec![turso::Value::Integer(3)], - vec![turso::Value::Integer(4)] - ] - ); - - assert_eq!( - convert_rows(&mut server_conn.query("SELECT * FROM t", ()).await.unwrap()) - .await - .unwrap(), - vec![ - vec![turso::Value::Integer(1)], - vec![turso::Value::Integer(2)], - vec![turso::Value::Integer(3)], - vec![turso::Value::Integer(4)], - ] - ); - - runner.pull().await.unwrap(); - assert_eq!( - query_rows(&conn, "SELECT * FROM t").await.unwrap(), - vec![ - vec![turso::Value::Integer(1)], - vec![turso::Value::Integer(2)], - vec![turso::Value::Integer(3)], - vec![turso::Value::Integer(4)] - ] - ); - - assert_eq!( - convert_rows(&mut server_conn.query("SELECT * FROM t", ()).await.unwrap()) - .await - .unwrap(), - vec![ - vec![turso::Value::Integer(1)], - vec![turso::Value::Integer(2)], - vec![turso::Value::Integer(3)], - vec![turso::Value::Integer(4)], - ] - ); - }); - } - - #[test] - pub fn test_sync_single_db_no_changes_no_push() { - deterministic_runtime(async || { - let io: Arc = Arc::new(turso_core::PlatformIO::new().unwrap()); - let dir = tempfile::TempDir::new().unwrap(); - let server_path = dir.path().join("server.db"); - let ctx = Arc::new(TestContext::new(seed_u64())); - let protocol = TestProtocolIo::new(ctx.clone(), &server_path) - .await - .unwrap(); - let mut runner = TestRunner::new(ctx.clone(), io, protocol.clone()); - let local_path = dir.path().join("local.db").to_str().unwrap().to_string(); - let opts = DatabaseSyncEngineOpts { - client_name: "id-1".to_string(), - wal_pull_batch_size: 1, - }; - runner.init(&local_path, opts).await.unwrap(); - - protocol - .server - .execute("CREATE TABLE t(x INTEGER PRIMARY KEY)", ()) - .await - .unwrap(); - protocol - .server - .execute("INSERT INTO t VALUES (1)", ()) - .await - .unwrap(); - - let conn = runner.connect().await.unwrap(); - - runner.sync().await.unwrap(); - assert_eq!( - query_rows(&conn, "SELECT * FROM t").await.unwrap(), - vec![vec![turso::Value::Integer(1)]] - ); - - conn.execute("INSERT INTO t VALUES (100)", ()) - .await - .unwrap(); - - protocol - .server - .execute("INSERT INTO t VALUES (2)", ()) - .await - .unwrap(); - - runner.sync().await.unwrap(); - assert_eq!( - query_rows(&conn, "SELECT * FROM t").await.unwrap(), - vec![ - vec![turso::Value::Integer(1)], - vec![turso::Value::Integer(2)], - vec![turso::Value::Integer(100)], - ] - ); - - protocol - .server - .execute("INSERT INTO t VALUES (3)", ()) - .await - .unwrap(); - runner.sync().await.unwrap(); - assert_eq!( - query_rows(&conn, "SELECT * FROM t").await.unwrap(), - vec![ - vec![turso::Value::Integer(1)], - vec![turso::Value::Integer(2)], - vec![turso::Value::Integer(3)], - vec![turso::Value::Integer(100)], - ] - ); - - ctx.switch_mode(FaultInjectionStrategy::Enabled { - plan: FaultInjectionPlan { - is_fault: Box::new(|name, _| Box::pin(async move { name == "wal_push_start" })), - }, - }) - .await; - - protocol - .server - .execute("INSERT INTO t VALUES (4)", ()) - .await - .unwrap(); - runner.sync().await.unwrap(); - assert_eq!( - query_rows(&conn, "SELECT * FROM t").await.unwrap(), - vec![ - vec![turso::Value::Integer(1)], - vec![turso::Value::Integer(2)], - vec![turso::Value::Integer(3)], - vec![turso::Value::Integer(4)], - vec![turso::Value::Integer(100)], - ] - ); - }); - } - - #[test] - pub fn test_sync_single_db_update_sync_concurrent() { - deterministic_runtime(async || { - let io: Arc = Arc::new(turso_core::MemoryIO::new()); - let dir = tempfile::TempDir::new().unwrap(); - let server_path = dir.path().join("server.db"); - let ctx = Arc::new(TestContext::new(seed_u64())); - let protocol = TestProtocolIo::new(ctx.clone(), &server_path) - .await - .unwrap(); - let mut runner = TestRunner::new(ctx.clone(), io, protocol.clone()); - let opts = DatabaseSyncEngineOpts { - client_name: "id-1".to_string(), - wal_pull_batch_size: 1, - }; - - protocol - .server - .execute("CREATE TABLE t(x TEXT PRIMARY KEY, y)", ()) - .await - .unwrap(); - protocol - .server - .execute("INSERT INTO t VALUES ('hello', 'world')", ()) - .await - .unwrap(); - - runner.init(":memory:", opts).await.unwrap(); - let conn = runner.connect().await.unwrap(); - - let syncs = async move { - for i in 0..10 { - tracing::info!("sync attempt #{i}"); - runner.sync().await.unwrap(); - } - }; - - let updates = async move { - for i in 0..10 { - tracing::info!("update attempt #{i}"); - let sql = format!("INSERT INTO t VALUES ('key-{i}', 'value-{i}')"); - match conn.execute(&sql, ()).await { - Ok(_) => {} - Err(err) if err.to_string().contains("database is locked") => {} - Err(err) => panic!("update failed: {err}"), - } - ctx.random_sleep_n(50).await; - } - }; - - join!(updates, syncs); - }); - } - - #[test] - pub fn test_sync_many_dbs_update_sync_concurrent() { - deterministic_runtime(async || { - let io: Arc = Arc::new(turso_core::MemoryIO::new()); - let dir = tempfile::TempDir::new().unwrap(); - let server_path = dir.path().join("server.db"); - let ctx = Arc::new(TestContext::new(seed_u64())); - let protocol = TestProtocolIo::new(ctx.clone(), &server_path) - .await - .unwrap(); - protocol - .server - .execute("CREATE TABLE t(x TEXT PRIMARY KEY, y)", ()) - .await - .unwrap(); - protocol - .server - .execute( - "INSERT INTO t VALUES ('id-1', 'client1'), ('id-2', 'client2')", - (), - ) - .await - .unwrap(); - let mut runner1 = TestRunner::new(ctx.clone(), io.clone(), protocol.clone()); - runner1 - .init( - ":memory:-1", - DatabaseSyncEngineOpts { - client_name: "id-1".to_string(), - wal_pull_batch_size: 2, - }, - ) - .await - .unwrap(); - let mut runner2 = TestRunner::new(ctx.clone(), io.clone(), protocol.clone()); - runner2 - .init( - ":memory:-2", - DatabaseSyncEngineOpts { - client_name: "id-2".to_string(), - wal_pull_batch_size: 2, - }, - ) - .await - .unwrap(); - - let conn1 = runner1.connect().await.unwrap(); - let conn2 = runner2.connect().await.unwrap(); - - let syncs1 = async move { - for i in 0..10 { - tracing::info!("sync attempt #{i}"); - match runner1.sync().await { - Ok(()) | Err(Error::DatabaseSyncEngineConflict(..)) => continue, - Err(err) => panic!("unexpected error: {err}"), - } - } - }; - - let syncs2 = async move { - for i in 0..10 { - tracing::info!("sync attempt #{i}"); - match runner2.sync().await { - Ok(()) | Err(Error::DatabaseSyncEngineConflict(..)) => continue, - Err(err) => panic!("unexpected error: {err}"), - } - } - }; - - let ctx1 = ctx.clone(); - let updates1 = async move { - for i in 0..100 { - tracing::info!("update attempt #{i}"); - let sql = format!("INSERT INTO t VALUES ('key-1-{i}', 'value')"); - match conn1.execute(&sql, ()).await { - Ok(_) => {} - Err(err) if err.to_string().contains("database is locked") => {} - Err(err) => panic!("update failed: {err}"), - } - ctx1.random_sleep_n(10).await; - } - }; - - let ctx2 = ctx.clone(); - let updates2 = async move { - for i in 0..100 { - tracing::info!("update attempt #{i}"); - let sql = format!("INSERT INTO t VALUES ('key-2-{i}', 'value')"); - match conn2.execute(&sql, ()).await { - Ok(_) => {} - Err(err) if err.to_string().contains("database is locked") => {} - Err(err) => panic!("update failed: {err}"), - } - ctx2.random_sleep_n(10).await; - } - }; - - join!(updates1, updates2, syncs1, syncs2); - }); - } - - #[test] - pub fn test_sync_single_db_many_pulls_big_payloads() { - deterministic_runtime(async || { - let io: Arc = Arc::new(turso_core::PlatformIO::new().unwrap()); - let dir = tempfile::TempDir::new().unwrap(); - let server_path = dir.path().join("server.db"); - let ctx = Arc::new(TestContext::new(seed_u64())); - let protocol = TestProtocolIo::new(ctx.clone(), &server_path) - .await - .unwrap(); - let mut runner = TestRunner::new(ctx.clone(), io, protocol.clone()); - let local_path = dir.path().join("local.db").to_str().unwrap().to_string(); - let opts = DatabaseSyncEngineOpts { - client_name: "id-1".to_string(), - wal_pull_batch_size: 1, - }; - - runner.init(&local_path, opts).await.unwrap(); - - protocol - .server - .execute("CREATE TABLE t(x INTEGER PRIMARY KEY, y)", ()) - .await - .unwrap(); - - runner.sync().await.unwrap(); - - // create connection in outer scope in order to prevent Database from being dropped in between of pull operations - let conn = runner.connect().await.unwrap(); - - let mut expected = BTreeMap::new(); - for attempt in 0..10 { - for _ in 0..5 { - let key = ctx.rng().await.next_u32(); - let length = ctx.rng().await.next_u32() % (10 * 4096); - protocol - .server - .execute("INSERT INTO t VALUES (?, randomblob(?))", (key, length)) - .await - .unwrap(); - expected.insert(key as i64, length as i64); - } - - tracing::info!("pull attempt={}", attempt); - runner.sync().await.unwrap(); - - let expected = expected - .iter() - .map(|(x, y)| vec![turso::Value::Integer(*x), turso::Value::Integer(*y)]) - .collect::>(); - assert_eq!( - query_rows(&conn, "SELECT x, length(y) FROM t") - .await - .unwrap(), - expected - ); - } - }); - } - - #[test] - pub fn test_sync_single_db_checkpoint() { - deterministic_runtime(async || { - let io: Arc = Arc::new(turso_core::PlatformIO::new().unwrap()); - let dir = tempfile::TempDir::new().unwrap(); - let server_path = dir.path().join("server.db"); - let ctx = Arc::new(TestContext::new(seed_u64())); - let protocol = TestProtocolIo::new(ctx.clone(), &server_path) - .await - .unwrap(); - let mut runner = TestRunner::new(ctx.clone(), io, protocol.clone()); - let local_path = dir.path().join("local.db").to_str().unwrap().to_string(); - let opts = DatabaseSyncEngineOpts { - client_name: "id-1".to_string(), - wal_pull_batch_size: 1, - }; - runner.init(&local_path, opts).await.unwrap(); - - protocol - .server - .execute("CREATE TABLE t(x INTEGER PRIMARY KEY, y)", ()) - .await - .unwrap(); - protocol - .server - .execute("INSERT INTO t VALUES (1, randomblob(4 * 4096))", ()) - .await - .unwrap(); - protocol.server.checkpoint().await.unwrap(); - protocol - .server - .execute("INSERT INTO t VALUES (2, randomblob(5 * 4096))", ()) - .await - .unwrap(); - protocol.server.checkpoint().await.unwrap(); - protocol - .server - .execute("INSERT INTO t VALUES (3, randomblob(6 * 4096))", ()) - .await - .unwrap(); - protocol - .server - .execute("INSERT INTO t VALUES (4, randomblob(7 * 4096))", ()) - .await - .unwrap(); - - let conn = runner.connect().await.unwrap(); - - runner.pull().await.unwrap(); - - assert_eq!( - query_rows(&conn, "SELECT x, length(y) FROM t") - .await - .unwrap(), - vec![ - vec![turso::Value::Integer(1), turso::Value::Integer(4 * 4096)], - vec![turso::Value::Integer(2), turso::Value::Integer(5 * 4096)], - vec![turso::Value::Integer(3), turso::Value::Integer(6 * 4096)], - vec![turso::Value::Integer(4), turso::Value::Integer(7 * 4096)], - ] - ); - }); - } - - #[test] - pub fn test_sync_single_db_full_syncs() { - deterministic_runtime(async || { - let io: Arc = Arc::new(turso_core::PlatformIO::new().unwrap()); - let dir = tempfile::TempDir::new().unwrap(); - let server_path = dir.path().join("server.db"); - let ctx = Arc::new(TestContext::new(seed_u64())); - let server = TestProtocolIo::new(ctx.clone(), &server_path) - .await - .unwrap(); - let mut runner = TestRunner::new(ctx.clone(), io.clone(), server.clone()); - let local_path = dir.path().join("local.db").to_str().unwrap().to_string(); - let opts = DatabaseSyncEngineOpts { - client_name: "id-1".to_string(), - wal_pull_batch_size: 1, - }; - runner.init(&local_path, opts).await.unwrap(); - - server - .server - .execute("CREATE TABLE t(x INTEGER PRIMARY KEY)", ()) - .await - .unwrap(); - server - .server - .execute("INSERT INTO t VALUES (1)", ()) - .await - .unwrap(); - - let conn = runner.connect().await.unwrap(); - - // no table in schema before sync from remote (as DB was initialized when remote was empty) - assert!(matches!( - query_rows(&conn, "SELECT * FROM t").await, - Err(x) if x.to_string().contains("no such table: t") - )); - - runner.sync().await.unwrap(); - - assert_eq!( - query_rows(&conn, "SELECT * FROM t").await.unwrap(), - vec![vec![turso::Value::Integer(1)]] - ); - - conn.execute("INSERT INTO t VALUES (2)", ()).await.unwrap(); - runner.sync().await.unwrap(); - - assert_eq!( - query_rows(&conn, "SELECT * FROM t").await.unwrap(), - vec![ - vec![turso::Value::Integer(1)], - vec![turso::Value::Integer(2)] - ] - ); - - conn.execute("INSERT INTO t VALUES (3)", ()).await.unwrap(); - runner.sync().await.unwrap(); - assert_eq!( - query_rows(&conn, "SELECT * FROM t").await.unwrap(), - vec![ - vec![turso::Value::Integer(1)], - vec![turso::Value::Integer(2)], - vec![turso::Value::Integer(3)] - ] - ); - }); - } - - #[test] - pub fn test_sync_multiple_dbs_conflict() { - deterministic_runtime(async || { - let dir = tempfile::TempDir::new().unwrap(); - let server_path = dir.path().join("server.db"); - let ctx = Arc::new(TestContext::new(seed_u64())); - let io: Arc = Arc::new(turso_core::PlatformIO::new().unwrap()); - let protocol = TestProtocolIo::new(ctx.clone(), &server_path) - .await - .unwrap(); - let mut dbs = Vec::new(); - const CLIENTS: usize = 8; - for i in 0..CLIENTS { - let mut runner = TestRunner::new(ctx.clone(), io.clone(), protocol.clone()); - let local_path = dir - .path() - .join(format!("local-{i}.db")) - .to_str() - .unwrap() - .to_string(); - let opts = DatabaseSyncEngineOpts { - client_name: format!("id-{i}"), - wal_pull_batch_size: 1, - }; - runner.init(&local_path, opts).await.unwrap(); - dbs.push(runner); - } - - protocol - .server - .execute("CREATE TABLE t(x INTEGER PRIMARY KEY)", ()) - .await - .unwrap(); - - for db in &mut dbs { - db.pull().await.unwrap(); - } - for (i, db) in dbs.iter().enumerate() { - let conn = db.connect().await.unwrap(); - conn.execute("INSERT INTO t VALUES (?)", (i as i32,)) - .await - .unwrap(); - } - - let try_sync = || async { - let mut tasks = Vec::new(); - for db in &dbs { - tasks.push(async move { db.push().await }); - } - futures::future::join_all(tasks).await - }; - for attempt in 0..CLIENTS { - let results = try_sync().await; - tracing::info!("attempt #{}: {:?}", attempt, results); - assert!(results.iter().filter(|x| x.is_ok()).count() > attempt); - } - }); - } - - #[test] - pub fn test_sync_multiple_clients_no_conflicts_synchronized() { - deterministic_runtime(async || { - let dir = tempfile::TempDir::new().unwrap(); - let server_path = dir.path().join("server.db"); - let io: Arc = Arc::new(turso_core::PlatformIO::new().unwrap()); - let ctx = Arc::new(TestContext::new(seed_u64())); - let protocol = TestProtocolIo::new(ctx.clone(), &server_path) - .await - .unwrap(); - - protocol - .server - .execute("CREATE TABLE t(k INTEGER PRIMARY KEY, v)", ()) - .await - .unwrap(); - - let sync_lock = Arc::new(tokio::sync::Mutex::new(())); - let mut clients = Vec::new(); - const CLIENTS: usize = 10; - let mut expected_rows = Vec::new(); - for i in 0..CLIENTS { - let mut queries = Vec::new(); - let cnt = ctx.rng().await.next_u32() % CLIENTS as u32 + 1; - for q in 0..cnt { - let key = i * CLIENTS + q as usize; - let length = ctx.rng().await.next_u32() % 4096; - queries.push(format!( - "INSERT INTO t VALUES ({key}, randomblob({length}))", - )); - expected_rows.push(vec![ - turso::Value::Integer(key as i64), - turso::Value::Integer(length as i64), - ]); - } - clients.push({ - let io = io.clone(); - let dir = dir.path().to_path_buf().clone(); - let ctx = ctx.clone(); - let server = protocol.clone(); - let sync_lock = sync_lock.clone(); - async move { - let mut runner = TestRunner::new(ctx.clone(), io.clone(), server.clone()); - let local_path = dir - .join(format!("local-{i}.db")) - .to_str() - .unwrap() - .to_string(); - let opts = DatabaseSyncEngineOpts { - client_name: format!("id-{i}"), - wal_pull_batch_size: 1, - }; - runner.init(&local_path, opts).await.unwrap(); - runner.pull().await.unwrap(); - let conn = runner.connect().await.unwrap(); - for query in queries { - conn.execute(&query, ()).await.unwrap(); - } - let guard = sync_lock.lock().await; - runner.push().await.unwrap(); - drop(guard); - } - }); - } - for client in clients { - client.await; - } - let db = protocol.server.db(); - let conn = db.connect().unwrap(); - let mut result = conn.query("SELECT k, length(v) FROM t", ()).await.unwrap(); - let rows = convert_rows(&mut result).await.unwrap(); - assert_eq!(rows, expected_rows); - }); - } - - #[test] - pub fn test_sync_single_db_sync_from_remote_nothing_single_failure() { - deterministic_runtime(async || { - let dir = tempfile::TempDir::new().unwrap(); - let server_path = dir.path().join("server.db"); - let io: Arc = Arc::new(turso_core::PlatformIO::new().unwrap()); - let ctx = Arc::new(TestContext::new(seed_u64())); - let protocol = TestProtocolIo::new(ctx.clone(), &server_path) - .await - .unwrap(); - - protocol - .server - .execute("CREATE TABLE t(x)", ()) - .await - .unwrap(); - protocol - .server - .execute("INSERT INTO t VALUES (1), (2), (3)", ()) - .await - .unwrap(); - - let mut session = ctx.fault_session(); - let mut it = 0; - while let Some(strategy) = session.next().await { - it += 1; - - let mut runner = TestRunner::new(ctx.clone(), io.clone(), protocol.clone()); - let local_path = dir - .path() - .join(format!("local-{it}.db")) - .to_str() - .unwrap() - .to_string(); - let opts = DatabaseSyncEngineOpts { - client_name: format!("id-{it}"), - wal_pull_batch_size: 1, - }; - runner.init(&local_path, opts).await.unwrap(); - - let has_fault = matches!(strategy, FaultInjectionStrategy::Enabled { .. }); - - ctx.switch_mode(strategy).await; - let result = runner.pull().await; - ctx.switch_mode(FaultInjectionStrategy::Disabled).await; - - if !has_fault { - result.unwrap(); - } else { - let err = result.err().unwrap(); - tracing::info!("error after fault injection: {}", err); - } - - let conn = runner.connect().await.unwrap(); - let rows = query_rows(&conn, "SELECT * FROM t").await.unwrap(); - assert_eq!( - rows, - vec![ - vec![turso::Value::Integer(1)], - vec![turso::Value::Integer(2)], - vec![turso::Value::Integer(3)], - ] - ); - - runner.pull().await.unwrap(); - - let rows = query_rows(&conn, "SELECT * FROM t").await.unwrap(); - assert_eq!( - rows, - vec![ - vec![turso::Value::Integer(1)], - vec![turso::Value::Integer(2)], - vec![turso::Value::Integer(3)], - ] - ); - } - }); - } - - #[test] - pub fn test_sync_single_db_sync_from_remote_single_failure() { - deterministic_runtime(async || { - let dir = tempfile::TempDir::new().unwrap(); - let io: Arc = Arc::new(turso_core::PlatformIO::new().unwrap()); - let ctx = Arc::new(TestContext::new(seed_u64())); - - let mut session = ctx.fault_session(); - let mut it = 0; - while let Some(strategy) = session.next().await { - it += 1; - - let server_path = dir.path().join(format!("server-{it}.db")); - let protocol = TestProtocolIo::new(ctx.clone(), &server_path) - .await - .unwrap(); - - protocol - .server - .execute("CREATE TABLE t(x)", ()) - .await - .unwrap(); - - let mut runner = TestRunner::new(ctx.clone(), io.clone(), protocol.clone()); - let local_path = dir - .path() - .join(format!("local-{it}.db")) - .to_str() - .unwrap() - .to_string(); - let opts = DatabaseSyncEngineOpts { - client_name: format!("id-{it}"), - wal_pull_batch_size: 1, - }; - runner.init(&local_path, opts).await.unwrap(); - - protocol - .server - .execute("INSERT INTO t VALUES (1), (2), (3)", ()) - .await - .unwrap(); - - let has_fault = matches!(strategy, FaultInjectionStrategy::Enabled { .. }); - - ctx.switch_mode(strategy).await; - let result = runner.pull().await; - ctx.switch_mode(FaultInjectionStrategy::Disabled).await; - - if !has_fault { - result.unwrap(); - } else { - let err = result.err().unwrap(); - tracing::info!("error after fault injection: {}", err); - } - - let conn = runner.connect().await.unwrap(); - let rows = query_rows(&conn, "SELECT * FROM t").await.unwrap(); - assert!(rows.len() <= 3); - - runner.pull().await.unwrap(); - - let rows = query_rows(&conn, "SELECT * FROM t").await.unwrap(); - assert_eq!( - rows, - vec![ - vec![turso::Value::Integer(1)], - vec![turso::Value::Integer(2)], - vec![turso::Value::Integer(3)], - ] - ); - } - }); - } - - #[test] - pub fn test_sync_single_db_sync_to_remote_single_failure() { - deterministic_runtime(async || { - let dir = tempfile::TempDir::new().unwrap(); - let io: Arc = Arc::new(turso_core::PlatformIO::new().unwrap()); - let ctx = Arc::new(TestContext::new(seed_u64())); - - let mut session = ctx.fault_session(); - let mut it = 0; - while let Some(strategy) = session.next().await { - it += 1; - - let server_path = dir.path().join(format!("server-{it}.db")); - let protocol = TestProtocolIo::new(ctx.clone(), &server_path) - .await - .unwrap(); - - protocol - .server - .execute("CREATE TABLE t(x INTEGER PRIMARY KEY)", ()) - .await - .unwrap(); - - protocol - .server - .execute("INSERT INTO t VALUES (1)", ()) - .await - .unwrap(); - - let mut runner = TestRunner::new(ctx.clone(), io.clone(), protocol.clone()); - let local_path = dir - .path() - .join(format!("local-{it}.db")) - .to_str() - .unwrap() - .to_string(); - let opts = DatabaseSyncEngineOpts { - client_name: format!("id-{it}"), - wal_pull_batch_size: 1, - }; - runner.init(&local_path, opts).await.unwrap(); - - let conn = runner.connect().await.unwrap(); - - conn.execute("INSERT INTO t VALUES (2), (3)", ()) - .await - .unwrap(); - - let has_fault = matches!(strategy, FaultInjectionStrategy::Enabled { .. }); - - ctx.switch_mode(strategy).await; - let result = runner.push().await; - ctx.switch_mode(FaultInjectionStrategy::Disabled).await; - - if !has_fault { - result.unwrap(); - } else { - let err = result.err().unwrap(); - tracing::info!("error after fault injection: {}", err); - } - - let server_db = protocol.server.db(); - let server_conn = server_db.connect().unwrap(); - let rows = - convert_rows(&mut server_conn.query("SELECT * FROM t", ()).await.unwrap()) - .await - .unwrap(); - assert!(rows.len() <= 3); - - runner.push().await.unwrap(); - - let rows = - convert_rows(&mut server_conn.query("SELECT * FROM t", ()).await.unwrap()) - .await - .unwrap(); - assert_eq!( - rows, - vec![ - vec![turso::Value::Integer(1)], - vec![turso::Value::Integer(2)], - vec![turso::Value::Integer(3)], - ] - ); - } - }); - } -} diff --git a/packages/turso-sync-engine/src/database_sync_operations.rs b/packages/turso-sync-engine/src/database_sync_operations.rs index 683a7bbee..3851db953 100644 --- a/packages/turso-sync-engine/src/database_sync_operations.rs +++ b/packages/turso-sync-engine/src/database_sync_operations.rs @@ -3,13 +3,18 @@ use std::sync::Arc; use turso_core::{types::Text, Buffer, Completion, LimboError, Value}; use crate::{ + database_replay_generator::DatabaseReplayGenerator, database_tape::{ exec_stmt, run_stmt_expect_one_row, DatabaseChangesIteratorMode, DatabaseChangesIteratorOpts, DatabaseReplaySessionOpts, DatabaseTape, DatabaseWalSession, }, errors::Error, protocol_io::{DataCompletion, DataPollResult, ProtocolIO}, - types::{Coro, DatabaseTapeOperation, DbSyncInfo, DbSyncStatus, ProtocolCommand}, + server_proto::{self, ExecuteStreamReq, Stmt, StreamRequest}, + types::{ + Coro, DatabaseTapeOperation, DatabaseTapeRowChangeType, DbSyncInfo, DbSyncStatus, + ProtocolCommand, + }, wal_session::WalSession, Result, }; @@ -263,6 +268,8 @@ const TURSO_SYNC_INSERT_LAST_CHANGE_ID: &str = "INSERT INTO turso_sync_last_change_id(client_id, pull_gen, change_id) VALUES (?, 0, 0)"; const TURSO_SYNC_UPDATE_LAST_CHANGE_ID: &str = "UPDATE turso_sync_last_change_id SET pull_gen = ?, change_id = ? WHERE client_id = ?"; +const TURSO_SYNC_UPSERT_LAST_CHANGE_ID: &str = + "INSERT INTO turso_sync_last_change_id(client_id, pull_gen, change_id) VALUES (?, ?, ?) ON CONFLICT(client_id) DO UPDATE SET pull_gen=excluded.pull_gen, change_id=excluded.change_id"; /// Transfers row changes from source DB to target DB /// In order to guarantee atomicity and avoid conflicts - method maintain last_change_id counter in the target db table turso_sync_last_change_id @@ -419,6 +426,307 @@ pub async fn transfer_logical_changes( Ok(()) } +fn convert_to_args(values: Vec) -> Vec { + values + .into_iter() + .map(|value| match value { + Value::Null => server_proto::Value::Null, + Value::Integer(value) => server_proto::Value::Integer { value }, + Value::Float(value) => server_proto::Value::Float { value }, + Value::Text(value) => server_proto::Value::Text { + value: value.as_str().to_string(), + }, + Value::Blob(value) => server_proto::Value::Blob { + value: value.into(), + }, + }) + .collect() +} + +pub async fn push_logical_changes( + coro: &Coro, + client: &C, + source: &DatabaseTape, + target: &DatabaseTape, + client_id: &str, +) -> Result<()> { + tracing::info!("push_logical_changes: client_id={client_id}"); + let source_conn = connect_untracked(source)?; + let target_conn = connect_untracked(target)?; + + // fetch last_change_id from the target DB in order to guarantee atomic replay of changes and avoid conflicts in case of failure + let source_pull_gen = 'source_pull_gen: { + let mut select_last_change_id_stmt = + match source_conn.prepare(TURSO_SYNC_SELECT_LAST_CHANGE_ID) { + Ok(stmt) => stmt, + Err(LimboError::ParseError(..)) => break 'source_pull_gen 0, + Err(err) => return Err(err.into()), + }; + + select_last_change_id_stmt + .bind_at(1.try_into().unwrap(), Value::Text(Text::new(client_id))); + + match run_stmt_expect_one_row(coro, &mut select_last_change_id_stmt).await? { + Some(row) => row[0].as_int().ok_or_else(|| { + Error::DatabaseSyncEngineError("unexpected source pull_gen type".to_string()) + })?, + None => { + tracing::info!("push_logical_changes: client_id={client_id}, turso_sync_last_change_id table is not found"); + 0 + } + } + }; + tracing::info!( + "push_logical_changes: client_id={client_id}, source_pull_gen={source_pull_gen}" + ); + + // fetch last_change_id from the target DB in order to guarantee atomic replay of changes and avoid conflicts in case of failure + let mut schema_stmt = target_conn.prepare(TURSO_SYNC_CREATE_TABLE)?; + exec_stmt(coro, &mut schema_stmt).await?; + + let mut select_last_change_id_stmt = target_conn.prepare(TURSO_SYNC_SELECT_LAST_CHANGE_ID)?; + select_last_change_id_stmt.bind_at(1.try_into().unwrap(), Value::Text(Text::new(client_id))); + + let mut last_change_id = match run_stmt_expect_one_row(coro, &mut select_last_change_id_stmt) + .await? + { + Some(row) => { + let target_pull_gen = row[0].as_int().ok_or_else(|| { + Error::DatabaseSyncEngineError("unexpected target pull_gen type".to_string()) + })?; + let target_change_id = row[1].as_int().ok_or_else(|| { + Error::DatabaseSyncEngineError("unexpected target change_id type".to_string()) + })?; + tracing::debug!( + "push_logical_changes: client_id={client_id}, target_pull_gen={target_pull_gen}, target_change_id={target_change_id}" + ); + if target_pull_gen > source_pull_gen { + return Err(Error::DatabaseSyncEngineError(format!("protocol error: target_pull_gen > source_pull_gen: {target_pull_gen} > {source_pull_gen}"))); + } + if target_pull_gen == source_pull_gen { + Some(target_change_id) + } else { + Some(0) + } + } + None => { + let mut insert_last_change_id_stmt = + target_conn.prepare(TURSO_SYNC_INSERT_LAST_CHANGE_ID)?; + insert_last_change_id_stmt + .bind_at(1.try_into().unwrap(), Value::Text(Text::new(client_id))); + exec_stmt(coro, &mut insert_last_change_id_stmt).await?; + None + } + }; + + tracing::debug!("push_logical_changes: last_change_id={:?}", last_change_id); + let replay_opts = DatabaseReplaySessionOpts { + use_implicit_rowid: false, + }; + + let conn = connect_untracked(target)?; + let generator = DatabaseReplayGenerator::new(conn, replay_opts); + + let iterate_opts = DatabaseChangesIteratorOpts { + first_change_id: last_change_id.map(|x| x + 1), + mode: DatabaseChangesIteratorMode::Apply, + ignore_schema_changes: false, + ..Default::default() + }; + let mut sql_over_http_requests = vec![ + Stmt { + sql: Some("BEGIN IMMEDIATE".to_string()), + sql_id: None, + args: Vec::new(), + named_args: Vec::new(), + want_rows: Some(false), + replication_index: None, + }, + Stmt { + sql: Some(TURSO_SYNC_CREATE_TABLE.to_string()), + sql_id: None, + args: Vec::new(), + named_args: Vec::new(), + want_rows: Some(false), + replication_index: None, + }, + ]; + let mut rows_changed = 0; + let mut changes = source.iterate_changes(iterate_opts)?; + while let Some(operation) = changes.next(coro).await? { + match operation { + DatabaseTapeOperation::RowChange(change) => { + assert!( + last_change_id.is_none() || last_change_id.unwrap() < change.change_id, + "change id must be strictly increasing: last_change_id={:?}, change.change_id={}", + last_change_id, + change.change_id + ); + if change.table_name == TURSO_SYNC_TABLE_NAME { + continue; + } + rows_changed += 1; + // we give user full control over CDC table - so let's not emit assert here for now + if last_change_id.is_some() && last_change_id.unwrap() + 1 != change.change_id { + tracing::warn!( + "out of order change sequence: {} -> {}", + last_change_id.unwrap(), + change.change_id + ); + } + last_change_id = Some(change.change_id); + let replay_info = generator.replay_info(coro, &change).await?; + let change_type = (&change.change).into(); + match change.change { + DatabaseTapeRowChangeType::Delete { before } => { + assert!(replay_info.len() == 1); + let values = generator.replay_values( + &replay_info[0], + change_type, + change.id, + before, + None, + ); + sql_over_http_requests.push(Stmt { + sql: Some(replay_info[0].query.clone()), + sql_id: None, + args: convert_to_args(values), + named_args: Vec::new(), + want_rows: Some(false), + replication_index: None, + }) + } + DatabaseTapeRowChangeType::Insert { after } => { + assert!(replay_info.len() == 1); + let values = generator.replay_values( + &replay_info[0], + change_type, + change.id, + after, + None, + ); + sql_over_http_requests.push(Stmt { + sql: Some(replay_info[0].query.clone()), + sql_id: None, + args: convert_to_args(values), + named_args: Vec::new(), + want_rows: Some(false), + replication_index: None, + }) + } + DatabaseTapeRowChangeType::Update { + after, + updates: Some(updates), + .. + } => { + assert!(replay_info.len() == 1); + let values = generator.replay_values( + &replay_info[0], + change_type, + change.id, + after, + Some(updates), + ); + sql_over_http_requests.push(Stmt { + sql: Some(replay_info[0].query.clone()), + sql_id: None, + args: convert_to_args(values), + named_args: Vec::new(), + want_rows: Some(false), + replication_index: None, + }) + } + DatabaseTapeRowChangeType::Update { + before, + after, + updates: None, + } => { + assert!(replay_info.len() == 2); + let values = generator.replay_values( + &replay_info[0], + change_type, + change.id, + before, + None, + ); + sql_over_http_requests.push(Stmt { + sql: Some(replay_info[0].query.clone()), + sql_id: None, + args: convert_to_args(values), + named_args: Vec::new(), + want_rows: Some(false), + replication_index: None, + }); + let values = generator.replay_values( + &replay_info[1], + change_type, + change.id, + after, + None, + ); + sql_over_http_requests.push(Stmt { + sql: Some(replay_info[1].query.clone()), + sql_id: None, + args: convert_to_args(values), + named_args: Vec::new(), + want_rows: Some(false), + replication_index: None, + }); + } + } + } + DatabaseTapeOperation::Commit => { + if rows_changed > 0 { + tracing::info!("prepare update stmt for turso_sync_last_change_id table with client_id={} and last_change_id={:?}", client_id, last_change_id); + // update turso_sync_last_change_id table with new value before commit + let (next_pull_gen, next_change_id) = + (source_pull_gen, last_change_id.unwrap_or(0)); + tracing::info!("transfer_logical_changes: client_id={client_id}, set pull_gen={next_pull_gen}, change_id={next_change_id}, rows_changed={rows_changed}"); + sql_over_http_requests.push(Stmt { + sql: Some(TURSO_SYNC_UPSERT_LAST_CHANGE_ID.to_string()), + sql_id: None, + args: vec![ + server_proto::Value::Text { + value: client_id.to_string(), + }, + server_proto::Value::Integer { + value: next_pull_gen, + }, + server_proto::Value::Integer { + value: next_change_id, + }, + ], + named_args: Vec::new(), + want_rows: Some(false), + replication_index: None, + }); + } + sql_over_http_requests.push(Stmt { + sql: Some("COMMIT".to_string()), + sql_id: None, + args: Vec::new(), + named_args: Vec::new(), + want_rows: Some(false), + replication_index: None, + }); + } + } + } + + tracing::debug!("hrana request: {:?}", sql_over_http_requests); + let request = server_proto::PipelineReqBody { + baton: None, + requests: sql_over_http_requests + .into_iter() + .map(|stmt| StreamRequest::Execute(ExecuteStreamReq { stmt })) + .collect(), + }; + + sql_execute_http(coro, client, request).await?; + tracing::info!("push_logical_changes: rows_changed={:?}", rows_changed); + Ok(()) +} + /// Replace WAL frames [target_wal_match_watermark..) in the target DB with frames [source_wal_match_watermark..) from source DB /// Return the position in target DB wal which logically equivalent to the source_sync_watermark in the source DB WAL pub async fn transfer_physical_changes( @@ -451,6 +759,8 @@ pub async fn transfer_physical_changes( let target_sync_watermark = { let mut target_session = DatabaseWalSession::new(coro, target_session).await?; + tracing::info!("rollback_changes_after: {target_wal_match_watermark}"); + target_session.rollback_changes_after(target_wal_match_watermark)?; let mut last_frame_info = None; let mut frame = vec![0u8; WAL_FRAME_SIZE]; @@ -462,7 +772,7 @@ pub async fn transfer_physical_changes( ); for source_frame_no in source_wal_match_watermark + 1..=source_frames_count { let frame_info = source_conn.wal_get_frame(source_frame_no, &mut frame)?; - tracing::trace!("append page {} to target DB", frame_info.page_no); + tracing::debug!("append page {} to target DB", frame_info.page_no); target_session.append_page(frame_info.page_no, &frame[WAL_FRAME_HEADER..])?; if source_frame_no == source_sync_watermark { target_sync_watermark = target_session.frames_count()? + 1; // +1 because page will be actually commited on next iteration @@ -471,7 +781,7 @@ pub async fn transfer_physical_changes( last_frame_info = Some(frame_info); } let db_size = last_frame_info.unwrap().db_size; - tracing::trace!("commit WAL session to target with db_size={db_size}"); + tracing::debug!("commit WAL session to target with db_size={db_size}"); target_session.commit(db_size)?; assert!(target_sync_watermark != 0); target_sync_watermark @@ -518,6 +828,30 @@ pub async fn reset_wal_file( Ok(()) } +async fn sql_execute_http( + coro: &Coro, + client: &C, + request: server_proto::PipelineReqBody, +) -> Result<()> { + let body = serde_json::to_vec(&request)?; + let completion = client.http("POST", "/v2/pipeline", Some(body))?; + let status = wait_status(coro, &completion).await?; + if status != http::StatusCode::OK { + let error = format!("sql_execute_http: unexpected status code: {status}"); + return Err(Error::DatabaseSyncEngineError(error)); + } + let response = wait_full_body(coro, &completion).await?; + let response: server_proto::PipelineRespBody = serde_json::from_slice(&response)?; + for result in response.results { + if let server_proto::StreamResult::Error { error } = result { + return Err(Error::DatabaseSyncEngineError(format!( + "failed to execute sql: {error:?}" + ))); + } + } + Ok(()) +} + async fn wal_pull_http( coro: &Coro, client: &C, diff --git a/packages/turso-sync-engine/src/database_tape.rs b/packages/turso-sync-engine/src/database_tape.rs index bec747df8..310327a03 100644 --- a/packages/turso-sync-engine/src/database_tape.rs +++ b/packages/turso-sync-engine/src/database_tape.rs @@ -6,6 +6,7 @@ use std::{ use turso_core::{types::WalFrameInfo, StepResult}; use crate::{ + database_replay_generator::{DatabaseReplayGenerator, ReplayInfo}, database_sync_operations::WAL_FRAME_HEADER, errors::Error, types::{ @@ -170,12 +171,12 @@ impl DatabaseTape { let conn = self.connect(coro).await?; conn.execute("BEGIN IMMEDIATE")?; Ok(DatabaseReplaySession { - conn, + conn: conn.clone(), cached_delete_stmt: HashMap::new(), cached_insert_stmt: HashMap::new(), cached_update_stmt: HashMap::new(), in_txn: true, - opts, + generator: DatabaseReplayGenerator { conn, opts }, }) } } @@ -414,23 +415,31 @@ pub struct DatabaseReplaySessionOpts { pub use_implicit_rowid: bool, } -struct DeleteCachedStmt { +struct CachedStmt { stmt: turso_core::Statement, - pk_column_indices: Option>, // if None - use rowid instead -} - -struct UpdateCachedStmt { - stmt: turso_core::Statement, - pk_column_indices: Option>, // if None - use rowid instead + info: ReplayInfo, } pub struct DatabaseReplaySession { conn: Arc, - cached_delete_stmt: HashMap, - cached_insert_stmt: HashMap<(String, usize), turso_core::Statement>, - cached_update_stmt: HashMap<(String, Vec), UpdateCachedStmt>, + cached_delete_stmt: HashMap, + cached_insert_stmt: HashMap<(String, usize), CachedStmt>, + cached_update_stmt: HashMap<(String, Vec), CachedStmt>, in_txn: bool, - opts: DatabaseReplaySessionOpts, + generator: DatabaseReplayGenerator, +} + +async fn replay_stmt( + coro: &Coro, + cached: &mut CachedStmt, + values: Vec, +) -> Result<()> { + cached.stmt.reset(); + for (i, value) in values.into_iter().enumerate() { + cached.stmt.bind_at((i + 1).try_into().unwrap(), value); + } + exec_stmt(coro, &mut cached.stmt).await?; + Ok(()) } impl DatabaseReplaySession { @@ -452,107 +461,116 @@ impl DatabaseReplaySession { self.conn.execute("BEGIN IMMEDIATE")?; self.in_txn = true; } - tracing::trace!("replay: change={:?}", change); - let table_name = &change.table_name; + let table = &change.table_name; + let change_type = (&change.change).into(); - if table_name == SQLITE_SCHEMA_TABLE { - // sqlite_schema table: type, name, tbl_name, rootpage, sql - match change.change { - DatabaseTapeRowChangeType::Delete { before } => { - let before = parse_bin_record(before)?; - assert!(before.len() == 5); - let Some(turso_core::Value::Text(entity_type)) = before.first() else { - panic!( - "unexpected 'type' column of sqlite_schema table: {:?}", - before.first() - ); - }; - let Some(turso_core::Value::Text(entity_name)) = before.get(1) else { - panic!( - "unexpected 'name' column of sqlite_schema table: {:?}", - before.get(1) - ); - }; - self.conn.execute(format!( - "DROP {} {}", - entity_type.as_str(), - entity_name.as_str() - ))?; - } - DatabaseTapeRowChangeType::Insert { after } => { - let after = parse_bin_record(after)?; - assert!(after.len() == 5); - let Some(turso_core::Value::Text(sql)) = after.last() else { - return Err(Error::DatabaseTapeError(format!( - "unexpected 'sql' column of sqlite_schema table: {:?}", - after.last() - ))); - }; - self.conn.execute(sql.as_str())?; - } - DatabaseTapeRowChangeType::Update { updates, .. } => { - let Some(updates) = updates else { - return Err(Error::DatabaseTapeError( - "'updates' column of CDC table must be populated".to_string(), - )); - }; - let updates = parse_bin_record(updates)?; - assert!(updates.len() % 2 == 0); - assert!(updates.len() / 2 == 5); - let turso_core::Value::Text(ddl_stmt) = updates.last().unwrap() else { - panic!( - "unexpected 'sql' column of sqlite_schema table update record: {:?}", - updates.last() - ); - }; - self.conn.execute(ddl_stmt.as_str())?; - } + if table == SQLITE_SCHEMA_TABLE { + let replay_info = self.generator.replay_info(coro, &change).await?; + for replay in &replay_info { + self.conn.execute(replay.query.as_str())?; } } else { match change.change { DatabaseTapeRowChangeType::Delete { before } => { - let before = parse_bin_record(before)?; - self.replay_delete(coro, table_name, change.id, before) - .await? + let key = self.populate_delete_stmt(coro, table).await?; + tracing::trace!( + "ready to use prepared delete statement for replay: key={}", + key + ); + let cached = self.cached_delete_stmt.get_mut(key).unwrap(); + cached.stmt.reset(); + let values = self.generator.replay_values( + &cached.info, + change_type, + change.id, + before, + None, + ); + replay_stmt(coro, cached, values).await?; + } + DatabaseTapeRowChangeType::Insert { after } => { + let key = self.populate_insert_stmt(coro, table, after.len()).await?; + tracing::trace!( + "ready to use prepared insert statement for replay: key={:?}", + key + ); + let cached = self.cached_insert_stmt.get_mut(&key).unwrap(); + cached.stmt.reset(); + let values = self.generator.replay_values( + &cached.info, + change_type, + change.id, + after, + None, + ); + replay_stmt(coro, cached, values).await?; + } + DatabaseTapeRowChangeType::Update { + after, + updates: Some(updates), + .. + } => { + assert!(updates.len() % 2 == 0); + let columns_cnt = updates.len() / 2; + let mut columns = Vec::with_capacity(columns_cnt); + for value in updates.iter().take(columns_cnt) { + columns.push(match value { + turso_core::Value::Integer(x @ (1 | 0)) => *x > 0, + _ => panic!("unexpected 'changes' binary record first-half component: {value:?}") + }); + } + let key = self.populate_update_stmt(coro, table, &columns).await?; + tracing::trace!( + "ready to use prepared update statement for replay: key={:?}", + key + ); + let cached = self.cached_update_stmt.get_mut(&key).unwrap(); + cached.stmt.reset(); + let values = self.generator.replay_values( + &cached.info, + change_type, + change.id, + after, + Some(updates), + ); + replay_stmt(coro, cached, values).await?; } DatabaseTapeRowChangeType::Update { before, after, - updates, + updates: None, } => { - let after = parse_bin_record(after)?; - if let Some(updates) = updates { - let updates = parse_bin_record(updates)?; - assert!(updates.len() % 2 == 0); - let columns_cnt = updates.len() / 2; - let mut columns = Vec::with_capacity(columns_cnt); - let mut values = Vec::with_capacity(columns_cnt); - for (i, value) in updates.into_iter().enumerate() { - if i < columns_cnt { - columns.push(match value { - turso_core::Value::Integer(x @ (1 | 0)) => x > 0, - _ => panic!("unexpected 'changes' binary record first-half component: {value:?}") - }) - } else { - values.push(value); - } - } - self.replay_update( - coro, table_name, change.id, columns, after, values, - ) - .await?; - } else { - let before = parse_bin_record(before)?; - self.replay_delete(coro, table_name, change.id, before) - .await?; - self.replay_insert(coro, table_name, change.id, after) - .await?; - } - } - DatabaseTapeRowChangeType::Insert { after } => { - let values = parse_bin_record(after)?; - self.replay_insert(coro, table_name, change.id, values) - .await?; + let key = self.populate_delete_stmt(coro, table).await?; + tracing::trace!( + "ready to use prepared delete statement for replay of update: key={:?}", + key + ); + let cached = self.cached_delete_stmt.get_mut(key).unwrap(); + cached.stmt.reset(); + let values = self.generator.replay_values( + &cached.info, + change_type, + change.id, + before, + None, + ); + replay_stmt(coro, cached, values).await?; + + let key = self.populate_insert_stmt(coro, table, after.len()).await?; + tracing::trace!( + "ready to use prepared insert statement for replay of update: key={:?}", + key + ); + let cached = self.cached_insert_stmt.get_mut(&key).unwrap(); + cached.stmt.reset(); + let values = self.generator.replay_values( + &cached.info, + change_type, + change.id, + after, + None, + ); + replay_stmt(coro, cached, values).await?; } } } @@ -560,289 +578,55 @@ impl DatabaseReplaySession { } Ok(()) } - async fn replay_delete( - &mut self, - coro: &Coro, - table_name: &str, - id: i64, - mut values: Vec, - ) -> Result<()> { - let cached = self.cached_delete_stmt(coro, table_name).await?; - if let Some(pk_column_indices) = &cached.pk_column_indices { - for (i, pk_column) in pk_column_indices.iter().enumerate() { - let value = std::mem::replace(&mut values[*pk_column], turso_core::Value::Null); - cached.stmt.bind_at((i + 1).try_into().unwrap(), value); - } - } else { - let value = turso_core::Value::Integer(id); - cached.stmt.bind_at(1.try_into().unwrap(), value); + async fn populate_delete_stmt<'a>(&mut self, coro: &Coro, table: &'a str) -> Result<&'a str> { + if self.cached_delete_stmt.contains_key(table) { + return Ok(table); } - exec_stmt(coro, &mut cached.stmt).await?; - Ok(()) + tracing::trace!("prepare delete statement for replay: table={}", table); + let info = self.generator.delete_query(coro, table).await?; + let stmt = self.conn.prepare(&info.query)?; + self.cached_delete_stmt + .insert(table.to_string(), CachedStmt { stmt, info }); + Ok(table) } - async fn replay_insert( + async fn populate_insert_stmt( &mut self, coro: &Coro, - table_name: &str, - id: i64, - values: Vec, - ) -> Result<()> { - let columns = values.len(); - let use_implicit_rowid = self.opts.use_implicit_rowid; - let stmt = self.cached_insert_stmt(coro, table_name, columns).await?; - stmt.reset(); - - for (i, value) in values.into_iter().enumerate() { - stmt.bind_at((i + 1).try_into().unwrap(), value); - } - if use_implicit_rowid { - stmt.bind_at( - (columns + 1).try_into().unwrap(), - turso_core::Value::Integer(id), - ); - } - exec_stmt(coro, stmt).await?; - Ok(()) - } - async fn replay_update( - &mut self, - coro: &Coro, - table_name: &str, - id: i64, - columns: Vec, - mut full: Vec, - mut updates: Vec, - ) -> Result<()> { - let cached = self.cached_update_stmt(coro, table_name, &columns).await?; - let mut position: usize = 1; - for (i, updated) in columns.into_iter().enumerate() { - if !updated { - continue; - } - let value = std::mem::replace(&mut updates[i], turso_core::Value::Null); - cached.stmt.bind_at(position.try_into().unwrap(), value); - position += 1; - } - if let Some(pk_column_indices) = &cached.pk_column_indices { - for pk_column in pk_column_indices { - let value = std::mem::replace(&mut full[*pk_column], turso_core::Value::Null); - cached.stmt.bind_at(position.try_into().unwrap(), value); - position += 1 - } - } else { - let value = turso_core::Value::Integer(id); - cached.stmt.bind_at(position.try_into().unwrap(), value); - } - exec_stmt(coro, &mut cached.stmt).await?; - Ok(()) - } - async fn cached_delete_stmt( - &mut self, - coro: &Coro, - table_name: &str, - ) -> Result<&mut DeleteCachedStmt> { - if !self.cached_delete_stmt.contains_key(table_name) { - tracing::trace!("prepare delete statement for replay: table={}", table_name); - let stmt = self.delete_query(coro, table_name).await?; - self.cached_delete_stmt.insert(table_name.to_string(), stmt); - } - tracing::trace!( - "ready to use prepared delete statement for replay: table={}", - table_name - ); - let cached = self.cached_delete_stmt.get_mut(table_name).unwrap(); - cached.stmt.reset(); - Ok(cached) - } - async fn cached_insert_stmt( - &mut self, - coro: &Coro, - table_name: &str, + table: &str, columns: usize, - ) -> Result<&mut turso_core::Statement> { - let key = (table_name.to_string(), columns); - if !self.cached_insert_stmt.contains_key(&key) { - tracing::trace!( - "prepare insert statement for replay: table={}, columns={}", - table_name, - columns - ); - let stmt = self.insert_query(coro, table_name, columns).await?; - self.cached_insert_stmt.insert(key.clone(), stmt); + ) -> Result<(String, usize)> { + let key = (table.to_string(), columns); + if self.cached_insert_stmt.contains_key(&key) { + return Ok(key); } tracing::trace!( - "ready to use prepared insert statement for replay: table={}, columns={}", - table_name, + "prepare insert statement for replay: table={}, columns={}", + table, columns ); - let stmt = self.cached_insert_stmt.get_mut(&key).unwrap(); - stmt.reset(); - Ok(stmt) + let info = self.generator.insert_query(coro, table, columns).await?; + let stmt = self.conn.prepare(&info.query)?; + self.cached_insert_stmt + .insert(key.clone(), CachedStmt { stmt, info }); + Ok(key) } - async fn cached_update_stmt( + async fn populate_update_stmt( &mut self, coro: &Coro, - table_name: &str, + table: &str, columns: &[bool], - ) -> Result<&mut UpdateCachedStmt> { - let key = (table_name.to_string(), columns.to_owned()); - if !self.cached_update_stmt.contains_key(&key) { - tracing::trace!("prepare update statement for replay: table={}", table_name); - let stmt = self.update_query(coro, table_name, columns).await?; - self.cached_update_stmt.insert(key.clone(), stmt); + ) -> Result<(String, Vec)> { + let key = (table.to_string(), columns.to_owned()); + if self.cached_update_stmt.contains_key(&key) { + return Ok(key); } - tracing::trace!( - "ready to use prepared update statement for replay: table={}", - table_name - ); - let cached = self.cached_update_stmt.get_mut(&key).unwrap(); - cached.stmt.reset(); - Ok(cached) + tracing::trace!("prepare update statement for replay: table={}", table); + let info = self.generator.update_query(coro, table, columns).await?; + let stmt = self.conn.prepare(&info.query)?; + self.cached_update_stmt + .insert(key.clone(), CachedStmt { stmt, info }); + Ok(key) } - async fn insert_query( - &self, - coro: &Coro, - table_name: &str, - columns: usize, - ) -> Result { - let query = if !self.opts.use_implicit_rowid { - let placeholders = ["?"].repeat(columns).join(","); - format!("INSERT INTO {table_name} VALUES ({placeholders})") - } else { - let mut table_info_stmt = self.conn.prepare(format!( - "SELECT name FROM pragma_table_info('{table_name}')" - ))?; - let mut column_names = Vec::with_capacity(columns + 1); - while let Some(column) = run_stmt_once(coro, &mut table_info_stmt).await? { - let turso_core::Value::Text(text) = column.get_value(0) else { - return Err(Error::DatabaseTapeError( - "unexpected column type for pragma_table_info query".to_string(), - )); - }; - column_names.push(text.to_string()); - } - column_names.push("rowid".to_string()); - - let placeholders = ["?"].repeat(columns + 1).join(","); - let column_names = column_names.join(", "); - format!("INSERT INTO {table_name}({column_names}) VALUES ({placeholders})") - }; - Ok(self.conn.prepare(&query)?) - } - async fn delete_query(&self, coro: &Coro, table_name: &str) -> Result { - let (query, pk_column_indices) = if self.opts.use_implicit_rowid { - (format!("DELETE FROM {table_name} WHERE rowid = ?"), None) - } else { - let mut pk_info_stmt = self.conn.prepare(format!( - "SELECT cid, name FROM pragma_table_info('{table_name}') WHERE pk = 1" - ))?; - let mut pk_predicates = Vec::with_capacity(1); - let mut pk_column_indices = Vec::with_capacity(1); - while let Some(column) = run_stmt_once(coro, &mut pk_info_stmt).await? { - let turso_core::Value::Integer(column_id) = column.get_value(0) else { - return Err(Error::DatabaseTapeError( - "unexpected column type for pragma_table_info query".to_string(), - )); - }; - let turso_core::Value::Text(name) = column.get_value(1) else { - return Err(Error::DatabaseTapeError( - "unexpected column type for pragma_table_info query".to_string(), - )); - }; - pk_predicates.push(format!("{name} = ?")); - pk_column_indices.push(*column_id as usize); - } - - if pk_column_indices.is_empty() { - (format!("DELETE FROM {table_name} WHERE rowid = ?"), None) - } else { - let pk_predicates = pk_predicates.join(" AND "); - let query = format!("DELETE FROM {table_name} WHERE {pk_predicates}"); - (query, Some(pk_column_indices)) - } - }; - let use_implicit_rowid = self.opts.use_implicit_rowid; - tracing::trace!("delete_query: table_name={table_name}, query={query}, use_implicit_rowid={use_implicit_rowid}"); - let stmt = self.conn.prepare(&query)?; - Ok(DeleteCachedStmt { - stmt, - pk_column_indices, - }) - } - async fn update_query( - &self, - coro: &Coro, - table_name: &str, - columns: &[bool], - ) -> Result { - let mut table_info_stmt = self.conn.prepare(format!( - "SELECT cid, name, pk FROM pragma_table_info('{table_name}')" - ))?; - let mut pk_predicates = Vec::with_capacity(1); - let mut pk_column_indices = Vec::with_capacity(1); - let mut column_updates = Vec::with_capacity(1); - while let Some(column) = run_stmt_once(coro, &mut table_info_stmt).await? { - let turso_core::Value::Integer(column_id) = column.get_value(0) else { - return Err(Error::DatabaseTapeError( - "unexpected column type for pragma_table_info query".to_string(), - )); - }; - let turso_core::Value::Text(name) = column.get_value(1) else { - return Err(Error::DatabaseTapeError( - "unexpected column type for pragma_table_info query".to_string(), - )); - }; - let turso_core::Value::Integer(pk) = column.get_value(2) else { - return Err(Error::DatabaseTapeError( - "unexpected column type for pragma_table_info query".to_string(), - )); - }; - if *pk == 1 { - pk_predicates.push(format!("{name} = ?")); - pk_column_indices.push(*column_id as usize); - } - if columns[*column_id as usize] { - column_updates.push(format!("{name} = ?")); - } - } - - let (query, pk_column_indices) = if self.opts.use_implicit_rowid { - ( - format!( - "UPDATE {table_name} SET {} WHERE rowid = ?", - column_updates.join(", ") - ), - None, - ) - } else { - ( - format!( - "UPDATE {table_name} SET {} WHERE {}", - column_updates.join(", "), - pk_predicates.join(" AND ") - ), - Some(pk_column_indices), - ) - }; - let stmt = self.conn.prepare(&query)?; - let cached_stmt = UpdateCachedStmt { - stmt, - pk_column_indices, - }; - Ok(cached_stmt) - } -} - -fn parse_bin_record(bin_record: Vec) -> Result> { - let record = turso_core::types::ImmutableRecord::from_bin_record(bin_record); - let mut cursor = turso_core::types::RecordCursor::new(); - let columns = cursor.count(&record); - let mut values = Vec::with_capacity(columns); - for i in 0..columns { - let value = cursor.get_value(&record, i)?; - values.push(value.to_owned()); - } - Ok(values) } #[cfg(test)] diff --git a/packages/turso-sync-engine/src/lib.rs b/packages/turso-sync-engine/src/lib.rs index c67b2e363..0546a15e7 100644 --- a/packages/turso-sync-engine/src/lib.rs +++ b/packages/turso-sync-engine/src/lib.rs @@ -1,37 +1,19 @@ +pub mod database_replay_generator; pub mod database_sync_engine; pub mod database_sync_operations; pub mod database_tape; pub mod errors; pub mod io_operations; pub mod protocol_io; +pub mod server_proto; pub mod types; pub mod wal_session; -#[cfg(test)] -pub mod test_context; -#[cfg(test)] -pub mod test_protocol_io; -#[cfg(test)] -pub mod test_sync_server; - pub type Result = std::result::Result; #[cfg(test)] mod tests { - use std::sync::Arc; - - use tokio::{select, sync::Mutex}; use tracing_subscriber::EnvFilter; - use turso_core::IO; - - use crate::{ - database_sync_engine::{DatabaseSyncEngine, DatabaseSyncEngineOpts}, - errors::Error, - test_context::TestContext, - test_protocol_io::TestProtocolIo, - types::{Coro, ProtocolCommand}, - Result, - }; #[ctor::ctor] fn init() { @@ -41,14 +23,17 @@ mod tests { .init(); } + #[allow(dead_code)] pub fn seed_u64() -> u64 { seed().parse().unwrap_or(0) } + #[allow(dead_code)] pub fn seed() -> String { std::env::var("SEED").unwrap_or("0".to_string()) } + #[allow(dead_code)] pub fn deterministic_runtime_from_seed>( seed: &[u8], f: impl Fn() -> F, @@ -63,107 +48,9 @@ mod tests { runtime.block_on(f()); } + #[allow(dead_code)] pub fn deterministic_runtime>(f: impl Fn() -> F) { let seed = seed(); deterministic_runtime_from_seed(seed.as_bytes(), f); } - - pub struct TestRunner { - pub ctx: Arc, - pub io: Arc, - pub sync_server: TestProtocolIo, - db: Option>>>, - } - - impl TestRunner { - pub fn new(ctx: Arc, io: Arc, sync_server: TestProtocolIo) -> Self { - Self { - ctx, - io, - sync_server, - db: None, - } - } - pub async fn init(&mut self, local_path: &str, opts: DatabaseSyncEngineOpts) -> Result<()> { - let io = self.io.clone(); - let server = self.sync_server.clone(); - let db = self - .run(genawaiter::sync::Gen::new(|coro| async move { - DatabaseSyncEngine::new(&coro, io, Arc::new(server), local_path, opts).await - })) - .await - .unwrap(); - self.db = Some(Arc::new(Mutex::new(db))); - Ok(()) - } - pub async fn connect(&self) -> Result { - self.run_db_fn(self.db.as_ref().unwrap(), async move |coro, db| { - Ok(turso::Connection::create(db.connect(coro).await?)) - }) - .await - } - pub async fn pull(&self) -> Result<()> { - self.run_db_fn(self.db.as_ref().unwrap(), async move |coro, db| { - db.pull(coro).await - }) - .await - } - pub async fn push(&self) -> Result<()> { - self.run_db_fn(self.db.as_ref().unwrap(), async move |coro, db| { - db.push(coro).await - }) - .await - } - pub async fn sync(&self) -> Result<()> { - self.run_db_fn(self.db.as_ref().unwrap(), async move |coro, db| { - db.sync(coro).await - }) - .await - } - pub async fn run_db_fn( - &self, - db: &Arc>>, - f: impl AsyncFn(&Coro, &mut DatabaseSyncEngine) -> Result, - ) -> Result { - let g = genawaiter::sync::Gen::new({ - let db = db.clone(); - |coro| async move { - let mut db = db.lock().await; - f(&coro, &mut db).await - } - }); - self.run(g).await - } - pub async fn run>>( - &self, - mut g: genawaiter::sync::Gen, F>, - ) -> Result { - let mut response = Ok(()); - loop { - // we must drive internal tokio clocks on every iteration - otherwise one TestRunner without work can block everything - // if other TestRunner sleeping - as time will "freeze" in this case - self.ctx.random_sleep().await; - - match g.resume_with(response) { - genawaiter::GeneratorState::Complete(result) => return result, - genawaiter::GeneratorState::Yielded(ProtocolCommand::IO) => { - let drained = { - let mut requests = self.sync_server.requests.lock().unwrap(); - requests.drain(..).collect::>() - }; - for mut request in drained { - select! { - value = &mut request => { value.unwrap(); }, - _ = self.ctx.random_sleep() => { self.sync_server.requests.lock().unwrap().push(request); } - }; - } - response = - self.io.run_once().map(|_| ()).map_err(|e| { - Error::DatabaseSyncEngineError(format!("io error: {e}")) - }); - } - } - } - } - } } diff --git a/packages/turso-sync-engine/src/server_proto.rs b/packages/turso-sync-engine/src/server_proto.rs new file mode 100644 index 000000000..0289afe16 --- /dev/null +++ b/packages/turso-sync-engine/src/server_proto.rs @@ -0,0 +1,231 @@ +use std::collections::VecDeque; + +use bytes::Bytes; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, Debug)] +pub struct PipelineReqBody { + pub baton: Option, + pub requests: VecDeque, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct PipelineRespBody { + pub baton: Option, + pub base_url: Option, + pub results: Vec, +} + +#[derive(Serialize, Deserialize, Debug, Default)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum StreamRequest { + #[serde(skip_deserializing)] + #[default] + None, + /// See [`CloseStreamReq`] + Close(CloseStreamReq), + /// See [`ExecuteStreamReq`] + Execute(ExecuteStreamReq), +} + +#[derive(Serialize, Deserialize, Default, Debug, PartialEq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum StreamResult { + #[default] + None, + Ok, + Error { + error: Error, + }, +} + +#[derive(Serialize, Deserialize, Debug)] +/// A request to close the current stream. +pub struct CloseStreamReq {} + +#[derive(Serialize, Deserialize, Debug)] +/// A request to execute a single SQL statement. +pub struct ExecuteStreamReq { + pub stmt: Stmt, +} + +#[derive(Clone, Deserialize, Serialize, Debug, PartialEq)] +pub struct Error { + pub message: String, + pub code: String, +} + +#[derive(Clone, Deserialize, Serialize, Debug)] +/// A SQL statement to execute. +pub struct Stmt { + #[serde(default)] + /// The SQL statement to execute. + pub sql: Option, + #[serde(default)] + /// The ID of the SQL statement (if it is a stored statement; see [`crate::connections_manager::StreamResource`]). + pub sql_id: Option, + #[serde(default)] + /// The positional arguments to the SQL statement. + pub args: Vec, + #[serde(default)] + /// The named arguments to the SQL statement. + pub named_args: Vec, + #[serde(default)] + /// Whether the SQL statement should return rows. + pub want_rows: Option, + #[serde(default, with = "option_u64_as_str")] + /// The replication index of the SQL statement (a LibSQL concept, currently not used). + pub replication_index: Option, +} + +#[derive(Clone, Deserialize, Serialize, Debug)] +pub struct NamedArg { + pub name: String, + pub value: Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Value { + #[serde(skip_deserializing)] + #[default] + None, + Null, + Integer { + #[serde(with = "i64_as_str")] + value: i64, + }, + Float { + value: f64, + }, + Text { + value: String, + }, + Blob { + #[serde(with = "bytes_as_base64", rename = "base64")] + value: Bytes, + }, +} + +pub mod option_u64_as_str { + use serde::de::Error; + use serde::{de::Visitor, ser, Deserializer, Serialize as _}; + + pub fn serialize(value: &Option, ser: S) -> Result { + value.map(|v| v.to_string()).serialize(ser) + } + + pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result, D::Error> { + struct V; + + impl<'de> Visitor<'de> for V { + type Value = Option; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "a string representing an integer, or null") + } + + fn visit_some(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_any(V) + } + + fn visit_unit(self) -> Result + where + E: Error, + { + Ok(None) + } + + fn visit_none(self) -> Result + where + E: Error, + { + Ok(None) + } + + fn visit_u64(self, v: u64) -> Result + where + E: Error, + { + Ok(Some(v)) + } + + fn visit_str(self, v: &str) -> Result + where + E: Error, + { + v.parse().map_err(E::custom).map(Some) + } + } + + d.deserialize_option(V) + } + + #[cfg(test)] + mod test { + use serde::Deserialize; + + #[test] + fn deserialize_ok() { + #[derive(Deserialize)] + struct Test { + #[serde(with = "super")] + value: Option, + } + + let json = r#"{"value": null }"#; + let val: Test = serde_json::from_str(json).unwrap(); + assert!(val.value.is_none()); + + let json = r#"{"value": "124" }"#; + let val: Test = serde_json::from_str(json).unwrap(); + assert_eq!(val.value.unwrap(), 124); + + let json = r#"{"value": 124 }"#; + let val: Test = serde_json::from_str(json).unwrap(); + assert_eq!(val.value.unwrap(), 124); + } + } +} + +mod i64_as_str { + use serde::{de, ser}; + use serde::{de::Error as _, Serialize as _}; + + pub fn serialize(value: &i64, ser: S) -> Result { + value.to_string().serialize(ser) + } + + pub fn deserialize<'de, D: de::Deserializer<'de>>(de: D) -> Result { + let str_value = <&'de str as de::Deserialize>::deserialize(de)?; + str_value.parse().map_err(|_| { + D::Error::invalid_value( + de::Unexpected::Str(str_value), + &"decimal integer as a string", + ) + }) + } +} + +pub(crate) mod bytes_as_base64 { + use base64::{engine::general_purpose::STANDARD_NO_PAD, Engine as _}; + use bytes::Bytes; + use serde::{de, ser}; + use serde::{de::Error as _, Serialize as _}; + + pub fn serialize(value: &Bytes, ser: S) -> Result { + STANDARD_NO_PAD.encode(value).serialize(ser) + } + + pub fn deserialize<'de, D: de::Deserializer<'de>>(de: D) -> Result { + let text = <&'de str as de::Deserialize>::deserialize(de)?; + let text = text.trim_end_matches('='); + let bytes = STANDARD_NO_PAD.decode(text).map_err(|_| { + D::Error::invalid_value(de::Unexpected::Str(text), &"binary data encoded as base64") + })?; + Ok(Bytes::from(bytes)) + } +} diff --git a/packages/turso-sync-engine/src/test_context.rs b/packages/turso-sync-engine/src/test_context.rs deleted file mode 100644 index 3d67d135d..000000000 --- a/packages/turso-sync-engine/src/test_context.rs +++ /dev/null @@ -1,154 +0,0 @@ -use std::{ - collections::{HashMap, HashSet}, - future::Future, - pin::Pin, - sync::Arc, -}; - -use rand::{RngCore, SeedableRng}; -use rand_chacha::ChaCha8Rng; -use tokio::sync::Mutex; - -use crate::{errors::Error, Result}; - -type PinnedFuture = Pin + Send>>; - -pub struct FaultInjectionPlan { - pub is_fault: Box PinnedFuture + Send + Sync>, -} - -pub enum FaultInjectionStrategy { - Disabled, - Record, - Enabled { plan: FaultInjectionPlan }, -} - -impl std::fmt::Debug for FaultInjectionStrategy { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Disabled => write!(f, "Disabled"), - Self::Record => write!(f, "Record"), - Self::Enabled { .. } => write!(f, "Enabled"), - } - } -} - -pub struct TestContext { - fault_injection: Mutex, - faulty_call: Mutex>, - rng: Mutex, -} - -pub struct FaultSession { - ctx: Arc, - recording: bool, - plans: Option>, -} - -impl FaultSession { - pub async fn next(&mut self) -> Option { - if !self.recording { - self.recording = true; - return Some(FaultInjectionStrategy::Record); - } - if self.plans.is_none() { - self.plans = Some(self.ctx.enumerate_simple_plans().await); - } - - let plans = self.plans.as_mut().unwrap(); - if plans.is_empty() { - return None; - } - - let plan = plans.pop().unwrap(); - Some(FaultInjectionStrategy::Enabled { plan }) - } -} - -impl TestContext { - pub fn new(seed: u64) -> Self { - Self { - rng: Mutex::new(ChaCha8Rng::seed_from_u64(seed)), - fault_injection: Mutex::new(FaultInjectionStrategy::Disabled), - faulty_call: Mutex::new(HashSet::new()), - } - } - pub async fn random_sleep(&self) { - let delay = self.rng.lock().await.next_u64() % 1000; - tokio::time::sleep(std::time::Duration::from_millis(delay)).await - } - pub async fn random_sleep_n(&self, n: u64) { - let delay = { - let mut rng = self.rng.lock().await; - rng.next_u64() % 1000 * (rng.next_u64() % n + 1) - }; - tokio::time::sleep(std::time::Duration::from_millis(delay)).await - } - - pub async fn rng(&self) -> tokio::sync::MutexGuard { - self.rng.lock().await - } - pub fn fault_session(self: &Arc) -> FaultSession { - FaultSession { - ctx: self.clone(), - recording: false, - plans: None, - } - } - pub async fn switch_mode(&self, updated: FaultInjectionStrategy) { - let mut mode = self.fault_injection.lock().await; - tracing::info!("switch fault injection mode: {:?}", updated); - *mode = updated; - } - pub async fn enumerate_simple_plans(&self) -> Vec { - let mut plans = vec![]; - for call in self.faulty_call.lock().await.iter() { - let mut fault_counts = HashMap::new(); - fault_counts.insert(call.clone(), 1); - - let count = Arc::new(Mutex::new(1)); - let call = call.clone(); - plans.push(FaultInjectionPlan { - is_fault: Box::new(move |name, bt| { - let call = call.clone(); - let count = count.clone(); - Box::pin(async move { - if (name, bt) != call { - return false; - } - let mut count = count.lock().await; - *count -= 1; - *count >= 0 - }) - }), - }) - } - plans - } - pub async fn faulty_call(&self, name: &str) -> Result<()> { - tracing::trace!("faulty_call: {}", name); - - // sleep here in order for scheduler to interleave different executions - self.random_sleep().await; - - if let FaultInjectionStrategy::Disabled = &*self.fault_injection.lock().await { - return Ok(()); - } - let bt = std::backtrace::Backtrace::force_capture().to_string(); - match &mut *self.fault_injection.lock().await { - FaultInjectionStrategy::Record => { - let mut call_sites = self.faulty_call.lock().await; - call_sites.insert((name.to_string(), bt)); - Ok(()) - } - FaultInjectionStrategy::Enabled { plan } => { - if plan.is_fault.as_ref()(name.to_string(), bt.clone()).await { - Err(Error::DatabaseSyncEngineError("injected fault".to_string())) - } else { - Ok(()) - } - } - _ => unreachable!("Disabled case handled above"), - } - } -} diff --git a/packages/turso-sync-engine/src/test_empty.db b/packages/turso-sync-engine/src/test_empty.db deleted file mode 100644 index 0a06b0094..000000000 Binary files a/packages/turso-sync-engine/src/test_empty.db and /dev/null differ diff --git a/packages/turso-sync-engine/src/test_protocol_io.rs b/packages/turso-sync-engine/src/test_protocol_io.rs deleted file mode 100644 index 09e00e265..000000000 --- a/packages/turso-sync-engine/src/test_protocol_io.rs +++ /dev/null @@ -1,222 +0,0 @@ -use std::{ - collections::{HashMap, VecDeque}, - path::Path, - pin::Pin, - sync::Arc, -}; - -use tokio::{sync::Mutex, task::JoinHandle}; - -use crate::{ - errors::Error, - protocol_io::{DataCompletion, DataPollResult, ProtocolIO}, - test_context::TestContext, - test_sync_server::TestSyncServer, - Result, -}; - -#[derive(Clone)] -pub struct TestProtocolIo { - #[allow(clippy::type_complexity)] - pub requests: Arc>>>>>, - pub server: TestSyncServer, - ctx: Arc, - files: Arc>>>, -} - -pub struct TestDataPollResult(Vec); - -impl DataPollResult for TestDataPollResult { - fn data(&self) -> &[u8] { - &self.0 - } -} - -#[derive(Clone)] -pub struct TestDataCompletion { - status: Arc>>, - chunks: Arc>>>, - done: Arc>, - poisoned: Arc>>, -} - -impl Default for TestDataCompletion { - fn default() -> Self { - Self::new() - } -} - -impl TestDataCompletion { - pub fn new() -> Self { - Self { - status: Arc::new(std::sync::Mutex::new(None)), - chunks: Arc::new(std::sync::Mutex::new(VecDeque::new())), - done: Arc::new(std::sync::Mutex::new(false)), - poisoned: Arc::new(std::sync::Mutex::new(None)), - } - } - pub fn set_status(&self, status: u16) { - *self.status.lock().unwrap() = Some(status); - } - - pub fn push_data(&self, data: Vec) { - let mut chunks = self.chunks.lock().unwrap(); - chunks.push_back(data); - } - - pub fn set_done(&self) { - *self.done.lock().unwrap() = true; - } - - pub fn poison(&self, err: &str) { - *self.poisoned.lock().unwrap() = Some(err.to_string()); - } -} - -impl DataCompletion for TestDataCompletion { - type DataPollResult = TestDataPollResult; - - fn status(&self) -> Result> { - let poison = self.poisoned.lock().unwrap(); - if poison.is_some() { - return Err(Error::DatabaseSyncEngineError(format!( - "status error: {poison:?}" - ))); - } - Ok(*self.status.lock().unwrap()) - } - - fn poll_data(&self) -> Result> { - let poison = self.poisoned.lock().unwrap(); - if poison.is_some() { - return Err(Error::DatabaseSyncEngineError(format!( - "poll_data error: {poison:?}" - ))); - } - let mut chunks = self.chunks.lock().unwrap(); - Ok(chunks.pop_front().map(TestDataPollResult)) - } - - fn is_done(&self) -> Result { - let poison = self.poisoned.lock().unwrap(); - if poison.is_some() { - return Err(Error::DatabaseSyncEngineError(format!( - "is_done error: {poison:?}" - ))); - } - Ok(*self.done.lock().unwrap()) - } -} - -impl TestProtocolIo { - pub async fn new(ctx: Arc, path: &Path) -> Result { - Ok(Self { - ctx: ctx.clone(), - requests: Arc::new(std::sync::Mutex::new(Vec::new())), - server: TestSyncServer::new(ctx, path).await?, - files: Arc::new(Mutex::new(HashMap::new())), - }) - } - fn schedule< - Fut: std::future::Future> + Send + 'static, - F: FnOnce(TestSyncServer, TestDataCompletion) -> Fut + Send + 'static, - >( - &self, - completion: TestDataCompletion, - f: F, - ) { - let server = self.server.clone(); - let mut requests = self.requests.lock().unwrap(); - requests.push(Box::pin(tokio::spawn(async move { - if let Err(err) = f(server, completion.clone()).await { - tracing::info!("poison completion: {}", err); - completion.poison(&err.to_string()); - } - }))); - } -} - -impl ProtocolIO for TestProtocolIo { - type DataCompletion = TestDataCompletion; - fn http(&self, method: &str, path: &str, data: Option>) -> Result { - let completion = TestDataCompletion::new(); - { - let completion = completion.clone(); - let path = &path[1..].split("/").collect::>(); - match (method, path.as_slice()) { - ("GET", ["info"]) => { - self.schedule(completion, |s, c| async move { s.db_info(c).await }); - } - ("GET", ["export", generation]) => { - let generation = generation.parse().unwrap(); - self.schedule(completion, async move |s, c| { - s.db_export(c, generation).await - }); - } - ("GET", ["sync", generation, start, end]) => { - let generation = generation.parse().unwrap(); - let start = start.parse().unwrap(); - let end = end.parse().unwrap(); - self.schedule(completion, async move |s, c| { - s.wal_pull(c, generation, start, end).await - }); - } - ("POST", ["sync", generation, start, end]) => { - let generation = generation.parse().unwrap(); - let start = start.parse().unwrap(); - let end = end.parse().unwrap(); - let data = data.unwrap(); - self.schedule(completion, async move |s, c| { - s.wal_push(c, None, generation, start, end, data).await - }); - } - ("POST", ["sync", generation, start, end, baton]) => { - let baton = baton.to_string(); - let generation = generation.parse().unwrap(); - let start = start.parse().unwrap(); - let end = end.parse().unwrap(); - let data = data.unwrap(); - self.schedule(completion, async move |s, c| { - s.wal_push(c, Some(baton), generation, start, end, data) - .await - }); - } - _ => panic!("unexpected sync server request: {method} {path:?}"), - }; - } - Ok(completion) - } - - fn full_read(&self, path: &str) -> Result { - let completion = TestDataCompletion::new(); - let ctx = self.ctx.clone(); - let files = self.files.clone(); - let path = path.to_string(); - self.schedule(completion.clone(), async move |_, c| { - ctx.faulty_call("full_read_start").await?; - let files = files.lock().await; - let result = files.get(&path); - c.push_data(result.cloned().unwrap_or(Vec::new())); - ctx.faulty_call("full_read_end").await?; - c.set_done(); - Ok(()) - }); - Ok(completion) - } - - fn full_write(&self, path: &str, content: Vec) -> Result { - let completion = TestDataCompletion::new(); - let ctx = self.ctx.clone(); - let files = self.files.clone(); - let path = path.to_string(); - self.schedule(completion.clone(), async move |_, c| { - ctx.faulty_call("full_write_start").await?; - let mut files = files.lock().await; - files.insert(path, content); - ctx.faulty_call("full_write_end").await?; - c.set_done(); - Ok(()) - }); - Ok(completion) - } -} diff --git a/packages/turso-sync-engine/src/test_sync_server.rs b/packages/turso-sync-engine/src/test_sync_server.rs deleted file mode 100644 index ae5be4a22..000000000 --- a/packages/turso-sync-engine/src/test_sync_server.rs +++ /dev/null @@ -1,351 +0,0 @@ -use std::{ - collections::HashMap, - path::{Path, PathBuf}, - sync::Arc, -}; - -use tokio::sync::Mutex; - -use crate::{ - errors::Error, - test_context::TestContext, - test_protocol_io::TestDataCompletion, - types::{DbSyncInfo, DbSyncStatus}, - Result, -}; - -const PAGE_SIZE: usize = 4096; -const FRAME_SIZE: usize = 24 + PAGE_SIZE; - -struct Generation { - snapshot: Vec, - frames: Vec>, -} - -#[derive(Clone)] -struct SyncSession { - baton: String, - conn: turso::Connection, - in_txn: bool, -} - -struct TestSyncServerState { - generation: u64, - generations: HashMap, - sessions: HashMap, -} - -#[derive(Clone)] -pub struct TestSyncServer { - path: PathBuf, - ctx: Arc, - db: turso::Database, - state: Arc>, -} - -impl TestSyncServer { - pub async fn new(ctx: Arc, path: &Path) -> Result { - let mut generations = HashMap::new(); - generations.insert( - 1, - Generation { - snapshot: EMPTY_WAL_MODE_DB.to_vec(), - frames: Vec::new(), - }, - ); - Ok(Self { - path: path.to_path_buf(), - ctx, - db: turso::Builder::new_local(path.to_str().unwrap()) - .build() - .await?, - state: Arc::new(Mutex::new(TestSyncServerState { - generation: 1, - generations, - sessions: HashMap::new(), - })), - }) - } - pub async fn db_info(&self, completion: TestDataCompletion) -> Result<()> { - tracing::debug!("db_info"); - self.ctx.faulty_call("db_info_start").await?; - - let state = self.state.lock().await; - let result = DbSyncInfo { - current_generation: state.generation, - }; - - completion.set_status(200); - self.ctx.faulty_call("db_info_status").await?; - - completion.push_data(serde_json::to_vec(&result)?); - self.ctx.faulty_call("db_info_data").await?; - - completion.set_done(); - - Ok(()) - } - - pub async fn db_export( - &self, - completion: TestDataCompletion, - generation_id: u64, - ) -> Result<()> { - tracing::debug!("db_export: {}", generation_id); - self.ctx.faulty_call("db_export_start").await?; - - let state = self.state.lock().await; - let Some(generation) = state.generations.get(&generation_id) else { - return Err(Error::DatabaseSyncEngineError( - "generation not found".to_string(), - )); - }; - completion.set_status(200); - self.ctx.faulty_call("db_export_status").await?; - - completion.push_data(generation.snapshot.clone()); - self.ctx.faulty_call("db_export_push").await?; - - completion.set_done(); - - Ok(()) - } - - pub async fn wal_pull( - &self, - completion: TestDataCompletion, - generation_id: u64, - start_frame: u64, - end_frame: u64, - ) -> Result<()> { - tracing::debug!("wal_pull: {}/{}/{}", generation_id, start_frame, end_frame); - self.ctx.faulty_call("wal_pull_start").await?; - - let state = self.state.lock().await; - let Some(generation) = state.generations.get(&generation_id) else { - return Err(Error::DatabaseSyncEngineError( - "generation not found".to_string(), - )); - }; - let mut data = Vec::new(); - for frame_no in start_frame..end_frame { - let frame_idx = frame_no - 1; - let Some(frame) = generation.frames.get(frame_idx as usize) else { - break; - }; - data.extend_from_slice(frame); - } - if data.is_empty() { - let last_generation = state.generations.get(&state.generation).unwrap(); - - let status = DbSyncStatus { - baton: None, - status: "checkpoint_needed".to_string(), - generation: state.generation, - max_frame_no: last_generation.frames.len() as u64, - }; - completion.set_status(400); - self.ctx.faulty_call("wal_pull_400_status").await?; - - completion.push_data(serde_json::to_vec(&status)?); - self.ctx.faulty_call("wal_pull_400_push").await?; - - completion.set_done(); - } else { - completion.set_status(200); - self.ctx.faulty_call("wal_pull_200_status").await?; - - completion.push_data(data); - self.ctx.faulty_call("wal_pull_200_push").await?; - - completion.set_done(); - }; - - Ok(()) - } - - pub async fn wal_push( - &self, - completion: TestDataCompletion, - mut baton: Option, - generation_id: u64, - start_frame: u64, - end_frame: u64, - frames: Vec, - ) -> Result<()> { - tracing::debug!( - "wal_push: {}/{}/{}/{:?}", - generation_id, - start_frame, - end_frame, - baton - ); - self.ctx.faulty_call("wal_push_start").await?; - - let mut session = { - let mut state = self.state.lock().await; - if state.generation != generation_id { - let generation = state.generations.get(&state.generation).unwrap(); - let max_frame_no = generation.frames.len(); - let status = DbSyncStatus { - baton: None, - status: "checkpoint_needed".to_string(), - generation: state.generation, - max_frame_no: max_frame_no as u64, - }; - - let status = serde_json::to_vec(&status)?; - - completion.set_status(200); - self.ctx.faulty_call("wal_push_status").await?; - - completion.push_data(status); - self.ctx.faulty_call("wal_push_push").await?; - - completion.set_done(); - return Ok(()); - } - let baton_str = baton.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); - let session = match state.sessions.get(&baton_str) { - Some(session) => session.clone(), - None => { - let session = SyncSession { - baton: baton_str.clone(), - conn: self.db.connect()?, - in_txn: false, - }; - state.sessions.insert(baton_str.clone(), session.clone()); - session - } - }; - baton = Some(baton_str.clone()); - session - }; - - let conflict = 'conflict: { - let mut offset = 0; - for frame_no in start_frame..end_frame { - if offset + FRAME_SIZE > frames.len() { - return Err(Error::DatabaseSyncEngineError( - "unexpected length of frames data".to_string(), - )); - } - if !session.in_txn { - session.conn.wal_insert_begin()?; - session.in_txn = true; - } - let frame = &frames[offset..offset + FRAME_SIZE]; - match session.conn.wal_insert_frame(frame_no, frame) { - Ok(info) => { - if info.is_commit_frame() { - if session.in_txn { - session.conn.wal_insert_end()?; - session.in_txn = false; - } - self.sync_frames_from_conn(&session.conn).await?; - } - } - Err(turso::Error::WalOperationError(err)) if err.contains("Conflict") => { - session.conn.wal_insert_end()?; - break 'conflict true; - } - Err(err) => { - session.conn.wal_insert_end()?; - return Err(err.into()); - } - } - offset += FRAME_SIZE; - } - false - }; - let mut state = self.state.lock().await; - state - .sessions - .insert(baton.clone().unwrap(), session.clone()); - let status = DbSyncStatus { - baton: Some(session.baton.clone()), - status: if conflict { "conflict" } else { "ok" }.into(), - generation: state.generation, - max_frame_no: session.conn.wal_frame_count()?, - }; - - let status = serde_json::to_vec(&status)?; - - completion.set_status(200); - self.ctx.faulty_call("wal_push_status").await?; - - completion.push_data(status); - self.ctx.faulty_call("wal_push_push").await?; - - completion.set_done(); - - Ok(()) - } - - pub fn db(&self) -> turso::Database { - self.db.clone() - } - pub async fn checkpoint(&self) -> Result<()> { - tracing::debug!("checkpoint sync-server db"); - let conn = self.db.connect()?; - let mut rows = conn.query("PRAGMA wal_checkpoint(TRUNCATE)", ()).await?; - let Some(_) = rows.next().await? else { - return Err(Error::DatabaseSyncEngineError( - "checkpoint must return single row".to_string(), - )); - }; - let mut state = self.state.lock().await; - let generation = state.generation + 1; - state.generation = generation; - state.generations.insert( - generation, - Generation { - snapshot: std::fs::read(&self.path).map_err(|e| { - Error::DatabaseSyncEngineError(format!( - "failed to create generation snapshot: {e}" - )) - })?, - frames: Vec::new(), - }, - ); - Ok(()) - } - pub async fn execute(&self, sql: &str, params: impl turso::IntoParams) -> Result<()> { - let conn = self.db.connect()?; - conn.execute(sql, params).await?; - tracing::debug!("sync_frames_from_conn after execute"); - self.sync_frames_from_conn(&conn).await?; - Ok(()) - } - async fn sync_frames_from_conn(&self, conn: &turso::Connection) -> Result<()> { - let mut state = self.state.lock().await; - let generation = state.generation; - let generation = state.generations.get_mut(&generation).unwrap(); - let last_frame = generation.frames.len() + 1; - let mut frame = [0u8; FRAME_SIZE]; - let wal_frame_count = conn.wal_frame_count()?; - tracing::debug!("conn frames count: {}", wal_frame_count); - for frame_no in last_frame..=wal_frame_count as usize { - let frame_info = conn.wal_get_frame(frame_no as u64, &mut frame)?; - tracing::debug!("push local frame {}, info={:?}", frame_no, frame_info); - generation.frames.push(frame.to_vec()); - } - Ok(()) - } -} - -// empty DB with single 4096-byte page and WAL mode (PRAGMA journal_mode=WAL) -// see test test_empty_wal_mode_db_content which validates asset content -pub const EMPTY_WAL_MODE_DB: &[u8] = include_bytes!("test_empty.db"); - -pub async fn convert_rows(rows: &mut turso::Rows) -> Result>> { - let mut rows_values = vec![]; - while let Some(row) = rows.next().await? { - let mut row_values = vec![]; - for i in 0..row.column_count() { - row_values.push(row.get_value(i)?); - } - rows_values.push(row_values); - } - Ok(rows_values) -} diff --git a/packages/turso-sync-engine/src/types.rs b/packages/turso-sync-engine/src/types.rs index c4fd297d1..5b32462c0 100644 --- a/packages/turso-sync-engine/src/types.rs +++ b/packages/turso-sync-engine/src/types.rs @@ -75,27 +75,31 @@ impl DatabaseChange { pub fn into_apply(self) -> Result { let tape_change = match self.change_type { DatabaseChangeType::Delete => DatabaseTapeRowChangeType::Delete { - before: self.before.ok_or_else(|| { + before: parse_bin_record(self.before.ok_or_else(|| { Error::DatabaseTapeError( "cdc_mode must be set to either 'full' or 'before'".to_string(), ) - })?, + })?)?, }, DatabaseChangeType::Update => DatabaseTapeRowChangeType::Update { - before: self.before.ok_or_else(|| { + before: parse_bin_record(self.before.ok_or_else(|| { Error::DatabaseTapeError("cdc_mode must be set to 'full'".to_string()) - })?, - after: self.after.ok_or_else(|| { + })?)?, + after: parse_bin_record(self.after.ok_or_else(|| { Error::DatabaseTapeError("cdc_mode must be set to 'full'".to_string()) - })?, - updates: self.updates, + })?)?, + updates: if let Some(updates) = self.updates { + Some(parse_bin_record(updates)?) + } else { + None + }, }, DatabaseChangeType::Insert => DatabaseTapeRowChangeType::Insert { - after: self.after.ok_or_else(|| { + after: parse_bin_record(self.after.ok_or_else(|| { Error::DatabaseTapeError( "cdc_mode must be set to either 'full' or 'after'".to_string(), ) - })?, + })?)?, }, }; Ok(DatabaseTapeRowChange { @@ -110,29 +114,29 @@ impl DatabaseChange { pub fn into_revert(self) -> Result { let tape_change = match self.change_type { DatabaseChangeType::Delete => DatabaseTapeRowChangeType::Insert { - after: self.before.ok_or_else(|| { + after: parse_bin_record(self.before.ok_or_else(|| { Error::DatabaseTapeError( "cdc_mode must be set to either 'full' or 'before'".to_string(), ) - })?, + })?)?, }, DatabaseChangeType::Update => DatabaseTapeRowChangeType::Update { - before: self.after.ok_or_else(|| { + before: parse_bin_record(self.after.ok_or_else(|| { Error::DatabaseTapeError("cdc_mode must be set to 'full'".to_string()) - })?, - after: self.before.ok_or_else(|| { + })?)?, + after: parse_bin_record(self.before.ok_or_else(|| { Error::DatabaseTapeError( "cdc_mode must be set to either 'full' or 'before'".to_string(), ) - })?, + })?)?, updates: None, }, DatabaseChangeType::Insert => DatabaseTapeRowChangeType::Delete { - before: self.after.ok_or_else(|| { + before: parse_bin_record(self.after.ok_or_else(|| { Error::DatabaseTapeError( "cdc_mode must be set to either 'full' or 'after'".to_string(), ) - })?, + })?)?, }, }; Ok(DatabaseTapeRowChange { @@ -197,18 +201,28 @@ impl TryFrom<&turso_core::Row> for DatabaseChange { pub enum DatabaseTapeRowChangeType { Delete { - before: Vec, + before: Vec, }, Update { - before: Vec, - after: Vec, - updates: Option>, + before: Vec, + after: Vec, + updates: Option>, }, Insert { - after: Vec, + after: Vec, }, } +impl From<&DatabaseTapeRowChangeType> for DatabaseChangeType { + fn from(value: &DatabaseTapeRowChangeType) -> Self { + match value { + DatabaseTapeRowChangeType::Delete { .. } => DatabaseChangeType::Delete, + DatabaseTapeRowChangeType::Update { .. } => DatabaseChangeType::Update, + DatabaseTapeRowChangeType::Insert { .. } => DatabaseChangeType::Insert, + } + } +} + /// [DatabaseTapeOperation] extends [DatabaseTapeRowChange] by adding information about transaction boundary /// /// This helps [crate::database_tape::DatabaseTapeSession] to properly maintain transaction state and COMMIT or ROLLBACK changes in appropriate time @@ -286,3 +300,15 @@ pub enum ProtocolCommand { // Protocol waits for some IO - caller must spin turso-db IO event loop and also drive ProtocolIO IO, } + +pub fn parse_bin_record(bin_record: Vec) -> Result> { + let record = turso_core::types::ImmutableRecord::from_bin_record(bin_record); + let mut cursor = turso_core::types::RecordCursor::new(); + let columns = cursor.count(&record); + let mut values = Vec::with_capacity(columns); + for i in 0..columns { + let value = cursor.get_value(&record, i)?; + values.push(value.to_owned()); + } + Ok(values) +}