From 009aa479bfd2481808b0070358eeac1f14ebecd2 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Wed, 27 Aug 2025 15:29:30 +0400 Subject: [PATCH] improve sync engine --- Cargo.lock | 35 + sync/engine/Cargo.toml | 2 + sync/engine/src/database_replay_generator.rs | 363 +++-- sync/engine/src/database_sync_engine.rs | 961 ++++++----- sync/engine/src/database_sync_operations.rs | 1442 ++++++++++------- sync/engine/src/database_tape.rs | 191 ++- sync/engine/src/errors.rs | 2 + sync/engine/src/io_operations.rs | 8 +- sync/engine/src/protocol_io.rs | 9 +- sync/engine/src/server_proto.rs | 183 ++- sync/engine/src/types.rs | 84 +- sync/engine/src/wal_session.rs | 6 +- sync/javascript/index.d.ts | 57 +- sync/javascript/index.js | 171 +- sync/javascript/package.json | 2 +- sync/javascript/src/generator.rs | 28 +- sync/javascript/src/js_protocol_io.rs | 6 + sync/javascript/src/lib.rs | 286 +++- sync/javascript/sync_engine.ts | 31 +- sync/javascript/turso-sync-js.wasi-browser.js | 2 + sync/javascript/turso-sync-js.wasi.cjs | 2 + 21 files changed, 2482 insertions(+), 1389 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c141cca43..5330e3d75 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2838,6 +2838,29 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "prost" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7231bd9b3d3d33c86b58adbac74b5ec0ad9f496b19d22801d773636feaa95f3d" +dependencies = [ + "bytes", + "prost-derive", +] + +[[package]] +name = "prost-derive" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9120690fafc389a67ba3803df527d0ec9cbbc9cc45e4cc20b332996dfb672425" +dependencies = [ + "anyhow", + "itertools 0.14.0", + "proc-macro2", + "quote", + "syn 2.0.100", +] + [[package]] name = "py-turso" version = "0.1.4" @@ -3159,6 +3182,16 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "roaring" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f08d6a905edb32d74a5d5737a0c9d7e950c312f3c46cb0ca0a2ca09ea11878a0" +dependencies = [ + "bytemuck", + "byteorder", +] + [[package]] name = "rstest" version = "0.18.2" @@ -4199,8 +4232,10 @@ dependencies = [ "futures", "genawaiter", "http", + "prost", "rand 0.9.2", "rand_chacha 0.9.0", + "roaring", "serde", "serde_json", "tempfile", diff --git a/sync/engine/Cargo.toml b/sync/engine/Cargo.toml index c2a13c4ee..229c60714 100644 --- a/sync/engine/Cargo.toml +++ b/sync/engine/Cargo.toml @@ -17,6 +17,8 @@ genawaiter = { version = "0.99.1", default-features = false } http = "1.3.1" uuid = "1.17.0" base64 = "0.22.1" +prost = "0.14.1" +roaring = "0.11.2" [dev-dependencies] ctor = "0.4.2" diff --git a/sync/engine/src/database_replay_generator.rs b/sync/engine/src/database_replay_generator.rs index cceb6f98f..553a0243a 100644 --- a/sync/engine/src/database_replay_generator.rs +++ b/sync/engine/src/database_replay_generator.rs @@ -1,29 +1,110 @@ -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; use crate::{ database_tape::{run_stmt_once, DatabaseReplaySessionOpts}, errors::Error, - types::{Coro, DatabaseChangeType, DatabaseTapeRowChange, DatabaseTapeRowChangeType}, + types::{ + Coro, DatabaseChangeType, DatabaseRowMutation, DatabaseTapeRowChange, + DatabaseTapeRowChangeType, + }, Result, }; -pub struct DatabaseReplayGenerator { +pub struct DatabaseReplayGenerator { pub conn: Arc, - pub opts: DatabaseReplaySessionOpts, + pub opts: DatabaseReplaySessionOpts, } pub struct ReplayInfo { pub change_type: DatabaseChangeType, pub query: String, pub pk_column_indices: Option>, + pub column_names: Vec, pub is_ddl_replay: bool, } const SQLITE_SCHEMA_TABLE: &str = "sqlite_schema"; -impl DatabaseReplayGenerator { - pub fn new(conn: Arc, opts: DatabaseReplaySessionOpts) -> Self { +impl DatabaseReplayGenerator { + pub fn new(conn: Arc, opts: DatabaseReplaySessionOpts) -> Self { Self { conn, opts } } + pub fn create_mutation( + &self, + info: &ReplayInfo, + change: &DatabaseTapeRowChange, + ) -> Result { + match &change.change { + DatabaseTapeRowChangeType::Delete { before } => Ok(DatabaseRowMutation { + change_time: change.change_time, + table_name: change.table_name.to_string(), + id: change.id, + change_type: info.change_type, + before: Some(self.create_row_full(info, before)), + after: None, + updates: None, + }), + DatabaseTapeRowChangeType::Insert { after } => Ok(DatabaseRowMutation { + change_time: change.change_time, + table_name: change.table_name.to_string(), + id: change.id, + change_type: info.change_type, + before: None, + after: Some(self.create_row_full(info, after)), + updates: None, + }), + DatabaseTapeRowChangeType::Update { + before, + after, + updates, + } => Ok(DatabaseRowMutation { + change_time: change.change_time, + table_name: change.table_name.to_string(), + id: change.id, + change_type: info.change_type, + before: Some(self.create_row_full(info, before)), + after: Some(self.create_row_full(info, after)), + updates: updates + .as_ref() + .map(|updates| self.create_row_update(info, &updates)), + }), + } + } + fn create_row_full( + &self, + info: &ReplayInfo, + values: &Vec, + ) -> HashMap { + let mut row = HashMap::with_capacity(info.column_names.len()); + for (i, value) in values.iter().enumerate() { + row.insert(info.column_names[i].clone(), value.clone()); + } + row + } + fn create_row_update( + &self, + info: &ReplayInfo, + updates: &Vec, + ) -> HashMap { + let mut row = HashMap::with_capacity(info.column_names.len()); + assert!(updates.len() % 2 == 0); + let columns_cnt = updates.len() / 2; + for (i, value) in updates.iter().take(columns_cnt).enumerate() { + let updated = match value { + turso_core::Value::Integer(x @ (1 | 0)) => *x > 0, + _ => { + panic!("unexpected 'changes' binary record first-half component: {value:?}") + } + }; + if !updated { + continue; + } + row.insert( + info.column_names[i].clone(), + updates[columns_cnt + i].clone(), + ); + } + row + } pub fn replay_values( &self, info: &ReplayInfo, @@ -89,9 +170,9 @@ impl DatabaseReplayGenerator { } pub async fn replay_info( &self, - coro: &Coro, + coro: &Coro, change: &DatabaseTapeRowChange, - ) -> Result> { + ) -> Result { tracing::trace!("replay: change={:?}", change); let table_name = &change.table_name; @@ -117,9 +198,10 @@ impl DatabaseReplayGenerator { change_type: DatabaseChangeType::Delete, query, pk_column_indices: None, + column_names: Vec::new(), is_ddl_replay: true, }; - Ok(vec![delete]) + Ok(delete) } DatabaseTapeRowChangeType::Insert { after } => { assert!(after.len() == 5); @@ -133,9 +215,10 @@ impl DatabaseReplayGenerator { change_type: DatabaseChangeType::Insert, query: sql.as_str().to_string(), pk_column_indices: None, + column_names: Vec::new(), is_ddl_replay: true, }; - Ok(vec![insert]) + Ok(insert) } DatabaseTapeRowChangeType::Update { updates, .. } => { let Some(updates) = updates else { @@ -155,16 +238,17 @@ impl DatabaseReplayGenerator { change_type: DatabaseChangeType::Update, query: ddl_stmt.as_str().to_string(), pk_column_indices: None, + column_names: Vec::new(), is_ddl_replay: true, }; - Ok(vec![update]) + Ok(update) } } } else { match &change.change { DatabaseTapeRowChangeType::Delete { .. } => { let delete = self.delete_query(coro, table_name).await?; - Ok(vec![delete]) + Ok(delete) } DatabaseTapeRowChangeType::Update { updates, after, .. } => { if let Some(updates) = updates { @@ -178,32 +262,159 @@ impl DatabaseReplayGenerator { }); } let update = self.update_query(coro, table_name, &columns).await?; - Ok(vec![update]) + Ok(update) } else { - let delete = self.delete_query(coro, table_name).await?; - let insert = self.insert_query(coro, table_name, after.len()).await?; - Ok(vec![delete, insert]) + let columns = [true].repeat(after.len()); + let update = self.update_query(coro, table_name, &columns).await?; + Ok(update) } } DatabaseTapeRowChangeType::Insert { after } => { let insert = self.insert_query(coro, table_name, after.len()).await?; - Ok(vec![insert]) + Ok(insert) } } } } pub(crate) async fn update_query( &self, - coro: &Coro, + coro: &Coro, table_name: &str, columns: &[bool], ) -> Result { + let (column_names, pk_column_indices) = self.table_columns_info(coro, table_name).await?; + let mut pk_predicates = Vec::with_capacity(1); + let mut column_updates = Vec::with_capacity(1); + for &idx in &pk_column_indices { + pk_predicates.push(format!("{} = ?", column_names[idx])); + } + for (idx, name) in column_names.iter().enumerate() { + if columns[idx as usize] { + column_updates.push(format!("{name} = ?")); + } + } + let (query, pk_column_indices) = + if self.opts.use_implicit_rowid || pk_column_indices.is_empty() { + ( + format!( + "UPDATE {table_name} SET {} WHERE rowid = ?", + column_updates.join(", ") + ), + None, + ) + } else { + ( + format!( + "UPDATE {table_name} SET {} WHERE {}", + column_updates.join(", "), + pk_predicates.join(" AND ") + ), + Some(pk_column_indices), + ) + }; + Ok(ReplayInfo { + change_type: DatabaseChangeType::Update, + query, + column_names, + pk_column_indices, + is_ddl_replay: false, + }) + } + pub(crate) async fn insert_query( + &self, + coro: &Coro, + table_name: &str, + columns: usize, + ) -> Result { + let (mut column_names, pk_column_indices) = self.table_columns_info(coro, table_name).await?; + let conflict_clause = if !pk_column_indices.is_empty() { + let mut pk_column_names = Vec::new(); + for &idx in &pk_column_indices { + pk_column_names.push(column_names[idx].clone()); + } + let mut update_clauses = Vec::new(); + for name in &column_names { + update_clauses.push(format!("{name} = excluded.{name}")); + } + format!( + "ON CONFLICT({}) DO UPDATE SET {}", + pk_column_names.join(","), + update_clauses.join(",") + ) + } else { + String::new() + }; + if !self.opts.use_implicit_rowid { + let placeholders = ["?"].repeat(columns).join(","); + let query = + format!("INSERT INTO {table_name} VALUES ({placeholders}){conflict_clause}"); + return Ok(ReplayInfo { + change_type: DatabaseChangeType::Insert, + query, + pk_column_indices: None, + column_names, + is_ddl_replay: false, + }); + }; + let original_column_names = column_names.clone(); + column_names.push("rowid".to_string()); + + let placeholders = ["?"].repeat(columns + 1).join(","); + let column_names = column_names.join(", "); + let query = format!("INSERT INTO {table_name}({column_names}) VALUES ({placeholders})"); + Ok(ReplayInfo { + change_type: DatabaseChangeType::Insert, + query, + column_names: original_column_names, + pk_column_indices: None, + is_ddl_replay: false, + }) + } + pub(crate) async fn delete_query( + &self, + coro: &Coro, + table_name: &str, + ) -> Result { + let (column_names, pk_column_indices) = self.table_columns_info(coro, table_name).await?; + let mut pk_predicates = Vec::with_capacity(1); + for &idx in &pk_column_indices { + pk_predicates.push(format!("{} = ?", column_names[idx])); + } + let use_implicit_rowid = self.opts.use_implicit_rowid; + if pk_column_indices.is_empty() || use_implicit_rowid { + let query = format!("DELETE FROM {table_name} WHERE rowid = ?"); + tracing::trace!("delete_query: table_name={table_name}, query={query}, use_implicit_rowid={use_implicit_rowid}"); + return Ok(ReplayInfo { + change_type: DatabaseChangeType::Delete, + query, + column_names, + pk_column_indices: None, + is_ddl_replay: false, + }); + } + let pk_predicates = pk_predicates.join(" AND "); + let query = format!("DELETE FROM {table_name} WHERE {pk_predicates}"); + + tracing::trace!("delete_query: table_name={table_name}, query={query}, use_implicit_rowid={use_implicit_rowid}"); + Ok(ReplayInfo { + change_type: DatabaseChangeType::Delete, + query, + column_names, + pk_column_indices: Some(pk_column_indices), + is_ddl_replay: false, + }) + } + + async fn table_columns_info( + &self, + coro: &Coro, + table_name: &str, + ) -> Result<(Vec, Vec)> { let mut table_info_stmt = self.conn.prepare(format!( "SELECT cid, name, pk FROM pragma_table_info('{table_name}')" ))?; - let mut pk_predicates = Vec::with_capacity(1); let mut pk_column_indices = Vec::with_capacity(1); - let mut column_updates = Vec::with_capacity(1); + let mut column_names = Vec::new(); while let Some(column) = run_stmt_once(coro, &mut table_info_stmt).await? { let turso_core::Value::Integer(column_id) = column.get_value(0) else { return Err(Error::DatabaseTapeError( @@ -221,118 +432,10 @@ impl DatabaseReplayGenerator { )); }; if *pk == 1 { - pk_predicates.push(format!("{name} = ?")); pk_column_indices.push(*column_id as usize); } - if columns[*column_id as usize] { - column_updates.push(format!("{name} = ?")); - } + column_names.push(name.as_str().to_string()); } - - let (query, pk_column_indices) = if self.opts.use_implicit_rowid { - ( - format!( - "UPDATE {table_name} SET {} WHERE rowid = ?", - column_updates.join(", ") - ), - None, - ) - } else { - ( - format!( - "UPDATE {table_name} SET {} WHERE {}", - column_updates.join(", "), - pk_predicates.join(" AND ") - ), - Some(pk_column_indices), - ) - }; - Ok(ReplayInfo { - change_type: DatabaseChangeType::Update, - query, - pk_column_indices, - is_ddl_replay: false, - }) - } - pub(crate) async fn insert_query( - &self, - coro: &Coro, - table_name: &str, - columns: usize, - ) -> Result { - if !self.opts.use_implicit_rowid { - let placeholders = ["?"].repeat(columns).join(","); - let query = format!("INSERT INTO {table_name} VALUES ({placeholders})"); - return Ok(ReplayInfo { - change_type: DatabaseChangeType::Insert, - query, - pk_column_indices: None, - is_ddl_replay: false, - }); - }; - let mut table_info_stmt = self.conn.prepare(format!( - "SELECT name FROM pragma_table_info('{table_name}')" - ))?; - let mut column_names = Vec::with_capacity(columns + 1); - while let Some(column) = run_stmt_once(coro, &mut table_info_stmt).await? { - let turso_core::Value::Text(text) = column.get_value(0) else { - return Err(Error::DatabaseTapeError( - "unexpected column type for pragma_table_info query".to_string(), - )); - }; - column_names.push(text.to_string()); - } - column_names.push("rowid".to_string()); - - let placeholders = ["?"].repeat(columns + 1).join(","); - let column_names = column_names.join(", "); - let query = format!("INSERT INTO {table_name}({column_names}) VALUES ({placeholders})"); - Ok(ReplayInfo { - change_type: DatabaseChangeType::Insert, - query, - pk_column_indices: None, - is_ddl_replay: false, - }) - } - pub(crate) async fn delete_query(&self, coro: &Coro, table_name: &str) -> Result { - let (query, pk_column_indices) = if self.opts.use_implicit_rowid { - (format!("DELETE FROM {table_name} WHERE rowid = ?"), None) - } else { - let mut pk_info_stmt = self.conn.prepare(format!( - "SELECT cid, name FROM pragma_table_info('{table_name}') WHERE pk = 1" - ))?; - let mut pk_predicates = Vec::with_capacity(1); - let mut pk_column_indices = Vec::with_capacity(1); - while let Some(column) = run_stmt_once(coro, &mut pk_info_stmt).await? { - let turso_core::Value::Integer(column_id) = column.get_value(0) else { - return Err(Error::DatabaseTapeError( - "unexpected column type for pragma_table_info query".to_string(), - )); - }; - let turso_core::Value::Text(name) = column.get_value(1) else { - return Err(Error::DatabaseTapeError( - "unexpected column type for pragma_table_info query".to_string(), - )); - }; - pk_predicates.push(format!("{name} = ?")); - pk_column_indices.push(*column_id as usize); - } - - if pk_column_indices.is_empty() { - (format!("DELETE FROM {table_name} WHERE rowid = ?"), None) - } else { - let pk_predicates = pk_predicates.join(" AND "); - let query = format!("DELETE FROM {table_name} WHERE {pk_predicates}"); - (query, Some(pk_column_indices)) - } - }; - let use_implicit_rowid = self.opts.use_implicit_rowid; - tracing::trace!("delete_query: table_name={table_name}, query={query}, use_implicit_rowid={use_implicit_rowid}"); - Ok(ReplayInfo { - change_type: DatabaseChangeType::Delete, - query, - pk_column_indices, - is_ddl_replay: false, - }) + Ok((column_names, pk_column_indices)) } } diff --git a/sync/engine/src/database_sync_engine.rs b/sync/engine/src/database_sync_engine.rs index faf43a776..b15fa876b 100644 --- a/sync/engine/src/database_sync_engine.rs +++ b/sync/engine/src/database_sync_engine.rs @@ -1,253 +1,93 @@ -use std::sync::Arc; +use std::{ + cell::RefCell, + collections::{HashMap, HashSet}, + sync::Arc, +}; + +use turso_core::OpenFlags; +use uuid::Uuid; use crate::{ + database_replay_generator::DatabaseReplayGenerator, database_sync_operations::{ - checkpoint_wal_file, connect, connect_untracked, db_bootstrap, push_logical_changes, - reset_wal_file, transfer_logical_changes, transfer_physical_changes, wait_full_body, - wal_pull, wal_push, WalPullResult, + bootstrap_db_file, connect_untracked, count_local_changes, fetch_last_change_id, has_table, + push_logical_changes, read_wal_salt, reset_wal_file, update_last_change_id, wait_full_body, + wal_apply_from_file, wal_pull_to_file, PAGE_SIZE, WAL_FRAME_HEADER, WAL_FRAME_SIZE, + }, + database_tape::{ + DatabaseChangesIteratorMode, DatabaseChangesIteratorOpts, DatabaseReplaySession, + DatabaseReplaySessionOpts, DatabaseTape, DatabaseTapeOpts, DatabaseWalSession, + CDC_PRAGMA_NAME, }, - database_tape::DatabaseTape, errors::Error, io_operations::IoOperations, protocol_io::ProtocolIO, - types::{Coro, DatabaseMetadata}, + types::{ + Coro, DatabaseMetadata, DatabasePullRevision, DatabaseRowMutation, DatabaseRowStatement, + DatabaseSyncEngineProtocolVersion, DatabaseTapeOperation, DbChangesStatus, SyncEngineStats, + }, wal_session::WalSession, Result, }; -#[derive(Debug, Clone)] -pub struct DatabaseSyncEngineOpts { +#[derive(Clone)] +pub struct DatabaseSyncEngineOpts { pub client_name: String, + pub tables_ignore: Vec, + pub transform: Option< + Arc Result> + 'static>, + >, pub wal_pull_batch_size: u64, + pub protocol_version_hint: DatabaseSyncEngineProtocolVersion, } -pub struct DatabaseSyncEngine { +impl std::fmt::Debug for DatabaseSyncEngineOpts { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DatabaseSyncEngineOpts") + .field("client_name", &self.client_name) + .field("tables_ignore", &self.tables_ignore) + .field("transform.is_some()", &self.transform.is_some()) + .field("wal_pull_batch_size", &self.wal_pull_batch_size) + .finish() + } +} + +pub struct DatabaseSyncEngine { io: Arc, protocol: Arc

, - draft_tape: DatabaseTape, - draft_path: String, - synced_path: String, + db_file: Arc, + main_tape: DatabaseTape, + main_db_wal_path: String, + revert_db_wal_path: String, + main_db_path: String, meta_path: String, - opts: DatabaseSyncEngineOpts, - meta: Option, - // we remember information if Synced DB is dirty - which will make Database to reset it in case of any sync attempt - // this bit is set to false when we properly reset Synced DB - // this bit is set to true when we transfer changes from Draft to Synced or on initialization - synced_is_dirty: bool, + opts: DatabaseSyncEngineOpts, + meta: RefCell, } -async fn update_meta( - coro: &Coro, - io: &IO, - meta_path: &str, - orig: &mut Option, - update: impl FnOnce(&mut DatabaseMetadata), -) -> Result<()> { - let mut meta = orig.as_ref().unwrap().clone(); - update(&mut meta); - tracing::info!("update_meta: {meta:?}"); - let completion = io.full_write(meta_path, meta.dump()?)?; - // todo: what happen if we will actually update the metadata on disk but fail and so in memory state will not be updated - wait_full_body(coro, &completion).await?; - *orig = Some(meta); - Ok(()) +fn db_size_from_page(page: &[u8]) -> u32 { + u32::from_be_bytes(page[28..28 + 4].try_into().unwrap()) } -async fn set_meta( - coro: &Coro, - io: &IO, - meta_path: &str, - orig: &mut Option, - meta: DatabaseMetadata, -) -> Result<()> { - tracing::info!("set_meta: {meta:?}"); - let completion = io.full_write(meta_path, meta.dump()?)?; - // todo: what happen if we will actually update the metadata on disk but fail and so in memory state will not be updated - wait_full_body(coro, &completion).await?; - *orig = Some(meta); - Ok(()) -} - -impl DatabaseSyncEngine { +impl DatabaseSyncEngine { /// Creates new instance of SyncEngine and initialize it immediately if no consistent local data exists pub async fn new( - coro: &Coro, + coro: &Coro, io: Arc, - protocol: Arc, - path: &str, - opts: DatabaseSyncEngineOpts, + protocol: Arc

, + main_db_path: &str, + opts: DatabaseSyncEngineOpts, ) -> Result { - let draft_path = format!("{path}-draft"); - let draft_tape = io.open_tape(&draft_path, true)?; - let mut db = Self { - io, - protocol, - draft_tape, - draft_path, - synced_path: format!("{path}-synced"), - meta_path: format!("{path}-info"), - opts, - meta: None, - synced_is_dirty: true, - }; - db.init(coro).await?; - Ok(db) - } + let main_db_wal_path = format!("{main_db_path}-wal"); + let revert_db_wal_path = format!("{main_db_path}-wal-revert"); + let meta_path = format!("{main_db_path}-info"); - /// Create database connection and appropriately configure it before use - pub async fn connect(&self, coro: &Coro) -> Result> { - connect(coro, &self.draft_tape).await - } + let db_file = io.open_file(main_db_path, turso_core::OpenFlags::Create, false)?; + let db_file = Arc::new(turso_core::storage::database::DatabaseFile::new(db_file)); - /// Sync all new changes from remote DB and apply them locally - /// This method will **not** send local changed to the remote - /// This method will block writes for the period of pull - pub async fn pull(&mut self, coro: &Coro) -> Result<()> { - tracing::info!( - "pull: draft={}, synced={}", - self.draft_path, - self.synced_path - ); + tracing::info!("init(path={}): opts={:?}", main_db_path, opts); - // reset Synced DB if it wasn't properly cleaned-up on previous "sync-method" attempt - self.reset_synced_if_dirty(coro).await?; - - loop { - // update Synced DB with fresh changes from remote - let pull_result = self.pull_synced_from_remote(coro).await?; - - { - // we will "replay" Synced WAL to the Draft WAL later without pushing it to the remote - // so, we pass 'capture: true' as we need to preserve all changes for future push of WAL - let synced = self.io.open_tape(&self.synced_path, true)?; - - // we will start wal write session for Draft DB in order to hold write lock during transfer of changes - let mut draft_session = WalSession::new(connect(coro, &self.draft_tape).await?); - draft_session.begin()?; - - // mark Synced as dirty as we will start transfer of logical changes there and if we will fail in the middle - we will need to cleanup Synced db - self.synced_is_dirty = true; - - // transfer logical changes to the Synced DB in order to later execute physical "rebase" operation - let client_id = &self.meta().client_unique_id; - transfer_logical_changes(coro, &self.draft_tape, &synced, client_id, true).await?; - - // now we are ready to do the rebase: let's transfer physical changes from Synced to Draft - let synced_wal_watermark = self.meta().synced_wal_match_watermark; - let synced_sync_watermark = self.meta().synced_frame_no.expect( - "synced_frame_no must be set as we call pull_synced_from_remote before that", - ); - let draft_wal_watermark = self.meta().draft_wal_match_watermark; - let draft_sync_watermark = transfer_physical_changes( - coro, - &synced, - draft_session, - synced_wal_watermark, - synced_sync_watermark, - draft_wal_watermark, - ) - .await?; - update_meta( - coro, - self.protocol.as_ref(), - &self.meta_path, - &mut self.meta, - |m| { - m.draft_wal_match_watermark = draft_sync_watermark; - m.synced_wal_match_watermark = synced_sync_watermark; - }, - ) - .await?; - } - - // Synced DB is 100% dirty now - let's reset it - assert!(self.synced_is_dirty); - self.reset_synced_if_dirty(coro).await?; - - let WalPullResult::NeedCheckpoint = pull_result else { - break; - }; - tracing::info!( - "ready to checkpoint synced db file at {:?}, generation={}", - self.synced_path, - self.meta().synced_generation - ); - { - let synced = self.io.open_tape(&self.synced_path, false)?; - checkpoint_wal_file(coro, &connect_untracked(&synced)?).await?; - update_meta( - coro, - self.protocol.as_ref(), - &self.meta_path, - &mut self.meta, - |m| { - m.synced_generation += 1; - m.synced_frame_no = Some(0); - m.synced_wal_match_watermark = 0; - }, - ) - .await?; - } - } - - Ok(()) - } - - /// Sync local changes to remote DB - /// This method will **not** pull remote changes to the local DB - /// This method will **not** block writes for the period of sync - pub async fn push(&mut self, coro: &Coro) -> Result<()> { - tracing::info!( - "push: draft={}, synced={}", - self.draft_path, - self.synced_path - ); - - // reset Synced DB if it wasn't properly cleaned-up on previous "sync-method" attempt - self.reset_synced_if_dirty(coro).await?; - - // update Synced DB with fresh changes from remote in order to avoid WAL frame conflicts - self.pull_synced_from_remote(coro).await?; - - // we will push Synced WAL to the remote - // so, we pass 'capture: false' as we don't need to preserve changes made to Synced WAL in turso_cdc - let synced = self.io.open_tape(&self.synced_path, false)?; - - self.synced_is_dirty = true; - let client_id = &self.meta().client_unique_id; - push_logical_changes( - coro, - self.protocol.as_ref(), - &self.draft_tape, - &synced, - client_id, - ) - .await?; - - self.reset_synced_if_dirty(coro).await?; - - Ok(()) - } - - /// Sync local changes to remote DB and bring new changes from remote to local - /// This method will block writes for the period of sync - pub async fn sync(&mut self, coro: &Coro) -> Result<()> { - // todo(sivukhin): this is bit suboptimal as both 'push' and 'pull' will call pull_synced_from_remote - // but for now - keep it simple - self.push(coro).await?; - self.pull(coro).await?; - Ok(()) - } - - async fn init(&mut self, coro: &Coro) -> Result<()> { - tracing::info!( - "initialize sync engine: draft={}, synced={}, opts={:?}", - self.draft_path, - self.synced_path, - self.opts, - ); - - let completion = self.protocol.full_read(&self.meta_path)?; + let completion = protocol.full_read(&meta_path)?; let data = wait_full_body(coro, &completion).await?; let meta = if data.is_empty() { None @@ -255,200 +95,551 @@ impl DatabaseSyncEngine { Some(DatabaseMetadata::load(&data)?) }; - match meta { - Some(meta) => { - self.meta = Some(meta); - } + let meta = match meta { + Some(meta) => meta, None => { - let meta = self.bootstrap_db_files(coro).await?; - tracing::info!("write meta after successful bootstrap: meta={meta:?}"); - set_meta( + let client_unique_id = format!("{}-{}", opts.client_name, uuid::Uuid::new_v4()); + let revision = bootstrap_db_file( coro, - self.protocol.as_ref(), - &self.meta_path, - &mut self.meta, - meta, + protocol.as_ref(), + &io, + &main_db_path, + opts.protocol_version_hint, ) .await?; + let meta = DatabaseMetadata { + client_unique_id, + synced_revision: Some(revision), + revert_since_wal_salt: None, + revert_since_wal_watermark: 0, + last_pushed_change_id_hint: 0, + last_pushed_pull_gen_hint: 0, + }; + tracing::info!("write meta after successful bootstrap: meta={meta:?}"); + let completion = protocol.full_write(&meta_path, meta.dump()?)?; + // todo: what happen if we will actually update the metadata on disk but fail and so in memory state will not be updated + wait_full_body(coro, &completion).await?; + meta } }; - let draft_exists = self.io.try_open(&self.draft_path)?.is_some(); - let synced_exists = self.io.try_open(&self.synced_path)?.is_some(); - if !draft_exists || !synced_exists { - let error = "Draft or Synced files doesn't exists, but metadata is".to_string(); + let main_exists = io.try_open(&main_db_path)?.is_some(); + if !main_exists { + let error = "main DB file doesn't exists, but metadata is".to_string(); return Err(Error::DatabaseSyncEngineError(error)); } - if self.meta().synced_frame_no.is_none() { - // sync WAL from the remote in case of bootstrap - all subsequent initializations will be fast - self.pull(coro).await?; - } - Ok(()) - } - - async fn pull_synced_from_remote(&mut self, coro: &Coro) -> Result { - tracing::info!( - "pull_synced_from_remote: draft={:?}, synced={:?}", - self.draft_path, - self.synced_path, - ); - let synced = self.io.open_tape(&self.synced_path, false)?; - let synced_conn = connect(coro, &synced).await?; - let mut wal = WalSession::new(synced_conn); - - let generation = self.meta().synced_generation; - let mut start_frame = self.meta().synced_frame_no.unwrap_or(0) + 1; - loop { - let end_frame = start_frame + self.opts.wal_pull_batch_size; - let update = async |coro, frame_no| { - update_meta( - coro, - self.protocol.as_ref(), - &self.meta_path, - &mut self.meta, - |m| m.synced_frame_no = Some(frame_no), - ) - .await - }; - match wal_pull( - coro, - self.protocol.as_ref(), - &mut wal, - generation, - start_frame, - end_frame, - update, - ) - .await? - { - WalPullResult::Done => return Ok(WalPullResult::Done), - WalPullResult::NeedCheckpoint => return Ok(WalPullResult::NeedCheckpoint), - WalPullResult::PullMore => { - start_frame = end_frame; - continue; - } - } - } - } - - #[allow(dead_code)] - async fn push_synced_to_remote(&mut self, coro: &Coro) -> Result<()> { - tracing::info!( - "push_synced_to_remote: draft={}, synced={}, id={}", - self.draft_path, - self.synced_path, - self.meta().client_unique_id - ); - let synced = self.io.open_tape(&self.synced_path, false)?; - let synced_conn = connect(coro, &synced).await?; - - let mut wal = WalSession::new(synced_conn); - wal.begin()?; - - // todo(sivukhin): push frames in multiple batches - let generation = self.meta().synced_generation; - let start_frame = self.meta().synced_frame_no.unwrap_or(0) + 1; - let end_frame = wal.conn().wal_state()?.max_frame + 1; - match wal_push( - coro, - self.protocol.as_ref(), - &mut wal, - None, - generation, - start_frame, - end_frame, + let main_db = turso_core::Database::open_with_flags( + io.clone(), + main_db_path, + db_file.clone(), + OpenFlags::Create, + false, + true, + false, ) - .await + .unwrap(); + let tape_opts = DatabaseTapeOpts { + cdc_table: None, + cdc_mode: Some("full".to_string()), + }; + let main_tape = DatabaseTape::new_with_opts(main_db, tape_opts); + tracing::info!("initialize database tape connection: path={}", main_db_path); + let mut db = Self { + io, + protocol, + db_file, + main_db_wal_path, + main_tape, + revert_db_wal_path, + main_db_path: main_db_path.to_string(), + meta_path: format!("{main_db_path}-info"), + opts, + meta: RefCell::new(meta.clone()), + }; + + let synced_revision = meta.synced_revision.as_ref().unwrap(); + if let DatabasePullRevision::Legacy { + synced_frame_no: None, + .. + } = synced_revision { - Ok(_) => { - update_meta( - coro, - self.protocol.as_ref(), - &self.meta_path, - &mut self.meta, - |m| m.synced_frame_no = Some(end_frame - 1), - ) - .await?; - self.synced_is_dirty = false; - Ok(()) - } - Err(err) => { - tracing::info!("push_synced_to_remote: failed: err={err}"); - Err(err) + // sync WAL from the remote in case of bootstrap - all subsequent initializations will be fast + if let Some(changes) = db.wait_changes_from_remote(coro).await? { + db.apply_changes_from_remote(coro, changes).await?; } } + Ok(db) } - async fn bootstrap_db_files(&mut self, coro: &Coro) -> Result { - assert!( - self.meta.is_none(), - "bootstrap_db_files must be called only when meta is not set" - ); + fn open_revert_db_conn(&mut self) -> Result> { + let db = turso_core::Database::open_with_flags_bypass_registry( + self.io.clone(), + &self.main_db_path, + &self.revert_db_wal_path, + self.db_file.clone(), + OpenFlags::Create, + false, + true, + false, + )?; + let conn = db.connect()?; + conn.wal_auto_checkpoint_disable(); + Ok(conn) + } + + async fn checkpoint_passive(&mut self, coro: &Coro) -> Result<(Option>, u64)> { + let watermark = self.meta().revert_since_wal_watermark as u64; tracing::info!( - "bootstrap_db_files: draft={}, synced={}", - self.draft_path, - self.synced_path, + "checkpoint(path={:?}): revert_since_wal_watermark={}", + self.main_db_path, + watermark ); + let main_conn = connect_untracked(&self.main_tape)?; + let main_wal = self.io.try_open(&self.main_db_wal_path)?; + let main_wal_salt = if let Some(main_wal) = main_wal { + read_wal_salt(coro, &main_wal).await? + } else { + None + }; - let start_time = std::time::Instant::now(); - // cleanup all files left from previous attempt to bootstrap - // we shouldn't write any WAL files - but let's truncate them too for safety - if let Some(file) = self.io.try_open(&self.draft_path)? { - self.io.truncate(coro, file, 0).await?; - } - if let Some(file) = self.io.try_open(&self.synced_path)? { - self.io.truncate(coro, file, 0).await?; - } - if let Some(file) = self.io.try_open(&format!("{}-wal", self.draft_path))? { - self.io.truncate(coro, file, 0).await?; - } - if let Some(file) = self.io.try_open(&format!("{}-wal", self.synced_path))? { - self.io.truncate(coro, file, 0).await?; - } - - let files = &[ - self.io.create(&self.draft_path)?, - self.io.create(&self.synced_path)?, - ]; - let db_info = db_bootstrap(coro, self.protocol.as_ref(), files).await?; - - let elapsed = std::time::Instant::now().duration_since(start_time); tracing::info!( - "bootstrap_db_files: finished draft={:?}, synced={:?}: elapsed={:?}", - self.draft_path, - self.synced_path, - elapsed + "checkpoint(path={:?}): main_wal_salt={:?}", + self.main_db_path, + main_wal_salt ); - Ok(DatabaseMetadata { - client_unique_id: format!("{}-{}", self.opts.client_name, uuid::Uuid::new_v4()), - synced_generation: db_info.current_generation, - synced_frame_no: None, - draft_wal_match_watermark: 0, - synced_wal_match_watermark: 0, + let revert_since_wal_salt = self.meta().revert_since_wal_salt.clone(); + if revert_since_wal_salt.is_some() && main_wal_salt != revert_since_wal_salt { + self.update_meta(coro, |meta| { + meta.revert_since_wal_watermark = 0; + meta.revert_since_wal_salt = main_wal_salt.clone(); + }) + .await?; + return Ok((main_wal_salt, 0)); + } + // we do this Passive checkpoint in order to transfer all synced frames to the DB file and make history of revert DB valid + // if we will not do that we will be in situation where WAL in the revert DB is not valid relative to the DB file + let result = main_conn.checkpoint(turso_core::CheckpointMode::Passive { + upper_bound_inclusive: Some(watermark), + })?; + tracing::info!( + "checkpoint(path={:?}): checkpointed portion of WAL: {:?}", + self.main_db_path, + result + ); + if result.max_frame < watermark { + return Err(Error::DatabaseSyncEngineError( + format!("unable to checkpoint synced portion of WAL: result={result:?}, watermark={watermark}"), + )); + } + Ok((main_wal_salt, watermark)) + } + + pub async fn stats(&self, coro: &Coro) -> Result { + let main_conn = connect_untracked(&self.main_tape)?; + let change_id = self.meta().last_pushed_change_id_hint; + Ok(SyncEngineStats { + cdc_operations: count_local_changes(coro, &main_conn, change_id).await?, + wal_size: main_conn.wal_state()?.max_frame as i64, }) } - /// Reset WAL of Synced database which potentially can have some local changes - async fn reset_synced_if_dirty(&mut self, coro: &Coro) -> Result<()> { + pub async fn checkpoint(&mut self, coro: &Coro) -> Result<()> { + let (main_wal_salt, watermark) = self.checkpoint_passive(coro).await?; + + let main_conn = connect_untracked(&self.main_tape)?; + let revert_conn = self.open_revert_db_conn()?; + + let mut page = [0u8; PAGE_SIZE]; + let db_size = if revert_conn.try_wal_watermark_read_page(1, &mut page, None)? { + db_size_from_page(&page) + } else { + 0 + }; + tracing::info!( - "reset_synced: synced_path={:?}, synced_is_dirty={}", - self.synced_path, - self.synced_is_dirty + "checkpoint(path={:?}): revert DB initial size: {}", + self.main_db_path, + db_size ); - // if we know that Synced DB is not dirty - let's skip this phase completely - if !self.synced_is_dirty { - return Ok(()); + + let main_wal_state; + { + let mut revert_session = WalSession::new(revert_conn.clone()); + revert_session.begin()?; + + let mut main_session = WalSession::new(main_conn.clone()); + main_session.begin()?; + + main_wal_state = main_conn.wal_state()?; + tracing::info!( + "checkpoint(path={:?}): main DB WAL state: {:?}", + self.main_db_path, + main_wal_state + ); + + let mut revert_session = DatabaseWalSession::new(coro, revert_session).await?; + + let main_changed_pages = main_conn.wal_changed_pages_after(watermark)?; + tracing::info!( + "checkpoint(path={:?}): collected {} changed pages", + self.main_db_path, + main_changed_pages.len() + ); + let revert_changed_pages: HashSet = revert_conn + .wal_changed_pages_after(0)? + .into_iter() + .collect(); + for page_no in main_changed_pages { + if revert_changed_pages.contains(&page_no) { + tracing::info!( + "checkpoint(path={:?}): skip page {} as it present in revert WAL", + self.main_db_path, + page_no + ); + continue; + } + if page_no > db_size { + tracing::info!( + "checkpoint(path={:?}): skip page {} as it ahead of revert-DB size", + self.main_db_path, + page_no + ); + continue; + } + if !main_conn.try_wal_watermark_read_page(page_no, &mut page, Some(watermark))? { + tracing::info!( + "checkpoint(path={:?}): skip page {} as it was allocated in the wAL portion for revert", + self.main_db_path, + page_no + ); + continue; + } + tracing::info!( + "checkpoint(path={:?}): append page {} (current db_size={})", + self.main_db_path, + page_no, + db_size + ); + revert_session.append_page(page_no, &page)?; + } + revert_session.commit(db_size)?; + revert_session.wal_session.end(false)?; } - if let Some(synced_wal) = self.io.try_open(&format!("{}-wal", self.synced_path))? { - reset_wal_file(coro, synced_wal, self.meta().synced_frame_no.unwrap_or(0)).await?; - } - self.synced_is_dirty = false; + self.update_meta(coro, |meta| { + meta.revert_since_wal_salt = main_wal_salt; + meta.revert_since_wal_watermark = main_wal_state.max_frame; + }) + .await?; + + let result = main_conn.checkpoint(turso_core::CheckpointMode::Truncate { + upper_bound_inclusive: Some(main_wal_state.max_frame), + })?; + tracing::info!( + "checkpoint(path={:?}): main DB TRUNCATE checkpoint result: {:?}", + self.main_db_path, + result + ); + Ok(()) } - fn meta(&self) -> &DatabaseMetadata { - self.meta.as_ref().expect("metadata must be set") + pub async fn wait_changes_from_remote( + &self, + coro: &Coro, + ) -> Result> { + let file_path = format!("{}-frames-{}", self.main_db_path, Uuid::new_v4()); + tracing::info!( + "wait_changes(path={}): file_path={}", + self.main_db_path, + file_path + ); + let file = self.io.create(&file_path)?; + + let revision = self.meta().synced_revision.clone().unwrap(); + let next_revision = wal_pull_to_file( + coro, + self.protocol.as_ref(), + file.clone(), + &revision, + self.opts.wal_pull_batch_size, + ) + .await?; + + if file.size()? == 0 { + tracing::info!( + "wait_changes(path={}): no changes detected, removing changes file {}", + self.main_db_path, + file_path + ); + self.io.remove_file(&file_path)?; + return Ok(None); + } + + tracing::info!( + "wait_changes_from_remote(path={}): revision: {:?} -> {:?}", + self.main_db_path, + revision, + next_revision + ); + + Ok(Some(DbChangesStatus { + revision: next_revision, + file_path, + })) + } + + /// Sync all new changes from remote DB and apply them locally + /// This method will **not** send local changed to the remote + /// This method will block writes for the period of pull + pub async fn apply_changes_from_remote( + &mut self, + coro: &Coro, + remote_changes: DbChangesStatus, + ) -> Result<()> { + let pull_result = self.apply_changes_internal(coro, &remote_changes).await; + let cleanup_result: Result<()> = self + .io + .remove_file(&remote_changes.file_path) + .inspect_err(|e| tracing::error!("failed to cleanup changes file: {e}")) + .map_err(|e| e.into()); + let Ok(revert_since_wal_watermark) = pull_result else { + return Err(pull_result.err().unwrap()); + }; + + let revert_wal_file = self.io.open_file( + &self.revert_db_wal_path, + turso_core::OpenFlags::Create, + false, + )?; + reset_wal_file(coro, revert_wal_file, 0).await?; + + self.update_meta(coro, |meta| { + meta.revert_since_wal_watermark = revert_since_wal_watermark; + meta.synced_revision = Some(remote_changes.revision); + meta.last_pushed_change_id_hint = 0; + }) + .await?; + + cleanup_result + } + async fn apply_changes_internal( + &mut self, + coro: &Coro, + remote_changes: &DbChangesStatus, + ) -> Result { + tracing::info!( + "apply_changes(path={}, changes={:?})", + self.main_db_path, + remote_changes + ); + + let (_, watermark) = self.checkpoint_passive(coro).await?; + + let changes_file = self.io.open_file( + &remote_changes.file_path, + turso_core::OpenFlags::empty(), + false, + )?; + + let revert_conn = self.open_revert_db_conn()?; + let main_conn = connect_untracked(&self.main_tape)?; + + let mut revert_session = WalSession::new(revert_conn.clone()); + revert_session.begin()?; + + let mut main_session = WalSession::new(main_conn.clone()); + main_session.begin()?; + + let had_cdc_table = has_table(coro, &main_conn, "turso_cdc").await?; + + // read schema version after initiating WAL session (in order to read it with consistent max_frame_no) + let main_conn_schema_version = main_conn.read_schema_version()?; + + let mut main_session = DatabaseWalSession::new(coro, main_session).await?; + + // fetch last_change_id from remote + let (pull_gen, last_change_id) = fetch_last_change_id( + coro, + self.protocol.as_ref(), + &main_conn, + &self.meta().client_unique_id, + ) + .await?; + + // collect local changes before doing anything with the main DB + // it's important to do this after opening WAL session - otherwise we can miss some updates + let iterate_opts = DatabaseChangesIteratorOpts { + first_change_id: last_change_id.map(|x| x + 1), + mode: DatabaseChangesIteratorMode::Apply, + ignore_schema_changes: false, + ..Default::default() + }; + let mut local_changes = Vec::new(); + let mut iterator = self.main_tape.iterate_changes(iterate_opts)?; + while let Some(operation) = iterator.next(coro).await? { + match operation { + DatabaseTapeOperation::RowChange(change) => local_changes.push(change), + DatabaseTapeOperation::Commit => continue, + } + } + tracing::info!( + "apply_changes(path={}): collected {} changes", + self.main_db_path, + local_changes.len() + ); + + // rollback local changes not checkpointed to the revert-db + tracing::info!( + "apply_changes(path={}): rolling back frames after {} watermark, max_frame={}", + self.main_db_path, + watermark, + main_conn.wal_state()?.max_frame + ); + let local_rollback = main_session.rollback_changes_after(watermark)?; + let mut frame = [0u8; WAL_FRAME_SIZE]; + + let remote_rollback = revert_conn.wal_state()?.max_frame; + tracing::info!( + "apply_changes(path={}): rolling back {} frames from revert DB", + self.main_db_path, + remote_rollback + ); + // rollback local changes by using frames from revert-db + // it's important to append pages from revert-db after local revert - because pages from revert-db must overwrite rollback from main DB + for frame_no in 1..=remote_rollback { + let info = revert_session.read_at(frame_no, &mut frame)?; + main_session.append_page(info.page_no, &frame[WAL_FRAME_HEADER..])?; + } + + // after rollback - WAL state is aligned with remote - let's apply changes from it + let db_size = wal_apply_from_file(coro, changes_file, &mut main_session).await?; + tracing::info!( + "apply_changes(path={}): applied changes from remote: db_size={}", + self.main_db_path, + db_size, + ); + + let revert_since_wal_watermark; + if local_changes.is_empty() && local_rollback == 0 && remote_rollback == 0 && !had_cdc_table + { + main_session.commit(db_size)?; + revert_since_wal_watermark = main_session.frames_count()?; + main_session.wal_session.end(false)?; + } else { + main_session.commit(0)?; + let current_schema_version = main_conn.read_schema_version()?; + revert_since_wal_watermark = main_session.frames_count()?; + let final_schema_version = current_schema_version.max(main_conn_schema_version) + 1; + main_conn.write_schema_version(final_schema_version)?; + tracing::info!( + "apply_changes(path={}): updated schema version to {}", + self.main_db_path, + final_schema_version + ); + + update_last_change_id( + coro, + &main_conn, + &self.meta().client_unique_id, + pull_gen + 1, + 0, + ) + .await + .inspect_err(|e| tracing::error!("update_last_change_id failed: {e}"))?; + + if had_cdc_table { + tracing::info!( + "apply_changes(path={}): initiate CDC pragma again in order to recreate CDC table", + self.main_db_path, + ); + let _ = main_conn.pragma_update(CDC_PRAGMA_NAME, "'full'")?; + } + + let mut replay = DatabaseReplaySession { + conn: main_conn.clone(), + cached_delete_stmt: HashMap::new(), + cached_insert_stmt: HashMap::new(), + cached_update_stmt: HashMap::new(), + in_txn: true, + generator: DatabaseReplayGenerator:: { + conn: main_conn.clone(), + opts: DatabaseReplaySessionOpts:: { + use_implicit_rowid: false, + transform: self.opts.transform.clone(), + }, + }, + }; + for change in local_changes { + let operation = DatabaseTapeOperation::RowChange(change); + replay.replay(coro, operation).await?; + } + + main_session.wal_session.end(true)?; + } + + Ok(revert_since_wal_watermark) + } + + /// Sync local changes to remote DB + /// This method will **not** pull remote changes to the local DB + /// This method will **not** block writes for the period of sync + pub async fn push_changes_to_remote(&self, coro: &Coro) -> Result<()> { + tracing::info!("push_changes(path={})", self.main_db_path); + + let (_, change_id) = push_logical_changes( + coro, + self.protocol.as_ref(), + &self.main_tape, + &self.meta().client_unique_id, + &self.opts, + ) + .await?; + + self.update_meta(coro, |m| { + m.last_pushed_change_id_hint = change_id; + }) + .await?; + + Ok(()) + } + + /// Create read/write database connection and appropriately configure it before use + pub async fn connect_rw(&self, coro: &Coro) -> Result> { + let conn = self.main_tape.connect(coro).await?; + conn.wal_auto_checkpoint_disable(); + Ok(conn) + } + + /// Sync local changes to remote DB and bring new changes from remote to local + /// This method will block writes for the period of sync + pub async fn sync(&mut self, coro: &Coro) -> Result<()> { + // todo(sivukhin): this is bit suboptimal as both 'push' and 'pull' will call pull_synced_from_remote + // but for now - keep it simple + self.push_changes_to_remote(coro).await?; + if let Some(changes) = self.wait_changes_from_remote(coro).await? { + self.apply_changes_from_remote(coro, changes).await?; + } + Ok(()) + } + + fn meta(&self) -> std::cell::Ref<'_, DatabaseMetadata> { + self.meta.borrow() + } + + async fn update_meta( + &self, + coro: &Coro, + update: impl FnOnce(&mut DatabaseMetadata), + ) -> Result<()> { + let mut meta = self.meta().clone(); + update(&mut meta); + tracing::info!("update_meta: {meta:?}"); + let completion = self.protocol.full_write(&self.meta_path, meta.dump()?)?; + // todo: what happen if we will actually update the metadata on disk but fail and so in memory state will not be updated + wait_full_body(coro, &completion).await?; + self.meta.replace(meta); + Ok(()) } } diff --git a/sync/engine/src/database_sync_operations.rs b/sync/engine/src/database_sync_operations.rs index f43881480..9101d9829 100644 --- a/sync/engine/src/database_sync_operations.rs +++ b/sync/engine/src/database_sync_operations.rs @@ -1,19 +1,29 @@ use std::sync::Arc; -use turso_core::{types::Text, Buffer, Completion, LimboError, Value}; +use bytes::BytesMut; +use prost::Message; +use turso_core::{ + types::{Text, WalFrameInfo}, + Buffer, Completion, LimboError, OpenFlags, Value, +}; use crate::{ database_replay_generator::DatabaseReplayGenerator, + database_sync_engine::DatabaseSyncEngineOpts, database_tape::{ - exec_stmt, run_stmt_expect_one_row, DatabaseChangesIteratorMode, + run_stmt_expect_one_row, run_stmt_ignore_rows, DatabaseChangesIteratorMode, DatabaseChangesIteratorOpts, DatabaseReplaySessionOpts, DatabaseTape, DatabaseWalSession, }, errors::Error, + io_operations::IoOperations, protocol_io::{DataCompletion, DataPollResult, ProtocolIO}, - server_proto::{self, ExecuteStreamReq, Stmt, StreamRequest}, + server_proto::{ + self, ExecuteStreamReq, PageData, PageUpdatesEncodingReq, PullUpdatesReqProtoBody, + PullUpdatesRespProtoBody, Stmt, StmtResult, StreamRequest, + }, types::{ - Coro, DatabaseTapeOperation, DatabaseTapeRowChangeType, DbSyncInfo, DbSyncStatus, - ProtocolCommand, + Coro, DatabasePullRevision, DatabaseSyncEngineProtocolVersion, DatabaseTapeOperation, + DatabaseTapeRowChangeType, DbSyncInfo, DbSyncStatus, ProtocolCommand, }, wal_session::WalSession, Result, @@ -21,31 +31,19 @@ use crate::{ pub const WAL_HEADER: usize = 32; pub const WAL_FRAME_HEADER: usize = 24; -const PAGE_SIZE: usize = 4096; -const WAL_FRAME_SIZE: usize = WAL_FRAME_HEADER + PAGE_SIZE; +pub const PAGE_SIZE: usize = 4096; +pub const WAL_FRAME_SIZE: usize = WAL_FRAME_HEADER + PAGE_SIZE; enum WalHttpPullResult { Frames(C), NeedCheckpoint(DbSyncStatus), } -pub enum WalPullResult { - Done, - PullMore, - NeedCheckpoint, -} - pub enum WalPushResult { Ok { baton: Option }, NeedCheckpoint, } -pub async fn connect(coro: &Coro, tape: &DatabaseTape) -> Result> { - let conn = tape.connect(coro).await?; - conn.wal_auto_checkpoint_disable(); - Ok(conn) -} - pub fn connect_untracked(tape: &DatabaseTape) -> Result> { let conn = tape.connect_untracked()?; conn.wal_auto_checkpoint_disable(); @@ -53,10 +51,10 @@ pub fn connect_untracked(tape: &DatabaseTape) -> Result( - coro: &Coro, +pub async fn db_bootstrap( + coro: &Coro, client: &C, - dbs: &[Arc], + db: Arc, ) -> Result { tracing::debug!("db_bootstrap"); let start_time = std::time::Instant::now(); @@ -72,18 +70,15 @@ pub async fn db_bootstrap( #[allow(clippy::arc_with_non_send_sync)] let buffer = Arc::new(Buffer::new_temporary(chunk.len())); buffer.as_mut_slice().copy_from_slice(chunk); - let mut completions = Vec::with_capacity(dbs.len()); - for db in dbs { - let c = Completion::new_write(move |res| { - let Ok(size) = res else { - return; - }; - // todo(sivukhin): we need to error out in case of partial read - assert!(size as usize == content_len); - }); - completions.push(db.pwrite(pos, buffer.clone(), c)?); - } - while !completions.iter().all(|x| x.is_completed()) { + let c = Completion::new_write(move |result| { + // todo(sivukhin): we need to error out in case of partial read + let Ok(size) = result else { + return; + }; + assert!(size as usize == content_len); + }); + let c = db.pwrite(pos, buffer.clone(), c)?; + while !c.is_completed() { coro.yield_(ProtocolCommand::IO).await?; } pos += content_len; @@ -95,14 +90,11 @@ pub async fn db_bootstrap( } // sync files in the end - let mut completions = Vec::with_capacity(dbs.len()); - for db in dbs { - let c = Completion::new_sync(move |_| { - // todo(sivukhin): we need to error out in case of failed sync - }); - completions.push(db.sync(c)?); - } - while !completions.iter().all(|x| x.is_completed()) { + let c = Completion::new_sync(move |_| { + // todo(sivukhin): we need to error out in case of failed sync + }); + let c = db.sync(c)?; + while !c.is_completed() { coro.yield_(ProtocolCommand::IO).await?; } @@ -112,92 +104,275 @@ pub async fn db_bootstrap( Ok(db_info) } -/// Pull updates from remote to the database file -/// -/// Returns [WalPullResult::Done] if pull reached the end of database history -/// Returns [WalPullResult::PullMore] if all frames from [start_frame..end_frame) range were pulled, but remote have more -/// Returns [WalPullResult::NeedCheckpoint] if remote generation increased and local version must be checkpointed -/// -/// Guarantees: -/// 1. Frames are commited to the WAL (i.e. db_size is not zero 0) only at transaction boundaries from remote -/// 2. wal_pull is idempotent for fixed generation and can be called multiple times with same frame range -pub async fn wal_pull<'a, C: ProtocolIO, U: AsyncFnMut(&'a Coro, u64) -> Result<()>>( - coro: &'a Coro, +pub async fn wal_apply_from_file( + coro: &Coro, + frames_file: Arc, + session: &mut DatabaseWalSession, +) -> Result { + let size = frames_file.size()?; + assert!(size % WAL_FRAME_SIZE as u64 == 0); + let buffer = Arc::new(Buffer::new_temporary(WAL_FRAME_SIZE)); + tracing::debug!("wal_apply_from_file: size={}", size); + let mut db_size = 0; + for offset in (0..size).step_by(WAL_FRAME_SIZE) { + let c = Completion::new_read(buffer.clone(), move |result| { + let Ok((_, size)) = result else { + return; + }; + // todo(sivukhin): we need to error out in case of partial read + assert!(size as usize == WAL_FRAME_SIZE); + }); + let c = frames_file.pread(offset as usize, c)?; + while !c.is_completed() { + coro.yield_(ProtocolCommand::IO).await?; + } + let info = WalFrameInfo::from_frame_header(buffer.as_slice()); + tracing::debug!("got frame: {:?}", info); + db_size = info.db_size; + session.append_page(info.page_no, &buffer.as_slice()[WAL_FRAME_HEADER..])?; + } + assert!(db_size > 0); + Ok(db_size) +} + +pub async fn wal_pull_to_file( + coro: &Coro, client: &C, - wal_session: &mut WalSession, - generation: u64, + frames_file: Arc, + revision: &DatabasePullRevision, + wal_pull_batch_size: u64, +) -> Result { + match revision { + DatabasePullRevision::Legacy { + generation, + synced_frame_no, + } => { + let start_frame = synced_frame_no.unwrap_or(0) + 1; + wal_pull_to_file_legacy( + coro, + client, + frames_file, + *generation, + start_frame, + wal_pull_batch_size, + ) + .await + } + DatabasePullRevision::V1 { revision } => { + wal_pull_to_file_v1(coro, client, frames_file, revision).await + } + } +} + +/// Pull updates from remote to the separate file +pub async fn wal_pull_to_file_v1( + coro: &Coro, + client: &C, + frames_file: Arc, + revision: &str, +) -> Result { + tracing::info!("wal_pull: revision={revision}"); + let mut bytes = BytesMut::new(); + + let request = PullUpdatesReqProtoBody { + encoding: PageUpdatesEncodingReq::Raw as i32, + server_revision: String::new(), + client_revision: revision.to_string(), + long_poll_timeout_ms: 0, + server_pages: BytesMut::new().into(), + client_pages: BytesMut::new().into(), + }; + let request = request.encode_to_vec(); + let completion = client.http( + "POST", + "/pull-updates", + Some(request), + &[ + ("content-type", "application/protobuf"), + ("accept-encoding", "application/protobuf"), + ], + )?; + let Some(header) = + wait_proto_message::(coro, &completion, &mut bytes).await? + else { + return Err(Error::DatabaseSyncEngineError(format!( + "no header returned in the pull-updates protobuf call" + ))); + }; + tracing::info!("wal_pull_to_file: got header={:?}", header); + + let mut offset = 0; + let buffer = Arc::new(Buffer::new_temporary(WAL_FRAME_SIZE)); + + let mut page_data_opt = + wait_proto_message::(coro, &completion, &mut bytes).await?; + while let Some(page_data) = page_data_opt.take() { + let page_id = page_data.page_id; + tracing::info!("received page {}", page_id); + let page = decode_page(&header, page_data)?; + if page.len() != PAGE_SIZE { + return Err(Error::DatabaseSyncEngineError(format!( + "page has unexpected size: {} != {}", + page.len(), + PAGE_SIZE + ))); + } + buffer.as_mut_slice()[WAL_FRAME_HEADER..].copy_from_slice(&page); + page_data_opt = wait_proto_message(coro, &completion, &mut bytes).await?; + let mut frame_info = WalFrameInfo { + db_size: 0, + page_no: page_id as u32 + 1, + }; + if page_data_opt.is_none() { + frame_info.db_size = header.db_size as u32; + } + tracing::info!("page_data_opt: {}", page_data_opt.is_some()); + frame_info.put_to_frame_header(buffer.as_mut_slice()); + + let c = Completion::new_write(move |result| { + // todo(sivukhin): we need to error out in case of partial read + let Ok(size) = result else { + return; + }; + assert!(size as usize == WAL_FRAME_SIZE); + }); + + let c = frames_file.pwrite(offset, buffer.clone(), c)?; + while !c.is_completed() { + coro.yield_(ProtocolCommand::IO).await?; + } + offset += WAL_FRAME_SIZE; + } + + let c = Completion::new_sync(move |_| { + // todo(sivukhin): we need to error out in case of failed sync + }); + let c = frames_file.sync(c)?; + while !c.is_completed() { + coro.yield_(ProtocolCommand::IO).await?; + } + + Ok(DatabasePullRevision::V1 { + revision: header.server_revision, + }) +} + +/// Pull updates from remote to the separate file +pub async fn wal_pull_to_file_legacy( + coro: &Coro, + client: &C, + frames_file: Arc, + mut generation: u64, mut start_frame: u64, - end_frame: u64, - mut update: U, -) -> Result { + wal_pull_batch_size: u64, +) -> Result { tracing::info!( - "wal_pull: generation={}, start_frame={}, end_frame={}", - generation, - start_frame, - end_frame + "wal_pull: generation={generation}, start_frame={start_frame}, wal_pull_batch_size={wal_pull_batch_size}" ); // todo(sivukhin): optimize allocation by using buffer pool in the DatabaseSyncOperations - let mut buffer = Vec::with_capacity(WAL_FRAME_SIZE); - - let result = wal_pull_http(coro, client, generation, start_frame, end_frame).await?; - let data = match result { - WalHttpPullResult::NeedCheckpoint(status) => { - assert!(status.status == "checkpoint_needed"); - tracing::debug!("wal_pull: need checkpoint: status={status:?}"); - if status.generation == generation && status.max_frame_no < start_frame { - tracing::debug!("wal_pull: end of history: status={:?}", status); - update(coro, status.max_frame_no).await?; - return Ok(WalPullResult::Done); + let buffer = Arc::new(Buffer::new_temporary(WAL_FRAME_SIZE)); + let mut buffer_len = 0; + let mut last_offset = 0; + let mut committed_len = 0; + let revision = loop { + let end_frame = start_frame + wal_pull_batch_size; + let result = wal_pull_http(coro, client, generation, start_frame, end_frame).await?; + let data = match result { + WalHttpPullResult::NeedCheckpoint(status) => { + assert!(status.status == "checkpoint_needed"); + tracing::debug!("wal_pull: need checkpoint: status={status:?}"); + if status.generation == generation && status.max_frame_no < start_frame { + tracing::debug!("wal_pull: end of history: status={:?}", status); + break DatabasePullRevision::Legacy { + generation: status.generation, + synced_frame_no: Some(status.max_frame_no), + }; + } + generation += 1; + start_frame = 1; + continue; } - return Ok(WalPullResult::NeedCheckpoint); + WalHttpPullResult::Frames(content) => content, + }; + loop { + while let Some(chunk) = data.poll_data()? { + let mut chunk = chunk.data(); + while !chunk.is_empty() { + let to_fill = (WAL_FRAME_SIZE - buffer_len).min(chunk.len()); + buffer.as_mut_slice()[buffer_len..buffer_len + to_fill] + .copy_from_slice(&chunk[0..to_fill]); + buffer_len += to_fill; + chunk = &chunk[to_fill..]; + + if buffer_len < WAL_FRAME_SIZE { + continue; + } + let c = Completion::new_write(move |result| { + // todo(sivukhin): we need to error out in case of partial read + let Ok(size) = result else { + return; + }; + assert!(size as usize == WAL_FRAME_SIZE); + }); + let c = frames_file.pwrite(last_offset, buffer.clone(), c)?; + while !c.is_completed() { + coro.yield_(ProtocolCommand::IO).await?; + } + + last_offset += WAL_FRAME_SIZE; + buffer_len = 0; + start_frame += 1; + + let info = WalFrameInfo::from_frame_header(buffer.as_slice()); + if info.is_commit_frame() { + committed_len = last_offset; + } + } + } + if data.is_done()? { + break; + } + coro.yield_(ProtocolCommand::IO).await?; + } + if start_frame < end_frame { + // chunk which was sent from the server has ended early - so there is nothing left on server-side for pull + break DatabasePullRevision::Legacy { + generation: generation, + synced_frame_no: Some(start_frame - 1), + }; + } + if buffer_len != 0 { + return Err(Error::DatabaseSyncEngineError(format!( + "wal_pull: response has unexpected trailing data: buffer_len={}", + buffer_len + ))); } - WalHttpPullResult::Frames(content) => content, }; - loop { - while let Some(chunk) = data.poll_data()? { - let mut chunk = chunk.data(); - while !chunk.is_empty() { - let to_fill = (WAL_FRAME_SIZE - buffer.len()).min(chunk.len()); - buffer.extend_from_slice(&chunk[0..to_fill]); - chunk = &chunk[to_fill..]; - assert!( - buffer.capacity() == WAL_FRAME_SIZE, - "buffer should not extend its capacity" - ); - if buffer.len() < WAL_FRAME_SIZE { - continue; - } - if !wal_session.in_txn() { - wal_session.begin()?; - } - let frame_info = wal_session.insert_at(start_frame, &buffer)?; - if frame_info.is_commit_frame() { - wal_session.end()?; - // transaction boundary reached - safe to commit progress - update(coro, start_frame).await?; - } - buffer.clear(); - start_frame += 1; - } - } - if data.is_done()? { - break; - } + tracing::info!( + "wal_pull: generation={generation}, frame={start_frame}, last_offset={last_offset}, commited_len={committed_len}" + ); + let c = Completion::new_trunc(move |result| { + let Ok(rc) = result else { + return; + }; + assert!(rc as usize == 0); + }); + let c = frames_file.truncate(committed_len, c)?; + while !c.is_completed() { coro.yield_(ProtocolCommand::IO).await?; } - if start_frame < end_frame { - // chunk which was sent from the server has ended early - so there is nothing left on server-side for pull - return Ok(WalPullResult::Done); + + let c = Completion::new_sync(move |_| { + // todo(sivukhin): we need to error out in case of failed sync + }); + let c = frames_file.sync(c)?; + while !c.is_completed() { + coro.yield_(ProtocolCommand::IO).await?; } - if !buffer.is_empty() { - return Err(Error::DatabaseSyncEngineError(format!( - "wal_pull: response has unexpected trailing data: buffer.len()={}", - buffer.len() - ))); - } - Ok(WalPullResult::PullMore) + + Ok(revision) } /// Push frame range [start_frame..end_frame) to the remote @@ -207,8 +382,8 @@ pub async fn wal_pull<'a, C: ProtocolIO, U: AsyncFnMut(&'a Coro, u64) -> Result< /// Guarantees: /// 1. If there is a single client which calls wal_push, then this operation is idempotent for fixed generation /// and can be called multiple times with same frame range -pub async fn wal_push( - coro: &Coro, +pub async fn wal_push( + coro: &Coro, client: &C, wal_session: &mut WalSession, baton: Option, @@ -262,172 +437,17 @@ pub async fn wal_push( } } -const TURSO_SYNC_TABLE_NAME: &str = "turso_sync_last_change_id"; -const TURSO_SYNC_CREATE_TABLE: &str = +pub const TURSO_SYNC_TABLE_NAME: &str = "turso_sync_last_change_id"; +pub const TURSO_SYNC_CREATE_TABLE: &str = "CREATE TABLE IF NOT EXISTS turso_sync_last_change_id (client_id TEXT PRIMARY KEY, pull_gen INTEGER, change_id INTEGER)"; +pub const TURSO_SYNC_INSERT_LAST_CHANGE_ID: &str = + "INSERT INTO turso_sync_last_change_id(client_id, pull_gen, change_id) VALUES (?, ?, ?)"; +pub const TURSO_SYNC_UPSERT_LAST_CHANGE_ID: &str = + "INSERT INTO turso_sync_last_change_id(client_id, pull_gen, change_id) VALUES (?, ?, ?) ON CONFLICT(client_id) DO UPDATE SET pull_gen=excluded.pull_gen, change_id=excluded.change_id"; +pub const TURSO_SYNC_UPDATE_LAST_CHANGE_ID: &str = + "UPDATE turso_sync_last_change_id SET pull_gen = ?, change_id = ? WHERE client_id = ?"; const TURSO_SYNC_SELECT_LAST_CHANGE_ID: &str = "SELECT pull_gen, change_id FROM turso_sync_last_change_id WHERE client_id = ?"; -const TURSO_SYNC_INSERT_LAST_CHANGE_ID: &str = - "INSERT INTO turso_sync_last_change_id(client_id, pull_gen, change_id) VALUES (?, 0, 0)"; -const TURSO_SYNC_UPDATE_LAST_CHANGE_ID: &str = - "UPDATE turso_sync_last_change_id SET pull_gen = ?, change_id = ? WHERE client_id = ?"; -const TURSO_SYNC_UPSERT_LAST_CHANGE_ID: &str = - "INSERT INTO turso_sync_last_change_id(client_id, pull_gen, change_id) VALUES (?, ?, ?) ON CONFLICT(client_id) DO UPDATE SET pull_gen=excluded.pull_gen, change_id=excluded.change_id"; - -/// Transfers row changes from source DB to target DB -/// In order to guarantee atomicity and avoid conflicts - method maintain last_change_id counter in the target db table turso_sync_last_change_id -pub async fn transfer_logical_changes( - coro: &Coro, - source: &DatabaseTape, - target: &DatabaseTape, - client_id: &str, - bump_pull_gen: bool, -) -> Result<()> { - tracing::info!("transfer_logical_changes: client_id={client_id}"); - let source_conn = connect_untracked(source)?; - let target_conn = connect_untracked(target)?; - - // fetch last_change_id from the target DB in order to guarantee atomic replay of changes and avoid conflicts in case of failure - let source_pull_gen = 'source_pull_gen: { - let mut select_last_change_id_stmt = - match source_conn.prepare(TURSO_SYNC_SELECT_LAST_CHANGE_ID) { - Ok(stmt) => stmt, - Err(LimboError::ParseError(..)) => break 'source_pull_gen 0, - Err(err) => return Err(err.into()), - }; - - select_last_change_id_stmt - .bind_at(1.try_into().unwrap(), Value::Text(Text::new(client_id))); - - match run_stmt_expect_one_row(coro, &mut select_last_change_id_stmt).await? { - Some(row) => row[0].as_int().ok_or_else(|| { - Error::DatabaseSyncEngineError("unexpected source pull_gen type".to_string()) - })?, - None => { - tracing::info!("transfer_logical_changes: client_id={client_id}, turso_sync_last_change_id table is not found"); - 0 - } - } - }; - tracing::info!( - "transfer_logical_changes: client_id={client_id}, source_pull_gen={source_pull_gen}" - ); - - // fetch last_change_id from the target DB in order to guarantee atomic replay of changes and avoid conflicts in case of failure - let mut schema_stmt = target_conn.prepare(TURSO_SYNC_CREATE_TABLE)?; - exec_stmt(coro, &mut schema_stmt).await?; - - let mut select_last_change_id_stmt = target_conn.prepare(TURSO_SYNC_SELECT_LAST_CHANGE_ID)?; - select_last_change_id_stmt.bind_at(1.try_into().unwrap(), Value::Text(Text::new(client_id))); - - let mut last_change_id = match run_stmt_expect_one_row(coro, &mut select_last_change_id_stmt) - .await? - { - Some(row) => { - let target_pull_gen = row[0].as_int().ok_or_else(|| { - Error::DatabaseSyncEngineError("unexpected target pull_gen type".to_string()) - })?; - let target_change_id = row[1].as_int().ok_or_else(|| { - Error::DatabaseSyncEngineError("unexpected target change_id type".to_string()) - })?; - tracing::debug!( - "transfer_logical_changes: client_id={client_id}, target_pull_gen={target_pull_gen}, target_change_id={target_change_id}" - ); - if target_pull_gen > source_pull_gen { - return Err(Error::DatabaseSyncEngineError(format!("protocol error: target_pull_gen > source_pull_gen: {target_pull_gen} > {source_pull_gen}"))); - } - if target_pull_gen == source_pull_gen { - Some(target_change_id) - } else { - Some(0) - } - } - None => { - let mut insert_last_change_id_stmt = - target_conn.prepare(TURSO_SYNC_INSERT_LAST_CHANGE_ID)?; - insert_last_change_id_stmt - .bind_at(1.try_into().unwrap(), Value::Text(Text::new(client_id))); - exec_stmt(coro, &mut insert_last_change_id_stmt).await?; - None - } - }; - - tracing::debug!( - "transfer_logical_changes: last_change_id={:?}", - last_change_id - ); - let replay_opts = DatabaseReplaySessionOpts { - use_implicit_rowid: false, - }; - - let source_schema_cookie = connect_untracked(source)?.read_schema_version()?; - - let mut session = target.start_replay_session(coro, replay_opts).await?; - - let iterate_opts = DatabaseChangesIteratorOpts { - first_change_id: last_change_id.map(|x| x + 1), - mode: DatabaseChangesIteratorMode::Apply, - ignore_schema_changes: false, - ..Default::default() - }; - let mut rows_changed = 0; - let mut changes = source.iterate_changes(iterate_opts)?; - while let Some(operation) = changes.next(coro).await? { - match &operation { - DatabaseTapeOperation::RowChange(change) => { - assert!( - last_change_id.is_none() || last_change_id.unwrap() < change.change_id, - "change id must be strictly increasing: last_change_id={:?}, change.change_id={}", - last_change_id, - change.change_id - ); - if change.table_name == TURSO_SYNC_TABLE_NAME { - continue; - } - rows_changed += 1; - // we give user full control over CDC table - so let's not emit assert here for now - if last_change_id.is_some() && last_change_id.unwrap() + 1 != change.change_id { - tracing::warn!( - "out of order change sequence: {} -> {}", - last_change_id.unwrap(), - change.change_id - ); - } - last_change_id = Some(change.change_id); - } - DatabaseTapeOperation::Commit if rows_changed > 0 || bump_pull_gen => { - tracing::info!("prepare update stmt for turso_sync_last_change_id table with client_id={} and last_change_id={:?}", client_id, last_change_id); - // update turso_sync_last_change_id table with new value before commit - let mut set_last_change_id_stmt = - session.conn().prepare(TURSO_SYNC_UPDATE_LAST_CHANGE_ID)?; - let (next_pull_gen, next_change_id) = if bump_pull_gen { - (source_pull_gen + 1, 0) - } else { - (source_pull_gen, last_change_id.unwrap_or(0)) - }; - tracing::info!("transfer_logical_changes: client_id={client_id}, set pull_gen={next_pull_gen}, change_id={next_change_id}, rows_changed={rows_changed}"); - set_last_change_id_stmt - .bind_at(1.try_into().unwrap(), Value::Integer(next_pull_gen)); - set_last_change_id_stmt - .bind_at(2.try_into().unwrap(), Value::Integer(next_change_id)); - set_last_change_id_stmt - .bind_at(3.try_into().unwrap(), Value::Text(Text::new(client_id))); - exec_stmt(coro, &mut set_last_change_id_stmt).await?; - let session_schema_cookie = session.conn().read_schema_version()?; - if session_schema_cookie <= source_schema_cookie { - session - .conn() - .write_schema_version(source_schema_cookie + 1)?; - } - } - _ => {} - } - session.replay(coro, operation).await?; - } - - tracing::info!("transfer_logical_changes: rows_changed={:?}", rows_changed); - Ok(()) -} fn convert_to_args(values: Vec) -> Vec { values @@ -446,16 +466,95 @@ fn convert_to_args(values: Vec) -> Vec { .collect() } -pub async fn push_logical_changes( - coro: &Coro, - client: &C, - source: &DatabaseTape, - target: &DatabaseTape, +pub async fn has_table( + coro: &Coro, + conn: &Arc, + table_name: &str, +) -> Result { + let mut stmt = + conn.prepare("SELECT COUNT(*) FROM sqlite_schema WHERE type = 'table' AND name = ?")?; + stmt.bind_at(1.try_into().unwrap(), Value::Text(Text::new(table_name))); + + let count = match run_stmt_expect_one_row(coro, &mut stmt).await? { + Some(row) => row[0] + .as_int() + .ok_or_else(|| Error::DatabaseSyncEngineError("unexpected column type".to_string()))?, + _ => panic!("expected single row"), + }; + Ok(count > 0) +} + +pub async fn count_local_changes( + coro: &Coro, + conn: &Arc, + change_id: i64, +) -> Result { + let mut stmt = conn.prepare("SELECT COUNT(*) FROM turso_cdc WHERE change_id > ?")?; + stmt.bind_at(1.try_into().unwrap(), Value::Integer(change_id)); + + let count = match run_stmt_expect_one_row(coro, &mut stmt).await? { + Some(row) => row[0] + .as_int() + .ok_or_else(|| Error::DatabaseSyncEngineError("unexpected column type".to_string()))?, + _ => panic!("expected single row"), + }; + Ok(count) +} + +pub async fn update_last_change_id( + coro: &Coro, + conn: &Arc, client_id: &str, + pull_gen: i64, + change_id: i64, ) -> Result<()> { - tracing::info!("push_logical_changes: client_id={client_id}"); - let source_conn = connect_untracked(source)?; - let target_conn = connect_untracked(target)?; + tracing::info!( + "update_last_change_id(client_id={client_id}): pull_gen={pull_gen}, change_id={change_id}" + ); + conn.execute(TURSO_SYNC_CREATE_TABLE)?; + tracing::info!("update_last_change_id(client_id={client_id}): initialized table"); + let mut select_stmt = conn.prepare(TURSO_SYNC_SELECT_LAST_CHANGE_ID)?; + select_stmt.bind_at( + 1.try_into().unwrap(), + turso_core::Value::Text(turso_core::types::Text::new(client_id)), + ); + let row = run_stmt_expect_one_row(coro, &mut select_stmt).await?; + tracing::info!("update_last_change_id(client_id={client_id}): selected client row if any"); + + if let Some(_) = row { + let mut update_stmt = conn.prepare(TURSO_SYNC_UPDATE_LAST_CHANGE_ID)?; + update_stmt.bind_at(1.try_into().unwrap(), turso_core::Value::Integer(pull_gen)); + update_stmt.bind_at(2.try_into().unwrap(), turso_core::Value::Integer(change_id)); + update_stmt.bind_at( + 3.try_into().unwrap(), + turso_core::Value::Text(turso_core::types::Text::new(client_id)), + ); + run_stmt_ignore_rows(coro, &mut update_stmt).await?; + tracing::info!("update_last_change_id(client_id={client_id}): updated row for the client"); + } else { + let mut update_stmt = conn.prepare(TURSO_SYNC_INSERT_LAST_CHANGE_ID)?; + update_stmt.bind_at( + 1.try_into().unwrap(), + turso_core::Value::Text(turso_core::types::Text::new(client_id)), + ); + update_stmt.bind_at(2.try_into().unwrap(), turso_core::Value::Integer(pull_gen)); + update_stmt.bind_at(3.try_into().unwrap(), turso_core::Value::Integer(change_id)); + run_stmt_ignore_rows(coro, &mut update_stmt).await?; + tracing::info!( + "update_last_change_id(client_id={client_id}): inserted new row for the client" + ); + } + + Ok(()) +} + +pub async fn fetch_last_change_id( + coro: &Coro, + client: &C, + source_conn: &Arc, + client_id: &str, +) -> Result<(i64, Option)> { + tracing::info!("fetch_last_change_id: client_id={client_id}"); // fetch last_change_id from the target DB in order to guarantee atomic replay of changes and avoid conflicts in case of failure let source_pull_gen = 'source_pull_gen: { @@ -474,61 +573,101 @@ pub async fn push_logical_changes( Error::DatabaseSyncEngineError("unexpected source pull_gen type".to_string()) })?, None => { - tracing::info!("push_logical_changes: client_id={client_id}, turso_sync_last_change_id table is not found"); + tracing::info!("fetch_last_change_id: client_id={client_id}, turso_sync_last_change_id table is not found"); 0 } } }; tracing::info!( - "push_logical_changes: client_id={client_id}, source_pull_gen={source_pull_gen}" + "fetch_last_change_id: client_id={client_id}, source_pull_gen={source_pull_gen}" ); // fetch last_change_id from the target DB in order to guarantee atomic replay of changes and avoid conflicts in case of failure - let mut schema_stmt = target_conn.prepare(TURSO_SYNC_CREATE_TABLE)?; - exec_stmt(coro, &mut schema_stmt).await?; - - let mut select_last_change_id_stmt = target_conn.prepare(TURSO_SYNC_SELECT_LAST_CHANGE_ID)?; - select_last_change_id_stmt.bind_at(1.try_into().unwrap(), Value::Text(Text::new(client_id))); - - let mut last_change_id = match run_stmt_expect_one_row(coro, &mut select_last_change_id_stmt) - .await? - { - Some(row) => { - let target_pull_gen = row[0].as_int().ok_or_else(|| { - Error::DatabaseSyncEngineError("unexpected target pull_gen type".to_string()) - })?; - let target_change_id = row[1].as_int().ok_or_else(|| { - Error::DatabaseSyncEngineError("unexpected target change_id type".to_string()) - })?; - tracing::debug!( - "push_logical_changes: client_id={client_id}, target_pull_gen={target_pull_gen}, target_change_id={target_change_id}" - ); - if target_pull_gen > source_pull_gen { - return Err(Error::DatabaseSyncEngineError(format!("protocol error: target_pull_gen > source_pull_gen: {target_pull_gen} > {source_pull_gen}"))); - } - if target_pull_gen == source_pull_gen { - Some(target_change_id) - } else { - Some(0) - } - } - None => { - let mut insert_last_change_id_stmt = - target_conn.prepare(TURSO_SYNC_INSERT_LAST_CHANGE_ID)?; - insert_last_change_id_stmt - .bind_at(1.try_into().unwrap(), Value::Text(Text::new(client_id))); - exec_stmt(coro, &mut insert_last_change_id_stmt).await?; - None - } + let init_hrana_request = server_proto::PipelineReqBody { + baton: None, + requests: vec![ + // read pull_gen, change_id values for current client if they were set before + StreamRequest::Execute(ExecuteStreamReq { + stmt: Stmt { + sql: Some(TURSO_SYNC_SELECT_LAST_CHANGE_ID.to_string()), + sql_id: None, + args: vec![server_proto::Value::Text { + value: client_id.to_string(), + }], + named_args: Vec::new(), + want_rows: Some(true), + replication_index: None, + }, + }), + ] + .into(), }; + let response = match sql_execute_http(coro, client, init_hrana_request).await { + Ok(response) => response, + Err(Error::DatabaseSyncEngineError(err)) if err.contains("no such table") => { + return Ok((source_pull_gen, None)); + } + Err(err) => return Err(err), + }; + assert!(response.len() == 1); + let last_change_id_response = &response[0]; + tracing::debug!("fetch_last_change_id: response={:?}", response); + assert!(last_change_id_response.rows.len() <= 1); + if last_change_id_response.rows.is_empty() { + return Ok((source_pull_gen, None)); + } + let row = &last_change_id_response.rows[0].values; + let server_proto::Value::Integer { + value: target_pull_gen, + } = row[0] + else { + return Err(Error::DatabaseSyncEngineError( + "unexpected target pull_gen type".to_string(), + )); + }; + let server_proto::Value::Integer { + value: target_change_id, + } = row[1] + else { + return Err(Error::DatabaseSyncEngineError( + "unexpected target change_id type".to_string(), + )); + }; + tracing::debug!( + "fetch_last_change_id: client_id={client_id}, target_pull_gen={target_pull_gen}, target_change_id={target_change_id}" + ); + if target_pull_gen > source_pull_gen { + return Err(Error::DatabaseSyncEngineError(format!("protocol error: target_pull_gen > source_pull_gen: {target_pull_gen} > {source_pull_gen}"))); + } + let last_change_id = if target_pull_gen == source_pull_gen { + Some(target_change_id) + } else { + Some(0) + }; + Ok((source_pull_gen, last_change_id)) +} + +pub async fn push_logical_changes( + coro: &Coro, + client: &C, + source: &DatabaseTape, + client_id: &str, + opts: &DatabaseSyncEngineOpts, +) -> Result<(i64, i64)> { + tracing::info!("push_logical_changes: client_id={client_id}"); + let source_conn = connect_untracked(source)?; + + let (source_pull_gen, mut last_change_id) = + fetch_last_change_id(coro, client, &source_conn, client_id).await?; + tracing::debug!("push_logical_changes: last_change_id={:?}", last_change_id); let replay_opts = DatabaseReplaySessionOpts { use_implicit_rowid: false, + transform: None, }; - let conn = connect_untracked(target)?; - let generator = DatabaseReplayGenerator::new(conn, replay_opts); + let generator = DatabaseReplayGenerator::new(source_conn, replay_opts); let iterate_opts = DatabaseChangesIteratorOpts { first_change_id: last_change_id.map(|x| x + 1), @@ -568,6 +707,10 @@ pub async fn push_logical_changes( if change.table_name == TURSO_SYNC_TABLE_NAME { continue; } + let ignore = &opts.tables_ignore; + if ignore.iter().any(|x| &change.table_name == x) { + continue; + } rows_changed += 1; // we give user full control over CDC table - so let's not emit assert here for now if last_change_id.is_some() && last_change_id.unwrap() + 1 != change.change_id { @@ -579,19 +722,39 @@ pub async fn push_logical_changes( } last_change_id = Some(change.change_id); let replay_info = generator.replay_info(coro, &change).await?; + if !replay_info.is_ddl_replay { + if let Some(transform) = &opts.transform { + let mutation = generator.create_mutation(&replay_info, &change)?; + if let Some(statement) = transform(&coro.ctx.borrow(), mutation)? { + tracing::info!( + "push_logical_changes: use mutation from custom transformer: sql={}, values={:?}", + statement.sql, + statement.values + ); + sql_over_http_requests.push(Stmt { + sql: Some(statement.sql), + sql_id: None, + args: convert_to_args(statement.values), + named_args: Vec::new(), + want_rows: Some(false), + replication_index: None, + }); + continue; + } + } + } let change_type = (&change.change).into(); match change.change { DatabaseTapeRowChangeType::Delete { before } => { - assert!(replay_info.len() == 1); let values = generator.replay_values( - &replay_info[0], + &replay_info, change_type, change.id, before, None, ); sql_over_http_requests.push(Stmt { - sql: Some(replay_info[0].query.clone()), + sql: Some(replay_info.query.clone()), sql_id: None, args: convert_to_args(values), named_args: Vec::new(), @@ -600,16 +763,15 @@ pub async fn push_logical_changes( }) } DatabaseTapeRowChangeType::Insert { after } => { - assert!(replay_info.len() == 1); let values = generator.replay_values( - &replay_info[0], + &replay_info, change_type, change.id, after, None, ); sql_over_http_requests.push(Stmt { - sql: Some(replay_info[0].query.clone()), + sql: Some(replay_info.query.clone()), sql_id: None, args: convert_to_args(values), named_args: Vec::new(), @@ -622,16 +784,15 @@ pub async fn push_logical_changes( updates: Some(updates), .. } => { - assert!(replay_info.len() == 1); let values = generator.replay_values( - &replay_info[0], + &replay_info, change_type, change.id, after, Some(updates), ); sql_over_http_requests.push(Stmt { - sql: Some(replay_info[0].query.clone()), + sql: Some(replay_info.query.clone()), sql_id: None, args: convert_to_args(values), named_args: Vec::new(), @@ -640,35 +801,19 @@ pub async fn push_logical_changes( }) } DatabaseTapeRowChangeType::Update { - before, after, updates: None, + .. } => { - assert!(replay_info.len() == 2); let values = generator.replay_values( - &replay_info[0], - change_type, - change.id, - before, - None, - ); - sql_over_http_requests.push(Stmt { - sql: Some(replay_info[0].query.clone()), - sql_id: None, - args: convert_to_args(values), - named_args: Vec::new(), - want_rows: Some(false), - replication_index: None, - }); - let values = generator.replay_values( - &replay_info[1], + &replay_info, change_type, change.id, after, None, ); sql_over_http_requests.push(Stmt { - sql: Some(replay_info[1].query.clone()), + sql: Some(replay_info.query.clone()), sql_id: None, args: convert_to_args(values), named_args: Vec::new(), @@ -682,9 +827,8 @@ pub async fn push_logical_changes( if rows_changed > 0 { tracing::info!("prepare update stmt for turso_sync_last_change_id table with client_id={} and last_change_id={:?}", client_id, last_change_id); // update turso_sync_last_change_id table with new value before commit - let (next_pull_gen, next_change_id) = - (source_pull_gen, last_change_id.unwrap_or(0)); - tracing::info!("transfer_logical_changes: client_id={client_id}, set pull_gen={next_pull_gen}, change_id={next_change_id}, rows_changed={rows_changed}"); + let next_change_id = last_change_id.unwrap_or(0); + tracing::info!("push_logical_changes: client_id={client_id}, set pull_gen={source_pull_gen}, change_id={next_change_id}, rows_changed={rows_changed}"); sql_over_http_requests.push(Stmt { sql: Some(TURSO_SYNC_UPSERT_LAST_CHANGE_ID.to_string()), sql_id: None, @@ -693,7 +837,7 @@ pub async fn push_logical_changes( value: client_id.to_string(), }, server_proto::Value::Integer { - value: next_pull_gen, + value: source_pull_gen, }, server_proto::Value::Integer { value: next_change_id, @@ -716,8 +860,8 @@ pub async fn push_logical_changes( } } - tracing::debug!("hrana request: {:?}", sql_over_http_requests); - let request = server_proto::PipelineReqBody { + tracing::trace!("hrana request: {:?}", sql_over_http_requests); + let replay_hrana_request = server_proto::PipelineReqBody { baton: None, requests: sql_over_http_requests .into_iter() @@ -725,74 +869,40 @@ pub async fn push_logical_changes( .collect(), }; - sql_execute_http(coro, client, request).await?; + let _ = sql_execute_http(coro, client, replay_hrana_request).await?; tracing::info!("push_logical_changes: rows_changed={:?}", rows_changed); - Ok(()) + Ok((source_pull_gen, last_change_id.unwrap_or(0))) } -/// Replace WAL frames [target_wal_match_watermark..) in the target DB with frames [source_wal_match_watermark..) from source DB -/// Return the position in target DB wal which logically equivalent to the source_sync_watermark in the source DB WAL -pub async fn transfer_physical_changes( - coro: &Coro, - source: &DatabaseTape, - target_session: WalSession, - source_wal_match_watermark: u64, - source_sync_watermark: u64, - target_wal_match_watermark: u64, -) -> Result { - tracing::info!("transfer_physical_changes: source_wal_match_watermark={source_wal_match_watermark}, source_sync_watermark={source_sync_watermark}, target_wal_match_watermark={target_wal_match_watermark}"); - - let source_conn = connect(coro, source).await?; - let mut source_session = WalSession::new(source_conn.clone()); - source_session.begin()?; - - let source_frames_count = source_conn.wal_state()?.max_frame; - assert!( - source_frames_count >= source_wal_match_watermark, - "watermark can't be greater than current frames count: {source_frames_count} vs {source_wal_match_watermark}", - ); - if source_frames_count == source_wal_match_watermark { - assert!(source_sync_watermark == source_wal_match_watermark); - return Ok(target_wal_match_watermark); - } - assert!( - (source_wal_match_watermark..=source_frames_count).contains(&source_sync_watermark), - "source_sync_watermark={source_sync_watermark} must be in range: {source_wal_match_watermark}..={source_frames_count}", - ); - - let target_sync_watermark = { - let mut target_session = DatabaseWalSession::new(coro, target_session).await?; - tracing::info!("rollback_changes_after: {target_wal_match_watermark}"); - - target_session.rollback_changes_after(target_wal_match_watermark)?; - let mut last_frame_info = None; - let mut frame = vec![0u8; WAL_FRAME_SIZE]; - let mut target_sync_watermark = target_session.frames_count()?; - tracing::info!( - "transfer_physical_changes: start={}, end={}", - source_wal_match_watermark + 1, - source_frames_count - ); - for source_frame_no in source_wal_match_watermark + 1..=source_frames_count { - let frame_info = source_conn.wal_get_frame(source_frame_no, &mut frame)?; - tracing::debug!("append page {} to target DB", frame_info.page_no); - target_session.append_page(frame_info.page_no, &frame[WAL_FRAME_HEADER..])?; - if source_frame_no == source_sync_watermark { - target_sync_watermark = target_session.frames_count()? + 1; // +1 because page will be actually commited on next iteration - tracing::info!("set target_sync_watermark to {}", target_sync_watermark); - } - last_frame_info = Some(frame_info); +pub async fn read_wal_salt( + coro: &Coro, + wal: &Arc, +) -> Result>> { + let buffer = Arc::new(Buffer::new_temporary(WAL_HEADER)); + let c = Completion::new_read(buffer.clone(), |result| { + let Ok((buffer, len)) = result else { + return; + }; + if (len as usize) < WAL_HEADER { + buffer.as_mut_slice().fill(0); } - let db_size = last_frame_info.unwrap().db_size; - tracing::debug!("commit WAL session to target with db_size={db_size}"); - target_session.commit(db_size)?; - assert!(target_sync_watermark != 0); - target_sync_watermark - }; - Ok(target_sync_watermark) + }); + let c = wal.pread(0, c)?; + while !c.is_completed() { + coro.yield_(ProtocolCommand::IO).await?; + } + if buffer.as_mut_slice() == &[0u8; WAL_HEADER] { + return Ok(None); + } + let salt1 = u32::from_be_bytes(buffer.as_slice()[16..20].try_into().unwrap()); + let salt2 = u32::from_be_bytes(buffer.as_slice()[20..24].try_into().unwrap()); + Ok(Some(vec![salt1, salt2])) } -pub async fn checkpoint_wal_file(coro: &Coro, conn: &Arc) -> Result<()> { +pub async fn checkpoint_wal_file( + coro: &Coro, + conn: &Arc, +) -> Result<()> { let mut checkpoint_stmt = conn.prepare("PRAGMA wal_checkpoint(TRUNCATE)")?; loop { match checkpoint_stmt.step()? { @@ -809,8 +919,150 @@ pub async fn checkpoint_wal_file(coro: &Coro, conn: &Arc Ok(()) } -pub async fn reset_wal_file( - coro: &Coro, +pub async fn bootstrap_db_file( + coro: &Coro, + client: &C, + io: &Arc, + main_db_path: &str, + protocol: DatabaseSyncEngineProtocolVersion, +) -> Result { + match protocol { + DatabaseSyncEngineProtocolVersion::Legacy => { + bootstrap_db_file_legacy(coro, client, io, main_db_path).await + } + DatabaseSyncEngineProtocolVersion::V1 => { + bootstrap_db_file_v1(coro, client, io, main_db_path).await + } + } +} + +pub async fn bootstrap_db_file_v1( + coro: &Coro, + client: &C, + io: &Arc, + main_db_path: &str, +) -> Result { + let mut bytes = BytesMut::new(); + let completion = client.http( + "GET", + "/pull-updates", + None, + &[ + ("content-type", "application/protobuf"), + ("accept-encoding", "application/protobuf"), + ], + )?; + let Some(header) = + wait_proto_message::(coro, &completion, &mut bytes).await? + else { + return Err(Error::DatabaseSyncEngineError(format!( + "no header returned in the pull-updates protobuf call" + ))); + }; + tracing::info!( + "bootstrap_db_file(path={}): got header={:?}", + main_db_path, + header + ); + let file = io.open_file(main_db_path, OpenFlags::Create, false)?; + let c = Completion::new_trunc(move |result| { + let Ok(rc) = result else { + return; + }; + assert!(rc as usize == 0); + }); + let c = file.truncate(header.db_size as usize * PAGE_SIZE, c)?; + while !c.is_completed() { + coro.yield_(ProtocolCommand::IO).await?; + } + + let buffer = Arc::new(Buffer::new_temporary(PAGE_SIZE)); + while let Some(page_data) = + wait_proto_message::(coro, &completion, &mut bytes).await? + { + let offset = page_data.page_id as usize * PAGE_SIZE; + let page = decode_page(&header, page_data)?; + if page.len() != PAGE_SIZE { + return Err(Error::DatabaseSyncEngineError(format!( + "page has unexpected size: {} != {}", + page.len(), + PAGE_SIZE + ))); + } + buffer.as_mut_slice().copy_from_slice(&page); + let c = Completion::new_write(move |result| { + // todo(sivukhin): we need to error out in case of partial read + let Ok(size) = result else { + return; + }; + assert!(size as usize == PAGE_SIZE); + }); + let c = file.pwrite(offset, buffer.clone(), c)?; + while !c.is_completed() { + coro.yield_(ProtocolCommand::IO).await?; + } + } + Ok(DatabasePullRevision::V1 { + revision: header.server_revision, + }) +} + +fn decode_page(header: &PullUpdatesRespProtoBody, page_data: PageData) -> Result> { + if header.raw_encoding.is_some() && header.zstd_encoding.is_some() { + return Err(Error::DatabaseSyncEngineError( + "both of raw_encoding and zstd_encoding are set".to_string(), + )); + } + if header.raw_encoding.is_none() && header.zstd_encoding.is_none() { + return Err(Error::DatabaseSyncEngineError( + "none from raw_encoding and zstd_encoding are set".to_string(), + )); + } + + if let Some(_) = header.raw_encoding { + return Ok(page_data.encoded_page.to_vec()); + } + Err(Error::DatabaseSyncEngineError( + "zstd encoding is not supported".to_string(), + )) +} + +pub async fn bootstrap_db_file_legacy( + coro: &Coro, + client: &C, + io: &Arc, + main_db_path: &str, +) -> Result { + tracing::info!("bootstrap_db_file(path={})", main_db_path); + + let start_time = std::time::Instant::now(); + // cleanup all files left from previous attempt to bootstrap + // we shouldn't write any WAL files - but let's truncate them too for safety + if let Some(file) = io.try_open(&main_db_path)? { + io.truncate(coro, file, 0).await?; + } + if let Some(file) = io.try_open(&format!("{}-wal", main_db_path))? { + io.truncate(coro, file, 0).await?; + } + + let file = io.create(&main_db_path)?; + let db_info = db_bootstrap(coro, client, file).await?; + + let elapsed = std::time::Instant::now().duration_since(start_time); + tracing::info!( + "bootstrap_db_files(path={}): finished: elapsed={:?}", + main_db_path, + elapsed + ); + + Ok(DatabasePullRevision::Legacy { + generation: db_info.current_generation, + synced_frame_no: None, + }) +} + +pub async fn reset_wal_file( + coro: &Coro, wal: Arc, frames_count: u64, ) -> Result<()> { @@ -821,8 +1073,8 @@ pub async fn reset_wal_file( WAL_HEADER + WAL_FRAME_SIZE * (frames_count as usize) }; tracing::debug!("reset db wal to the size of {} frames", frames_count); - let c = Completion::new_trunc(move |res| { - let Ok(rc) = res else { + let c = Completion::new_trunc(move |result| { + let Ok(rc) = result else { return; }; assert!(rc as usize == 0); @@ -834,13 +1086,13 @@ pub async fn reset_wal_file( Ok(()) } -async fn sql_execute_http( - coro: &Coro, +async fn sql_execute_http( + coro: &Coro, client: &C, request: server_proto::PipelineReqBody, -) -> Result<()> { +) -> Result> { let body = serde_json::to_vec(&request)?; - let completion = client.http("POST", "/v2/pipeline", Some(body))?; + let completion = client.http("POST", "/v2/pipeline", Some(body), &[])?; let status = wait_status(coro, &completion).await?; if status != http::StatusCode::OK { let error = format!("sql_execute_http: unexpected status code: {status}"); @@ -848,18 +1100,32 @@ async fn sql_execute_http( } let response = wait_full_body(coro, &completion).await?; let response: server_proto::PipelineRespBody = serde_json::from_slice(&response)?; + tracing::debug!("hrana response: {:?}", response); + let mut results = Vec::new(); for result in response.results { - if let server_proto::StreamResult::Error { error } = result { - return Err(Error::DatabaseSyncEngineError(format!( - "failed to execute sql: {error:?}" - ))); + match result { + server_proto::StreamResult::Error { error } => { + return Err(Error::DatabaseSyncEngineError(format!( + "failed to execute sql: {error:?}" + ))) + } + server_proto::StreamResult::None => { + return Err(Error::DatabaseSyncEngineError( + "unexpected None result".to_string(), + )); + } + server_proto::StreamResult::Ok { response } => match response { + server_proto::StreamResponse::Execute(execute) => { + results.push(execute.result); + } + }, } } - Ok(()) + Ok(results) } -async fn wal_pull_http( - coro: &Coro, +async fn wal_pull_http( + coro: &Coro, client: &C, generation: u64, start_frame: u64, @@ -869,6 +1135,7 @@ async fn wal_pull_http( "GET", &format!("/sync/{generation}/{start_frame}/{end_frame}"), None, + &[], )?; let status = wait_status(coro, &completion).await?; if status == http::StatusCode::BAD_REQUEST { @@ -888,8 +1155,8 @@ async fn wal_pull_http( Ok(WalHttpPullResult::Frames(completion)) } -async fn wal_push_http( - coro: &Coro, +async fn wal_push_http( + coro: &Coro, client: &C, baton: Option, generation: u64, @@ -904,6 +1171,7 @@ async fn wal_push_http( "POST", &format!("/sync/{generation}/{start_frame}/{end_frame}{baton}"), Some(frames), + &[], )?; let status = wait_status(coro, &completion).await?; let status_body = wait_full_body(coro, &completion).await?; @@ -916,8 +1184,8 @@ async fn wal_push_http( Ok(serde_json::from_slice(&status_body)?) } -async fn db_info_http(coro: &Coro, client: &C) -> Result { - let completion = client.http("GET", "/info", None)?; +async fn db_info_http(coro: &Coro, client: &C) -> Result { + let completion = client.http("GET", "/info", None, &[])?; let status = wait_status(coro, &completion).await?; let status_body = wait_full_body(coro, &completion).await?; if status != http::StatusCode::OK { @@ -928,12 +1196,12 @@ async fn db_info_http(coro: &Coro, client: &C) -> Result( - coro: &Coro, +async fn db_bootstrap_http( + coro: &Coro, client: &C, generation: u64, ) -> Result { - let completion = client.http("GET", &format!("/export/{generation}"), None)?; + let completion = client.http("GET", &format!("/export/{generation}"), None, &[])?; let status = wait_status(coro, &completion).await?; if status != http::StatusCode::OK.as_u16() { return Err(Error::DatabaseSyncEngineError(format!( @@ -943,14 +1211,76 @@ async fn db_bootstrap_http( Ok(completion) } -pub async fn wait_status(coro: &Coro, completion: &impl DataCompletion) -> Result { +pub async fn wait_status(coro: &Coro, completion: &impl DataCompletion) -> Result { while completion.status()?.is_none() { coro.yield_(ProtocolCommand::IO).await?; } Ok(completion.status()?.unwrap()) } -pub async fn wait_full_body(coro: &Coro, completion: &impl DataCompletion) -> Result> { +#[inline(always)] +pub fn read_varint(buf: &[u8]) -> Result> { + let mut v: u64 = 0; + for i in 0..9 { + match buf.get(i) { + Some(c) => { + v = (((c & 0x7f) as u64) << (i * 7)) | v; + if (c & 0x80) == 0 { + return Ok(Some((v as usize, i + 1))); + } + } + None => return Ok(None), + } + } + return Err(Error::DatabaseSyncEngineError(format!( + "invalid variant byte: {:?}", + &buf[0..=8] + ))); +} + +pub async fn wait_proto_message( + coro: &Coro, + completion: &impl DataCompletion, + bytes: &mut BytesMut, +) -> Result> { + let start_time = std::time::Instant::now(); + loop { + let length = read_varint(&bytes)?; + let not_enough_bytes = match length { + None => true, + Some((message_length, prefix_length)) => message_length + prefix_length > bytes.len(), + }; + if not_enough_bytes { + if let Some(poll) = completion.poll_data()? { + bytes.extend_from_slice(poll.data()); + } else if !completion.is_done()? { + coro.yield_(ProtocolCommand::IO).await?; + } else if bytes.len() == 0 { + return Ok(None); + } else { + return Err(Error::DatabaseSyncEngineError( + "unexpected end of protobuf message".to_string(), + )); + } + continue; + } + let (message_length, prefix_length) = length.unwrap(); + let message = T::decode_length_delimited(&**bytes).map_err(|e| { + Error::DatabaseSyncEngineError(format!("unable to deserialize protobuf message: {e}")) + })?; + let _ = bytes.split_to(message_length + prefix_length); + tracing::debug!( + "wait_proto_message: elapsed={:?}", + std::time::Instant::now().duration_since(start_time) + ); + return Ok(Some(message)); + } +} + +pub async fn wait_full_body( + coro: &Coro, + completion: &impl DataCompletion, +) -> Result> { let mut bytes = Vec::new(); loop { while let Some(poll) = completion.poll_data()? { @@ -965,183 +1295,89 @@ pub async fn wait_full_body(coro: &Coro, completion: &impl DataCompletion) -> Re } #[cfg(test)] -pub mod tests { - use std::sync::Arc; +mod tests { + use std::cell::RefCell; - use tempfile::NamedTempFile; - use turso_core::Value; + use bytes::{Bytes, BytesMut}; + use prost::Message; use crate::{ - database_sync_operations::{transfer_logical_changes, transfer_physical_changes}, - database_tape::{run_stmt_once, DatabaseTape, DatabaseTapeOpts}, - wal_session::WalSession, + database_sync_operations::wait_proto_message, + protocol_io::{DataCompletion, DataPollResult}, + server_proto::PageData, + types::Coro, Result, }; - #[test] - pub fn test_transfer_logical_changes() { - let temp_file1 = NamedTempFile::new().unwrap(); - let db_path1 = temp_file1.path().to_str().unwrap(); - let temp_file2 = NamedTempFile::new().unwrap(); - let db_path2 = temp_file2.path().to_str().unwrap(); + struct TestPollResult(Vec); - let io: Arc = Arc::new(turso_core::PlatformIO::new().unwrap()); - let db1 = turso_core::Database::open_file(io.clone(), db_path1, false, true).unwrap(); - let db1 = Arc::new(DatabaseTape::new(db1)); + impl DataPollResult for TestPollResult { + fn data(&self) -> &[u8] { + &self.0 + } + } - let db2 = turso_core::Database::open_file(io.clone(), db_path2, false, true).unwrap(); - let db2 = Arc::new(DatabaseTape::new(db2)); + struct TestCompletion { + data: RefCell, + chunk: usize, + } - let mut gen = genawaiter::sync::Gen::new(|coro| async move { - let conn1 = db1.connect(&coro).await?; - conn1.execute("CREATE TABLE t(x, y)").unwrap(); - conn1 - .execute("INSERT INTO t VALUES (1, 2), (3, 4), (5, 6)") - .unwrap(); + impl DataCompletion for TestCompletion { + type DataPollResult = TestPollResult; - let conn2 = db2.connect(&coro).await.unwrap(); + fn status(&self) -> crate::Result> { + Ok(Some(200)) + } - transfer_logical_changes(&coro, &db1, &db2, "id-1", false) - .await - .unwrap(); - - let mut rows = Vec::new(); - let mut stmt = conn2.prepare("SELECT x, y FROM t").unwrap(); - while let Some(row) = run_stmt_once(&coro, &mut stmt).await.unwrap() { - rows.push(row.get_values().cloned().collect::>()); + fn poll_data(&self) -> crate::Result> { + let mut data = self.data.borrow_mut(); + let len = data.len(); + let chunk = data.split_to(len.min(self.chunk)); + if chunk.is_empty() { + Ok(None) + } else { + Ok(Some(TestPollResult(chunk.to_vec()))) } - assert_eq!( - rows, - vec![ - vec![Value::Integer(1), Value::Integer(2)], - vec![Value::Integer(3), Value::Integer(4)], - vec![Value::Integer(5), Value::Integer(6)], - ] - ); + } - conn1.execute("INSERT INTO t VALUES (7, 8)").unwrap(); - transfer_logical_changes(&coro, &db1, &db2, "id-1", false) - .await - .unwrap(); - - let mut rows = Vec::new(); - let mut stmt = conn2.prepare("SELECT x, y FROM t").unwrap(); - while let Some(row) = run_stmt_once(&coro, &mut stmt).await.unwrap() { - rows.push(row.get_values().cloned().collect::>()); - } - assert_eq!( - rows, - vec![ - vec![Value::Integer(1), Value::Integer(2)], - vec![Value::Integer(3), Value::Integer(4)], - vec![Value::Integer(5), Value::Integer(6)], - vec![Value::Integer(7), Value::Integer(8)], - ] - ); - - Result::Ok(()) - }); - loop { - match gen.resume_with(Ok(())) { - genawaiter::GeneratorState::Yielded(..) => io.run_once().unwrap(), - genawaiter::GeneratorState::Complete(result) => { - result.unwrap(); - break; - } - } + fn is_done(&self) -> crate::Result { + Ok(self.data.borrow().len() == 0) } } #[test] - pub fn test_transfer_physical_changes() { - let temp_file1 = NamedTempFile::new().unwrap(); - let db_path1 = temp_file1.path().to_str().unwrap(); - let temp_file2 = NamedTempFile::new().unwrap(); - let db_path2 = temp_file2.path().to_str().unwrap(); - - let opts = DatabaseTapeOpts { - cdc_mode: Some("off".to_string()), - cdc_table: None, + pub fn wait_proto_message_test() { + let mut data = Vec::new(); + for i in 0..1024 { + let page = PageData { + page_id: i as u64, + encoded_page: vec![0u8; 16 * 1024].into(), + }; + data.extend_from_slice(&page.encode_length_delimited_to_vec()); + } + let completion = TestCompletion { + data: RefCell::new(data.into()), + chunk: 128, }; - let io: Arc = Arc::new(turso_core::PlatformIO::new().unwrap()); - let db1 = turso_core::Database::open_file(io.clone(), db_path1, false, true).unwrap(); - let db1 = Arc::new(DatabaseTape::new_with_opts(db1, opts.clone())); - - let db2 = turso_core::Database::open_file(io.clone(), db_path2, false, true).unwrap(); - let db2 = Arc::new(DatabaseTape::new_with_opts(db2, opts.clone())); - - let mut gen = genawaiter::sync::Gen::new(|coro| async move { - let conn1 = db1.connect(&coro).await?; - conn1.execute("CREATE TABLE t(x, y)")?; - conn1.execute("INSERT INTO t VALUES (1, 2)")?; - let conn1_match_watermark = conn1.wal_state().unwrap().max_frame; - conn1.execute("INSERT INTO t VALUES (3, 4)")?; - let conn1_sync_watermark = conn1.wal_state().unwrap().max_frame; - conn1.execute("INSERT INTO t VALUES (5, 6)")?; - - let conn2 = db2.connect(&coro).await?; - conn2.execute("CREATE TABLE t(x, y)")?; - conn2.execute("INSERT INTO t VALUES (1, 2)")?; - let conn2_match_watermark = conn2.wal_state().unwrap().max_frame; - conn2.execute("INSERT INTO t VALUES (5, 6)")?; - - // db1 WAL frames: [A1 A2] [A3] [A4] (sync_watermark) [A5] - // db2 WAL frames: [B1 B2] [B3] [B4] - - let session = WalSession::new(conn2); - let conn2_sync_watermark = transfer_physical_changes( - &coro, - &db1, - session, - conn1_match_watermark, - conn1_sync_watermark, - conn2_match_watermark, - ) - .await?; - - // db2 WAL frames: [B1 B2] [B3] [B4] [B4^-1] [A4] (sync_watermark) [A5] - assert_eq!(conn2_sync_watermark, 6); - - let conn2 = db2.connect(&coro).await.unwrap(); - let mut rows = Vec::new(); - let mut stmt = conn2.prepare("SELECT x, y FROM t").unwrap(); - while let Some(row) = run_stmt_once(&coro, &mut stmt).await.unwrap() { - rows.push(row.get_values().cloned().collect::>()); + let mut gen = genawaiter::sync::Gen::new({ + |coro| async move { + let coro: Coro<()> = coro.into(); + let mut bytes = BytesMut::new(); + let mut count = 0; + while let Some(_) = + wait_proto_message::<(), PageData>(&coro, &completion, &mut bytes).await? + { + assert!(bytes.capacity() <= 16 * 1024 + 1024); + count += 1; + } + assert_eq!(count, 1024); + Result::Ok(()) } - assert_eq!( - rows, - vec![ - vec![Value::Integer(1), Value::Integer(2)], - vec![Value::Integer(3), Value::Integer(4)], - vec![Value::Integer(5), Value::Integer(6)], - ] - ); - - conn2.execute("INSERT INTO t VALUES (7, 8)")?; - let mut rows = Vec::new(); - let mut stmt = conn2.prepare("SELECT x, y FROM t").unwrap(); - while let Some(row) = run_stmt_once(&coro, &mut stmt).await.unwrap() { - rows.push(row.get_values().cloned().collect::>()); - } - assert_eq!( - rows, - vec![ - vec![Value::Integer(1), Value::Integer(2)], - vec![Value::Integer(3), Value::Integer(4)], - vec![Value::Integer(5), Value::Integer(6)], - vec![Value::Integer(7), Value::Integer(8)], - ] - ); - - Result::Ok(()) }); loop { match gen.resume_with(Ok(())) { - genawaiter::GeneratorState::Yielded(..) => io.run_once().unwrap(), - genawaiter::GeneratorState::Complete(result) => { - result.unwrap(); - break; - } + genawaiter::GeneratorState::Yielded(..) => {} + genawaiter::GeneratorState::Complete(result) => break result.unwrap(), } } } diff --git a/sync/engine/src/database_tape.rs b/sync/engine/src/database_tape.rs index cf62d0b53..f66d97651 100644 --- a/sync/engine/src/database_tape.rs +++ b/sync/engine/src/database_tape.rs @@ -3,15 +3,15 @@ use std::{ sync::Arc, }; -use turso_core::{types::WalFrameInfo, StepResult}; +use turso_core::{types::WalFrameInfo, LimboError, StepResult}; use crate::{ database_replay_generator::{DatabaseReplayGenerator, ReplayInfo}, database_sync_operations::WAL_FRAME_HEADER, errors::Error, types::{ - Coro, DatabaseChange, DatabaseTapeOperation, DatabaseTapeRowChange, - DatabaseTapeRowChangeType, ProtocolCommand, + Coro, DatabaseChange, DatabaseRowMutation, DatabaseRowStatement, DatabaseTapeOperation, + DatabaseTapeRowChange, DatabaseTapeRowChangeType, ProtocolCommand, }, wal_session::WalSession, Result, @@ -28,7 +28,7 @@ pub struct DatabaseTape { const DEFAULT_CDC_TABLE_NAME: &str = "turso_cdc"; const DEFAULT_CDC_MODE: &str = "full"; const DEFAULT_CHANGES_BATCH_SIZE: usize = 100; -const CDC_PRAGMA_NAME: &str = "unstable_capture_data_changes_conn"; +pub const CDC_PRAGMA_NAME: &str = "unstable_capture_data_changes_conn"; #[derive(Debug, Clone)] pub struct DatabaseTapeOpts { @@ -36,8 +36,8 @@ pub struct DatabaseTapeOpts { pub cdc_mode: Option, } -pub(crate) async fn run_stmt_once<'a>( - coro: &'_ Coro, +pub(crate) async fn run_stmt_once<'a, Ctx>( + coro: &'_ Coro, stmt: &'a mut turso_core::Statement, ) -> Result> { loop { @@ -61,8 +61,8 @@ pub(crate) async fn run_stmt_once<'a>( } } -pub(crate) async fn run_stmt_expect_one_row( - coro: &Coro, +pub(crate) async fn run_stmt_expect_one_row( + coro: &Coro, stmt: &mut turso_core::Statement, ) -> Result>> { let Some(row) = run_stmt_once(coro, stmt).await? else { @@ -75,15 +75,18 @@ pub(crate) async fn run_stmt_expect_one_row( Ok(Some(values)) } -pub(crate) async fn run_stmt_ignore_rows( - coro: &Coro, +pub(crate) async fn run_stmt_ignore_rows( + coro: &Coro, stmt: &mut turso_core::Statement, ) -> Result<()> { while run_stmt_once(coro, stmt).await?.is_some() {} Ok(()) } -pub(crate) async fn exec_stmt(coro: &Coro, stmt: &mut turso_core::Statement) -> Result<()> { +pub(crate) async fn exec_stmt( + coro: &Coro, + stmt: &mut turso_core::Statement, +) -> Result<()> { loop { match stmt.step()? { StepResult::IO => { @@ -128,7 +131,7 @@ impl DatabaseTape { let connection = self.inner.connect()?; Ok(connection) } - pub async fn connect(&self, coro: &Coro) -> Result> { + pub async fn connect(&self, coro: &Coro) -> Result> { let connection = self.inner.connect()?; tracing::debug!("set '{CDC_PRAGMA_NAME}' for new connection"); let mut stmt = connection.prepare(&self.pragma_query)?; @@ -142,19 +145,20 @@ impl DatabaseTape { ) -> Result { tracing::debug!("opening changes iterator with options {:?}", opts); let conn = self.inner.connect()?; - let query = opts.mode.query(&self.cdc_table, opts.batch_size); - let query_stmt = conn.prepare(&query)?; Ok(DatabaseChangesIterator { + conn, + cdc_table: self.cdc_table.clone(), first_change_id: opts.first_change_id, batch: VecDeque::with_capacity(opts.batch_size), - query_stmt, + query_stmt: None, txn_boundary_returned: false, mode: opts.mode, + batch_size: opts.batch_size, ignore_schema_changes: opts.ignore_schema_changes, }) } /// Start raw WAL edit session which can append or rollback pages directly in the current WAL - pub async fn start_wal_session(&self, coro: &Coro) -> Result { + pub async fn start_wal_session(&self, coro: &Coro) -> Result { let conn = self.connect(coro).await?; let mut wal_session = WalSession::new(conn); wal_session.begin()?; @@ -162,11 +166,11 @@ impl DatabaseTape { } /// Start replay session which can apply [DatabaseTapeOperation] from [Self::iterate_changes] - pub async fn start_replay_session( + pub async fn start_replay_session( &self, - coro: &Coro, - opts: DatabaseReplaySessionOpts, - ) -> Result { + coro: &Coro, + opts: DatabaseReplaySessionOpts, + ) -> Result> { tracing::debug!("opening replay session"); let conn = self.connect(coro).await?; conn.execute("BEGIN IMMEDIATE")?; @@ -184,12 +188,12 @@ impl DatabaseTape { pub struct DatabaseWalSession { page_size: usize, next_wal_frame_no: u64, - wal_session: WalSession, + pub wal_session: WalSession, prepared_frame: Option<(u32, Vec)>, } impl DatabaseWalSession { - pub async fn new(coro: &Coro, wal_session: WalSession) -> Result { + pub async fn new(coro: &Coro, wal_session: WalSession) -> Result { let conn = wal_session.conn(); let frames_count = conn.wal_state()?.max_frame; let mut page_size_stmt = conn.prepare("PRAGMA page_size")?; @@ -259,13 +263,15 @@ impl DatabaseWalSession { Ok(()) } - pub fn rollback_changes_after(&mut self, frame_watermark: u64) -> Result<()> { + pub fn rollback_changes_after(&mut self, frame_watermark: u64) -> Result { let conn = self.wal_session.conn(); let pages = conn.wal_changed_pages_after(frame_watermark)?; + tracing::info!("rolling back {} pages", pages.len()); + let pages_cnt = pages.len(); for page_no in pages { self.rollback_page(page_no, frame_watermark)?; } - Ok(()) + Ok(pages_cnt) } pub fn db_size(&self) -> Result { @@ -290,7 +296,7 @@ impl DatabaseWalSession { frame_info.put_to_frame_header(&mut frame); let frame_no = self.next_wal_frame_no; - tracing::trace!( + tracing::debug!( "flush prepared frame {:?} as frame_no {}", frame_info, frame_no @@ -352,17 +358,20 @@ impl Default for DatabaseChangesIteratorOpts { } pub struct DatabaseChangesIterator { - query_stmt: turso_core::Statement, + conn: Arc, + cdc_table: Arc, + query_stmt: Option, first_change_id: Option, batch: VecDeque, txn_boundary_returned: bool, mode: DatabaseChangesIteratorMode, + batch_size: usize, ignore_schema_changes: bool, } const SQLITE_SCHEMA_TABLE: &str = "sqlite_schema"; impl DatabaseChangesIterator { - pub async fn next(&mut self, coro: &Coro) -> Result> { + pub async fn next(&mut self, coro: &Coro) -> Result> { if self.batch.is_empty() { self.refill(coro).await?; } @@ -386,15 +395,26 @@ impl DatabaseChangesIterator { return Ok(next); } } - async fn refill(&mut self, coro: &Coro) -> Result<()> { + async fn refill(&mut self, coro: &Coro) -> Result<()> { + if self.query_stmt.is_none() { + let query = self.mode.query(&self.cdc_table, self.batch_size); + let stmt = match self.conn.prepare(&query) { + Ok(stmt) => stmt, + Err(LimboError::ParseError(err)) if err.contains("no such table") => return Ok(()), + Err(err) => return Err(err.into()), + }; + self.query_stmt = Some(stmt); + } + let query_stmt = self.query_stmt.as_mut().unwrap(); + let change_id_filter = self.first_change_id.unwrap_or(self.mode.first_id()); - self.query_stmt.reset(); - self.query_stmt.bind_at( + query_stmt.reset(); + query_stmt.bind_at( 1.try_into().unwrap(), turso_core::Value::Integer(change_id_filter), ); - while let Some(row) = run_stmt_once(coro, &mut self.query_stmt).await? { + while let Some(row) = run_stmt_once(coro, query_stmt).await? { let database_change: DatabaseChange = row.try_into()?; let tape_change = match self.mode { DatabaseChangesIteratorMode::Apply => database_change.into_apply()?, @@ -410,43 +430,59 @@ impl DatabaseChangesIterator { } } -#[derive(Debug, Clone)] -pub struct DatabaseReplaySessionOpts { +#[derive(Clone)] +pub struct DatabaseReplaySessionOpts { pub use_implicit_rowid: bool, + pub transform: Option< + Arc Result> + 'static>, + >, } -struct CachedStmt { +impl std::fmt::Debug for DatabaseReplaySessionOpts { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DatabaseReplaySessionOpts") + .field("use_implicit_rowid", &self.use_implicit_rowid) + .field("transform_mutation.is_some()", &self.transform.is_some()) + .finish() + } +} + +pub(crate) struct CachedStmt { stmt: turso_core::Statement, info: ReplayInfo, } -pub struct DatabaseReplaySession { - conn: Arc, - cached_delete_stmt: HashMap, - cached_insert_stmt: HashMap<(String, usize), CachedStmt>, - cached_update_stmt: HashMap<(String, Vec), CachedStmt>, - in_txn: bool, - generator: DatabaseReplayGenerator, +pub struct DatabaseReplaySession { + pub(crate) conn: Arc, + pub(crate) cached_delete_stmt: HashMap, + pub(crate) cached_insert_stmt: HashMap<(String, usize), CachedStmt>, + pub(crate) cached_update_stmt: HashMap<(String, Vec), CachedStmt>, + pub(crate) in_txn: bool, + pub(crate) generator: DatabaseReplayGenerator, } -async fn replay_stmt( - coro: &Coro, - cached: &mut CachedStmt, +async fn replay_stmt( + coro: &Coro, + stmt: &mut turso_core::Statement, values: Vec, ) -> Result<()> { - cached.stmt.reset(); + stmt.reset(); for (i, value) in values.into_iter().enumerate() { - cached.stmt.bind_at((i + 1).try_into().unwrap(), value); + stmt.bind_at((i + 1).try_into().unwrap(), value); } - exec_stmt(coro, &mut cached.stmt).await?; + exec_stmt(coro, stmt).await?; Ok(()) } -impl DatabaseReplaySession { +impl DatabaseReplaySession { pub fn conn(&self) -> Arc { self.conn.clone() } - pub async fn replay(&mut self, coro: &Coro, operation: DatabaseTapeOperation) -> Result<()> { + pub async fn replay( + &mut self, + coro: &Coro, + operation: DatabaseTapeOperation, + ) -> Result<()> { match operation { DatabaseTapeOperation::Commit => { tracing::debug!("replay: commit replayed changes after transaction boundary"); @@ -466,10 +502,23 @@ impl DatabaseReplaySession { if table == SQLITE_SCHEMA_TABLE { let replay_info = self.generator.replay_info(coro, &change).await?; - for replay in &replay_info { - self.conn.execute(replay.query.as_str())?; - } + self.conn.execute(replay_info.query.as_str())?; } else { + if let Some(transform) = &self.generator.opts.transform { + let replay_info = self.generator.replay_info(coro, &change).await?; + let mutation = self.generator.create_mutation(&replay_info, &change)?; + let statement = transform(&coro.ctx.borrow(), mutation)?; + if let Some(statement) = statement { + tracing::info!( + "replay: use mutation from custom transformer: sql={}, values={:?}", + statement.sql, + statement.values + ); + let mut stmt = self.conn.prepare(&statement.sql)?; + replay_stmt(coro, &mut stmt, statement.values).await?; + return Ok(()); + } + } match change.change { DatabaseTapeRowChangeType::Delete { before } => { let key = self.populate_delete_stmt(coro, table).await?; @@ -486,7 +535,7 @@ impl DatabaseReplaySession { before, None, ); - replay_stmt(coro, cached, values).await?; + replay_stmt(coro, &mut cached.stmt, values).await?; } DatabaseTapeRowChangeType::Insert { after } => { let key = self.populate_insert_stmt(coro, table, after.len()).await?; @@ -503,7 +552,7 @@ impl DatabaseReplaySession { after, None, ); - replay_stmt(coro, cached, values).await?; + replay_stmt(coro, &mut cached.stmt, values).await?; } DatabaseTapeRowChangeType::Update { after, @@ -533,7 +582,7 @@ impl DatabaseReplaySession { after, Some(updates), ); - replay_stmt(coro, cached, values).await?; + replay_stmt(coro, &mut cached.stmt, values).await?; } DatabaseTapeRowChangeType::Update { before, @@ -554,7 +603,7 @@ impl DatabaseReplaySession { before, None, ); - replay_stmt(coro, cached, values).await?; + replay_stmt(coro, &mut cached.stmt, values).await?; let key = self.populate_insert_stmt(coro, table, after.len()).await?; tracing::trace!( @@ -570,7 +619,7 @@ impl DatabaseReplaySession { after, None, ); - replay_stmt(coro, cached, values).await?; + replay_stmt(coro, &mut cached.stmt, values).await?; } } } @@ -578,7 +627,11 @@ impl DatabaseReplaySession { } Ok(()) } - async fn populate_delete_stmt<'a>(&mut self, coro: &Coro, table: &'a str) -> Result<&'a str> { + async fn populate_delete_stmt<'a>( + &mut self, + coro: &Coro, + table: &'a str, + ) -> Result<&'a str> { if self.cached_delete_stmt.contains_key(table) { return Ok(table); } @@ -591,7 +644,7 @@ impl DatabaseReplaySession { } async fn populate_insert_stmt( &mut self, - coro: &Coro, + coro: &Coro, table: &str, columns: usize, ) -> Result<(String, usize)> { @@ -612,7 +665,7 @@ impl DatabaseReplaySession { } async fn populate_update_stmt( &mut self, - coro: &Coro, + coro: &Coro, table: &str, columns: &[bool], ) -> Result<(String, Vec)> { @@ -639,7 +692,7 @@ mod tests { database_tape::{ run_stmt_once, DatabaseChangesIteratorOpts, DatabaseReplaySessionOpts, DatabaseTape, }, - types::{DatabaseTapeOperation, DatabaseTapeRowChange, DatabaseTapeRowChangeType}, + types::{Coro, DatabaseTapeOperation, DatabaseTapeRowChange, DatabaseTapeRowChangeType}, }; #[test] @@ -653,6 +706,7 @@ mod tests { let mut gen = genawaiter::sync::Gen::new({ let db1 = db1.clone(); |coro| async move { + let coro: Coro<()> = coro.into(); let conn = db1.connect(&coro).await.unwrap(); let mut stmt = conn.prepare("SELECT * FROM turso_cdc").unwrap(); let mut rows = Vec::new(); @@ -683,6 +737,7 @@ mod tests { let mut gen = genawaiter::sync::Gen::new({ let db1 = db1.clone(); |coro| async move { + let coro: Coro<()> = coro.into(); let conn = db1.connect(&coro).await.unwrap(); conn.execute("CREATE TABLE t(x)").unwrap(); conn.execute("INSERT INTO t VALUES (1), (2), (3)").unwrap(); @@ -754,6 +809,7 @@ mod tests { let db1 = db1.clone(); let db2 = db2.clone(); |coro| async move { + let coro: Coro<()> = coro.into(); let conn1 = db1.connect(&coro).await.unwrap(); conn1.execute("CREATE TABLE t(x)").unwrap(); conn1 @@ -768,6 +824,7 @@ mod tests { { let opts = DatabaseReplaySessionOpts { use_implicit_rowid: true, + transform: None, }; let mut session = db2.start_replay_session(&coro, opts).await.unwrap(); let opts = Default::default(); @@ -832,6 +889,7 @@ mod tests { let db1 = db1.clone(); let db2 = db2.clone(); |coro| async move { + let coro: Coro<()> = coro.into(); let conn1 = db1.connect(&coro).await.unwrap(); conn1.execute("CREATE TABLE t(x)").unwrap(); conn1 @@ -846,6 +904,7 @@ mod tests { { let opts = DatabaseReplaySessionOpts { use_implicit_rowid: false, + transform: None, }; let mut session = db2.start_replay_session(&coro, opts).await.unwrap(); let opts = Default::default(); @@ -904,6 +963,7 @@ mod tests { let db1 = db1.clone(); let db2 = db2.clone(); |coro| async move { + let coro: Coro<()> = coro.into(); let conn1 = db1.connect(&coro).await.unwrap(); conn1.execute("CREATE TABLE t(x TEXT PRIMARY KEY)").unwrap(); conn1.execute("INSERT INTO t(x) VALUES ('a')").unwrap(); @@ -915,6 +975,7 @@ mod tests { { let opts = DatabaseReplaySessionOpts { use_implicit_rowid: false, + transform: None, }; let mut session = db2.start_replay_session(&coro, opts).await.unwrap(); let opts = Default::default(); @@ -969,6 +1030,7 @@ mod tests { let mut gen = genawaiter::sync::Gen::new({ |coro| async move { + let coro: Coro<()> = coro.into(); let conn1 = db1.connect(&coro).await.unwrap(); conn1 .execute("CREATE TABLE t(x TEXT PRIMARY KEY, y)") @@ -988,6 +1050,7 @@ mod tests { { let opts = DatabaseReplaySessionOpts { use_implicit_rowid: false, + transform: None, }; let mut session = db3.start_replay_session(&coro, opts).await.unwrap(); @@ -1094,6 +1157,7 @@ mod tests { let mut gen = genawaiter::sync::Gen::new({ |coro| async move { + let coro: Coro<()> = coro.into(); let conn1 = db1.connect(&coro).await.unwrap(); conn1 .execute("CREATE TABLE t(x TEXT PRIMARY KEY, y)") @@ -1104,6 +1168,7 @@ mod tests { { let opts = DatabaseReplaySessionOpts { use_implicit_rowid: false, + transform: None, }; let mut session = db2.start_replay_session(&coro, opts).await.unwrap(); @@ -1177,6 +1242,7 @@ mod tests { let mut gen = genawaiter::sync::Gen::new({ |coro| async move { + let coro: Coro<()> = coro.into(); let conn1 = db1.connect(&coro).await.unwrap(); conn1 .execute("CREATE TABLE t(x TEXT PRIMARY KEY, y)") @@ -1188,6 +1254,7 @@ mod tests { { let opts = DatabaseReplaySessionOpts { use_implicit_rowid: false, + transform: None, }; let mut session = db2.start_replay_session(&coro, opts).await.unwrap(); @@ -1255,6 +1322,7 @@ mod tests { let mut gen = genawaiter::sync::Gen::new({ |coro| async move { + let coro: Coro<()> = coro.into(); let conn1 = db1.connect(&coro).await.unwrap(); conn1 .execute("CREATE TABLE t(x TEXT PRIMARY KEY, y, z)") @@ -1283,6 +1351,7 @@ mod tests { { let opts = DatabaseReplaySessionOpts { use_implicit_rowid: false, + transform: None, }; let mut session = db3.start_replay_session(&coro, opts).await.unwrap(); diff --git a/sync/engine/src/errors.rs b/sync/engine/src/errors.rs index abd40a666..d71b4b947 100644 --- a/sync/engine/src/errors.rs +++ b/sync/engine/src/errors.rs @@ -10,6 +10,8 @@ pub enum Error { DatabaseSyncEngineError(String), #[error("database sync engine conflict: {0}")] DatabaseSyncEngineConflict(String), + #[error("database sync engine IO error: {0}")] + IoError(#[from] std::io::Error), } #[cfg(test)] diff --git a/sync/engine/src/io_operations.rs b/sync/engine/src/io_operations.rs index 517ad2601..777687e18 100644 --- a/sync/engine/src/io_operations.rs +++ b/sync/engine/src/io_operations.rs @@ -12,9 +12,9 @@ pub trait IoOperations { fn open_tape(&self, path: &str, capture: bool) -> Result; fn try_open(&self, path: &str) -> Result>>; fn create(&self, path: &str) -> Result>; - fn truncate( + fn truncate( &self, - coro: &Coro, + coro: &Coro, file: Arc, len: usize, ) -> impl std::future::Future>; @@ -47,9 +47,9 @@ impl IoOperations for Arc { } } - async fn truncate( + async fn truncate( &self, - coro: &Coro, + coro: &Coro, file: Arc, len: usize, ) -> Result<()> { diff --git a/sync/engine/src/protocol_io.rs b/sync/engine/src/protocol_io.rs index 77577d6c5..f381b4c1c 100644 --- a/sync/engine/src/protocol_io.rs +++ b/sync/engine/src/protocol_io.rs @@ -15,6 +15,11 @@ pub trait ProtocolIO { type DataCompletion: DataCompletion; fn full_read(&self, path: &str) -> Result; fn full_write(&self, path: &str, content: Vec) -> Result; - fn http(&self, method: &str, path: &str, body: Option>) - -> Result; + fn http( + &self, + method: &str, + path: &str, + body: Option>, + headers: &[(&str, &str)], + ) -> Result; } diff --git a/sync/engine/src/server_proto.rs b/sync/engine/src/server_proto.rs index 0289afe16..19e72082c 100644 --- a/sync/engine/src/server_proto.rs +++ b/sync/engine/src/server_proto.rs @@ -3,6 +3,64 @@ use std::collections::VecDeque; use bytes::Bytes; use serde::{Deserialize, Serialize}; +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +#[serde(rename_all = "snake_case")] +#[derive(prost::Enumeration)] +#[repr(i32)] +pub enum PageUpdatesEncodingReq { + Raw = 0, + Zstd = 1, +} + +#[derive(prost::Message)] +pub struct PullUpdatesReqProtoBody { + #[prost(enumeration = "PageUpdatesEncodingReq", tag = "1")] + pub encoding: i32, + #[prost(string, tag = "2")] + pub server_revision: String, + #[prost(string, tag = "3")] + pub client_revision: String, + #[prost(uint32, tag = "4")] + pub long_poll_timeout_ms: u32, + #[prost(bytes, tag = "5")] + pub server_pages: Bytes, + #[prost(bytes, tag = "6")] + pub client_pages: Bytes, +} + +#[derive(prost::Message, Serialize, Deserialize, Clone, Eq, PartialEq)] +pub struct PageData { + #[prost(uint64, tag = "1")] + pub page_id: u64, + + #[serde(with = "bytes_as_base64_pad")] + #[prost(bytes, tag = "2")] + pub encoded_page: Bytes, +} + +#[derive(prost::Message)] +pub struct PageSetRawEncodingProto {} + +#[derive(prost::Message)] +pub struct PageSetZstdEncodingProto { + #[prost(int32, tag = "1")] + pub level: i32, + #[prost(uint32, repeated, tag = "2")] + pub pages_dict: Vec, +} + +#[derive(prost::Message)] +pub struct PullUpdatesRespProtoBody { + #[prost(string, tag = "1")] + pub server_revision: String, + #[prost(uint64, tag = "2")] + pub db_size: u64, + #[prost(optional, message, tag = "3")] + pub raw_encoding: Option, + #[prost(optional, message, tag = "4")] + pub zstd_encoding: Option, +} + #[derive(Serialize, Deserialize, Debug)] pub struct PipelineReqBody { pub baton: Option, @@ -22,8 +80,6 @@ pub enum StreamRequest { #[serde(skip_deserializing)] #[default] None, - /// See [`CloseStreamReq`] - Close(CloseStreamReq), /// See [`ExecuteStreamReq`] Execute(ExecuteStreamReq), } @@ -33,15 +89,53 @@ pub enum StreamRequest { pub enum StreamResult { #[default] None, - Ok, + Ok { + response: StreamResponse, + }, Error { error: Error, }, } -#[derive(Serialize, Deserialize, Debug)] -/// A request to close the current stream. -pub struct CloseStreamReq {} +#[derive(Serialize, Deserialize, Debug, PartialEq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum StreamResponse { + Execute(ExecuteStreamResp), +} + +#[derive(Serialize, Deserialize, Debug, PartialEq)] +/// A response to a [`ExecuteStreamReq`]. +pub struct ExecuteStreamResp { + pub result: StmtResult, +} +#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Default)] +pub struct StmtResult { + pub cols: Vec, + pub rows: Vec, + pub affected_row_count: u64, + #[serde(with = "option_i64_as_str")] + pub last_insert_rowid: Option, + #[serde(default, with = "option_u64_as_str")] + pub replication_index: Option, + #[serde(default)] + pub rows_read: u64, + #[serde(default)] + pub rows_written: u64, + #[serde(default)] + pub query_duration_ms: f64, +} + +#[derive(Clone, Deserialize, Serialize, Debug, PartialEq)] +pub struct Col { + pub name: Option, + pub decltype: Option, +} + +#[derive(Clone, Deserialize, Serialize, Debug, PartialEq)] +#[serde(transparent)] +pub struct Row { + pub values: Vec, +} #[derive(Serialize, Deserialize, Debug)] /// A request to execute a single SQL statement. @@ -229,3 +323,80 @@ pub(crate) mod bytes_as_base64 { Ok(Bytes::from(bytes)) } } + +mod option_i64_as_str { + use serde::de::{Error, Visitor}; + use serde::{ser, Deserializer, Serialize as _}; + + pub fn serialize(value: &Option, ser: S) -> Result { + value.map(|v| v.to_string()).serialize(ser) + } + + pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result, D::Error> { + struct V; + + impl<'de> Visitor<'de> for V { + type Value = Option; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "a string representing a signed integer, or null") + } + + fn visit_some(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_any(V) + } + + fn visit_none(self) -> Result + where + E: Error, + { + Ok(None) + } + + fn visit_unit(self) -> Result + where + E: Error, + { + Ok(None) + } + + fn visit_i64(self, v: i64) -> Result + where + E: Error, + { + Ok(Some(v)) + } + + fn visit_str(self, v: &str) -> Result + where + E: Error, + { + v.parse().map_err(E::custom).map(Some) + } + } + + d.deserialize_option(V) + } +} + +pub(crate) mod bytes_as_base64_pad { + use base64::{engine::general_purpose::STANDARD, Engine as _}; + use bytes::Bytes; + use serde::{de, ser}; + use serde::{de::Error as _, Serialize as _}; + + pub fn serialize(value: &Bytes, ser: S) -> Result { + STANDARD.encode(value).serialize(ser) + } + + pub fn deserialize<'de, D: de::Deserializer<'de>>(de: D) -> Result { + let text = <&'de str as de::Deserialize>::deserialize(de)?; + let bytes = STANDARD.decode(text).map_err(|_| { + D::Error::invalid_value(de::Unexpected::Str(text), &"binary data encoded as base64") + })?; + Ok(Bytes::from(bytes)) + } +} diff --git a/sync/engine/src/types.rs b/sync/engine/src/types.rs index 5b32462c0..20964b761 100644 --- a/sync/engine/src/types.rs +++ b/sync/engine/src/types.rs @@ -1,8 +1,36 @@ +use std::{cell::RefCell, collections::HashMap}; + use serde::{Deserialize, Serialize}; use crate::{errors::Error, Result}; -pub type Coro = genawaiter::sync::Co>; +pub struct Coro { + pub ctx: RefCell, + gen: genawaiter::sync::Co>, +} + +impl Coro { + pub fn new(ctx: Ctx, gen: genawaiter::sync::Co>) -> Self { + Self { + ctx: RefCell::new(ctx), + gen, + } + } + pub async fn yield_(&self, value: ProtocolCommand) -> Result<()> { + let ctx = self.gen.yield_(value).await?; + self.ctx.replace(ctx); + Ok(()) + } +} + +impl From>> for Coro<()> { + fn from(value: genawaiter::sync::Co>) -> Self { + Self { + gen: value, + ctx: RefCell::new(()), + } + } +} #[derive(Debug, Deserialize, Serialize)] pub struct DbSyncInfo { @@ -17,6 +45,17 @@ pub struct DbSyncStatus { pub max_frame_no: u64, } +#[derive(Debug)] +pub struct DbChangesStatus { + pub revision: DatabasePullRevision, + pub file_path: String, +} + +pub struct SyncEngineStats { + pub cdc_operations: i64, + pub wal_size: i64, +} + #[derive(Debug, Clone, Copy, PartialEq)] pub enum DatabaseChangeType { Delete, @@ -29,12 +68,30 @@ pub struct DatabaseMetadata { /// Unique identifier of the client - generated on sync startup pub client_unique_id: String, /// Latest generation from remote which was pulled locally to the Synced DB - pub synced_generation: u64, - /// Latest frame number from remote which was pulled locally to the Synced DB - pub synced_frame_no: Option, + pub synced_revision: Option, /// pair of frame_no for Draft and Synced DB such that content of the database file up to these frames is identical - pub draft_wal_match_watermark: u64, - pub synced_wal_match_watermark: u64, + pub revert_since_wal_salt: Option>, + pub revert_since_wal_watermark: u64, + pub last_pushed_pull_gen_hint: i64, + pub last_pushed_change_id_hint: i64, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum DatabasePullRevision { + Legacy { + generation: u64, + synced_frame_no: Option, + }, + V1 { + revision: String, + }, +} + +#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq)] +pub enum DatabaseSyncEngineProtocolVersion { + Legacy, + V1, } impl DatabaseMetadata { @@ -199,6 +256,21 @@ impl TryFrom<&turso_core::Row> for DatabaseChange { } } +pub struct DatabaseRowMutation { + pub change_time: u64, + pub table_name: String, + pub id: i64, + pub change_type: DatabaseChangeType, + pub before: Option>, + pub after: Option>, + pub updates: Option>, +} + +pub struct DatabaseRowStatement { + pub sql: String, + pub values: Vec, +} + pub enum DatabaseTapeRowChangeType { Delete { before: Vec, diff --git a/sync/engine/src/wal_session.rs b/sync/engine/src/wal_session.rs index 489634e01..379559443 100644 --- a/sync/engine/src/wal_session.rs +++ b/sync/engine/src/wal_session.rs @@ -38,9 +38,9 @@ impl WalSession { let info = self.conn.wal_get_frame(frame_no, frame)?; Ok(info) } - pub fn end(&mut self) -> Result<()> { + pub fn end(&mut self, force_commit: bool) -> Result<()> { assert!(self.in_txn); - self.conn.wal_insert_end(false)?; + self.conn.wal_insert_end(force_commit)?; self.in_txn = false; Ok(()) } @@ -53,7 +53,7 @@ impl Drop for WalSession { fn drop(&mut self) { if self.in_txn { let _ = self - .end() + .end(false) .inspect_err(|e| tracing::error!("failed to close WAL session: {}", e)); } } diff --git a/sync/javascript/index.d.ts b/sync/javascript/index.d.ts index fe82c3575..62d3b18ae 100644 --- a/sync/javascript/index.d.ts +++ b/sync/javascript/index.d.ts @@ -67,6 +67,14 @@ export declare class Database { * `Ok(())` if the database is closed successfully. */ close(): void + /** + * Sets the default safe integers mode for all statements from this database. + * + * # Arguments + * + * * `toggle` - Whether to use safe integers by default. + */ + defaultSafeIntegers(toggle?: boolean | undefined | null): void /** Runs the I/O loop synchronously. */ ioLoopSync(): void /** Runs the I/O loop asynchronously, returning a Promise. */ @@ -107,11 +115,22 @@ export declare class Statement { raw(raw?: boolean | undefined | null): void /** Sets the presentation mode to pluck. */ pluck(pluck?: boolean | undefined | null): void + /** + * Sets safe integers mode for this statement. + * + * # Arguments + * + * * `toggle` - Whether to use safe integers. + */ + safeIntegers(toggle?: boolean | undefined | null): void + /** Get column information for the statement */ + columns(): unknown[] /** Finalizes the statement. */ finalize(): void } export declare class GeneratorHolder { resume(error?: string | undefined | null): number + take(): GeneratorResponse | null } export declare class JsDataCompletion { @@ -143,16 +162,42 @@ export declare class SyncEngine { protocolIo(): JsProtocolRequestData | null sync(): GeneratorHolder push(): GeneratorHolder + stats(): GeneratorHolder pull(): GeneratorHolder + checkpoint(): GeneratorHolder open(): Database } +export declare const enum DatabaseChangeTypeJs { + Insert = 0, + Update = 1, + Delete = 2 +} + export interface DatabaseOpts { path: string } +export interface DatabaseRowMutationJs { + changeTime: number + tableName: string + id: number + changeType: DatabaseChangeTypeJs + before?: Record + after?: Record + updates?: Record +} + +export interface DatabaseRowStatementJs { + sql: string + values: Array +} + +export type GeneratorResponse = + | { type: 'SyncEngineStats', operations: number, wal: number } + export type JsProtocolRequest = - | { type: 'Http', method: string, path: string, body?: Buffer } + | { type: 'Http', method: string, path: string, body?: Array, headers: Array<[string, string]> } | { type: 'FullRead', path: string } | { type: 'FullWrite', path: string, content: Array } @@ -160,5 +205,13 @@ export interface SyncEngineOpts { path: string clientName?: string walPullBatchSize?: number - enableTracing?: boolean + enableTracing?: string + tablesIgnore?: Array + transform?: (arg: DatabaseRowMutationJs) => DatabaseRowStatementJs | null + protocolVersion?: SyncEngineProtocolVersion +} + +export declare const enum SyncEngineProtocolVersion { + Legacy = 0, + V1 = 1 } diff --git a/sync/javascript/index.js b/sync/javascript/index.js index bec3bf26f..24f60146e 100644 --- a/sync/javascript/index.js +++ b/sync/javascript/index.js @@ -79,12 +79,7 @@ function requireNative() { loadErrors.push(e) } try { - const binding = require('@tursodatabase/sync-android-arm64') - const bindingPackageVersion = require('@tursodatabase/sync-android-arm64/package.json').version - if (bindingPackageVersion !== '0.1.4-pre.5') { - throw new Error(`Native binding package version mismatch, expected 0.1.4-pre.5 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) - } - return binding + return require('@tursodatabase/sync-android-arm64') } catch (e) { loadErrors.push(e) } @@ -95,12 +90,7 @@ function requireNative() { loadErrors.push(e) } try { - const binding = require('@tursodatabase/sync-android-arm-eabi') - const bindingPackageVersion = require('@tursodatabase/sync-android-arm-eabi/package.json').version - if (bindingPackageVersion !== '0.1.4-pre.5') { - throw new Error(`Native binding package version mismatch, expected 0.1.4-pre.5 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) - } - return binding + return require('@tursodatabase/sync-android-arm-eabi') } catch (e) { loadErrors.push(e) } @@ -115,12 +105,7 @@ function requireNative() { loadErrors.push(e) } try { - const binding = require('@tursodatabase/sync-win32-x64-msvc') - const bindingPackageVersion = require('@tursodatabase/sync-win32-x64-msvc/package.json').version - if (bindingPackageVersion !== '0.1.4-pre.5') { - throw new Error(`Native binding package version mismatch, expected 0.1.4-pre.5 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) - } - return binding + return require('@tursodatabase/sync-win32-x64-msvc') } catch (e) { loadErrors.push(e) } @@ -131,12 +116,7 @@ function requireNative() { loadErrors.push(e) } try { - const binding = require('@tursodatabase/sync-win32-ia32-msvc') - const bindingPackageVersion = require('@tursodatabase/sync-win32-ia32-msvc/package.json').version - if (bindingPackageVersion !== '0.1.4-pre.5') { - throw new Error(`Native binding package version mismatch, expected 0.1.4-pre.5 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) - } - return binding + return require('@tursodatabase/sync-win32-ia32-msvc') } catch (e) { loadErrors.push(e) } @@ -147,12 +127,7 @@ function requireNative() { loadErrors.push(e) } try { - const binding = require('@tursodatabase/sync-win32-arm64-msvc') - const bindingPackageVersion = require('@tursodatabase/sync-win32-arm64-msvc/package.json').version - if (bindingPackageVersion !== '0.1.4-pre.5') { - throw new Error(`Native binding package version mismatch, expected 0.1.4-pre.5 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) - } - return binding + return require('@tursodatabase/sync-win32-arm64-msvc') } catch (e) { loadErrors.push(e) } @@ -166,12 +141,7 @@ function requireNative() { loadErrors.push(e) } try { - const binding = require('@tursodatabase/sync-darwin-universal') - const bindingPackageVersion = require('@tursodatabase/sync-darwin-universal/package.json').version - if (bindingPackageVersion !== '0.1.4-pre.5') { - throw new Error(`Native binding package version mismatch, expected 0.1.4-pre.5 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) - } - return binding + return require('@tursodatabase/sync-darwin-universal') } catch (e) { loadErrors.push(e) } @@ -182,12 +152,7 @@ function requireNative() { loadErrors.push(e) } try { - const binding = require('@tursodatabase/sync-darwin-x64') - const bindingPackageVersion = require('@tursodatabase/sync-darwin-x64/package.json').version - if (bindingPackageVersion !== '0.1.4-pre.5') { - throw new Error(`Native binding package version mismatch, expected 0.1.4-pre.5 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) - } - return binding + return require('@tursodatabase/sync-darwin-x64') } catch (e) { loadErrors.push(e) } @@ -198,12 +163,7 @@ function requireNative() { loadErrors.push(e) } try { - const binding = require('@tursodatabase/sync-darwin-arm64') - const bindingPackageVersion = require('@tursodatabase/sync-darwin-arm64/package.json').version - if (bindingPackageVersion !== '0.1.4-pre.5') { - throw new Error(`Native binding package version mismatch, expected 0.1.4-pre.5 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) - } - return binding + return require('@tursodatabase/sync-darwin-arm64') } catch (e) { loadErrors.push(e) } @@ -218,12 +178,7 @@ function requireNative() { loadErrors.push(e) } try { - const binding = require('@tursodatabase/sync-freebsd-x64') - const bindingPackageVersion = require('@tursodatabase/sync-freebsd-x64/package.json').version - if (bindingPackageVersion !== '0.1.4-pre.5') { - throw new Error(`Native binding package version mismatch, expected 0.1.4-pre.5 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) - } - return binding + return require('@tursodatabase/sync-freebsd-x64') } catch (e) { loadErrors.push(e) } @@ -234,12 +189,7 @@ function requireNative() { loadErrors.push(e) } try { - const binding = require('@tursodatabase/sync-freebsd-arm64') - const bindingPackageVersion = require('@tursodatabase/sync-freebsd-arm64/package.json').version - if (bindingPackageVersion !== '0.1.4-pre.5') { - throw new Error(`Native binding package version mismatch, expected 0.1.4-pre.5 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) - } - return binding + return require('@tursodatabase/sync-freebsd-arm64') } catch (e) { loadErrors.push(e) } @@ -255,12 +205,7 @@ function requireNative() { loadErrors.push(e) } try { - const binding = require('@tursodatabase/sync-linux-x64-musl') - const bindingPackageVersion = require('@tursodatabase/sync-linux-x64-musl/package.json').version - if (bindingPackageVersion !== '0.1.4-pre.5') { - throw new Error(`Native binding package version mismatch, expected 0.1.4-pre.5 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) - } - return binding + return require('@tursodatabase/sync-linux-x64-musl') } catch (e) { loadErrors.push(e) } @@ -271,12 +216,7 @@ function requireNative() { loadErrors.push(e) } try { - const binding = require('@tursodatabase/sync-linux-x64-gnu') - const bindingPackageVersion = require('@tursodatabase/sync-linux-x64-gnu/package.json').version - if (bindingPackageVersion !== '0.1.4-pre.5') { - throw new Error(`Native binding package version mismatch, expected 0.1.4-pre.5 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) - } - return binding + return require('@tursodatabase/sync-linux-x64-gnu') } catch (e) { loadErrors.push(e) } @@ -289,12 +229,7 @@ function requireNative() { loadErrors.push(e) } try { - const binding = require('@tursodatabase/sync-linux-arm64-musl') - const bindingPackageVersion = require('@tursodatabase/sync-linux-arm64-musl/package.json').version - if (bindingPackageVersion !== '0.1.4-pre.5') { - throw new Error(`Native binding package version mismatch, expected 0.1.4-pre.5 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) - } - return binding + return require('@tursodatabase/sync-linux-arm64-musl') } catch (e) { loadErrors.push(e) } @@ -305,12 +240,7 @@ function requireNative() { loadErrors.push(e) } try { - const binding = require('@tursodatabase/sync-linux-arm64-gnu') - const bindingPackageVersion = require('@tursodatabase/sync-linux-arm64-gnu/package.json').version - if (bindingPackageVersion !== '0.1.4-pre.5') { - throw new Error(`Native binding package version mismatch, expected 0.1.4-pre.5 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) - } - return binding + return require('@tursodatabase/sync-linux-arm64-gnu') } catch (e) { loadErrors.push(e) } @@ -323,12 +253,7 @@ function requireNative() { loadErrors.push(e) } try { - const binding = require('@tursodatabase/sync-linux-arm-musleabihf') - const bindingPackageVersion = require('@tursodatabase/sync-linux-arm-musleabihf/package.json').version - if (bindingPackageVersion !== '0.1.4-pre.5') { - throw new Error(`Native binding package version mismatch, expected 0.1.4-pre.5 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) - } - return binding + return require('@tursodatabase/sync-linux-arm-musleabihf') } catch (e) { loadErrors.push(e) } @@ -339,12 +264,7 @@ function requireNative() { loadErrors.push(e) } try { - const binding = require('@tursodatabase/sync-linux-arm-gnueabihf') - const bindingPackageVersion = require('@tursodatabase/sync-linux-arm-gnueabihf/package.json').version - if (bindingPackageVersion !== '0.1.4-pre.5') { - throw new Error(`Native binding package version mismatch, expected 0.1.4-pre.5 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) - } - return binding + return require('@tursodatabase/sync-linux-arm-gnueabihf') } catch (e) { loadErrors.push(e) } @@ -357,12 +277,7 @@ function requireNative() { loadErrors.push(e) } try { - const binding = require('@tursodatabase/sync-linux-riscv64-musl') - const bindingPackageVersion = require('@tursodatabase/sync-linux-riscv64-musl/package.json').version - if (bindingPackageVersion !== '0.1.4-pre.5') { - throw new Error(`Native binding package version mismatch, expected 0.1.4-pre.5 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) - } - return binding + return require('@tursodatabase/sync-linux-riscv64-musl') } catch (e) { loadErrors.push(e) } @@ -373,12 +288,7 @@ function requireNative() { loadErrors.push(e) } try { - const binding = require('@tursodatabase/sync-linux-riscv64-gnu') - const bindingPackageVersion = require('@tursodatabase/sync-linux-riscv64-gnu/package.json').version - if (bindingPackageVersion !== '0.1.4-pre.5') { - throw new Error(`Native binding package version mismatch, expected 0.1.4-pre.5 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) - } - return binding + return require('@tursodatabase/sync-linux-riscv64-gnu') } catch (e) { loadErrors.push(e) } @@ -390,12 +300,7 @@ function requireNative() { loadErrors.push(e) } try { - const binding = require('@tursodatabase/sync-linux-ppc64-gnu') - const bindingPackageVersion = require('@tursodatabase/sync-linux-ppc64-gnu/package.json').version - if (bindingPackageVersion !== '0.1.4-pre.5') { - throw new Error(`Native binding package version mismatch, expected 0.1.4-pre.5 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) - } - return binding + return require('@tursodatabase/sync-linux-ppc64-gnu') } catch (e) { loadErrors.push(e) } @@ -406,12 +311,7 @@ function requireNative() { loadErrors.push(e) } try { - const binding = require('@tursodatabase/sync-linux-s390x-gnu') - const bindingPackageVersion = require('@tursodatabase/sync-linux-s390x-gnu/package.json').version - if (bindingPackageVersion !== '0.1.4-pre.5') { - throw new Error(`Native binding package version mismatch, expected 0.1.4-pre.5 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) - } - return binding + return require('@tursodatabase/sync-linux-s390x-gnu') } catch (e) { loadErrors.push(e) } @@ -421,49 +321,34 @@ function requireNative() { } else if (process.platform === 'openharmony') { if (process.arch === 'arm64') { try { - return require('./turso-sync-js.openharmony-arm64.node') + return require('./turso-sync-js.linux-arm64-ohos.node') } catch (e) { loadErrors.push(e) } try { - const binding = require('@tursodatabase/sync-openharmony-arm64') - const bindingPackageVersion = require('@tursodatabase/sync-openharmony-arm64/package.json').version - if (bindingPackageVersion !== '0.1.4-pre.5') { - throw new Error(`Native binding package version mismatch, expected 0.1.4-pre.5 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) - } - return binding + return require('@tursodatabase/sync-linux-arm64-ohos') } catch (e) { loadErrors.push(e) } } else if (process.arch === 'x64') { try { - return require('./turso-sync-js.openharmony-x64.node') + return require('./turso-sync-js.linux-x64-ohos.node') } catch (e) { loadErrors.push(e) } try { - const binding = require('@tursodatabase/sync-openharmony-x64') - const bindingPackageVersion = require('@tursodatabase/sync-openharmony-x64/package.json').version - if (bindingPackageVersion !== '0.1.4-pre.5') { - throw new Error(`Native binding package version mismatch, expected 0.1.4-pre.5 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) - } - return binding + return require('@tursodatabase/sync-linux-x64-ohos') } catch (e) { loadErrors.push(e) } } else if (process.arch === 'arm') { try { - return require('./turso-sync-js.openharmony-arm.node') + return require('./turso-sync-js.linux-arm-ohos.node') } catch (e) { loadErrors.push(e) } try { - const binding = require('@tursodatabase/sync-openharmony-arm') - const bindingPackageVersion = require('@tursodatabase/sync-openharmony-arm/package.json').version - if (bindingPackageVersion !== '0.1.4-pre.5') { - throw new Error(`Native binding package version mismatch, expected 0.1.4-pre.5 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) - } - return binding + return require('@tursodatabase/sync-linux-arm-ohos') } catch (e) { loadErrors.push(e) } @@ -508,7 +393,7 @@ if (!nativeBinding) { throw new Error(`Failed to load native binding`) } -const { Database, Statement, GeneratorHolder, JsDataCompletion, JsDataPollResult, JsProtocolIo, JsProtocolRequestData, SyncEngine } = nativeBinding +const { Database, Statement, GeneratorHolder, JsDataCompletion, JsDataPollResult, JsProtocolIo, JsProtocolRequestData, SyncEngine, DatabaseChangeTypeJs, SyncEngineProtocolVersion } = nativeBinding export { Database } export { Statement } export { GeneratorHolder } @@ -517,3 +402,5 @@ export { JsDataPollResult } export { JsProtocolIo } export { JsProtocolRequestData } export { SyncEngine } +export { DatabaseChangeTypeJs } +export { SyncEngineProtocolVersion } diff --git a/sync/javascript/package.json b/sync/javascript/package.json index e7220025a..f43d81297 100644 --- a/sync/javascript/package.json +++ b/sync/javascript/package.json @@ -59,4 +59,4 @@ "dependencies": { "@tursodatabase/database": "~0.1.4-pre.5" } -} \ No newline at end of file +} diff --git a/sync/javascript/src/generator.rs b/sync/javascript/src/generator.rs index b7fc0e487..c2917eb94 100644 --- a/sync/javascript/src/generator.rs +++ b/sync/javascript/src/generator.rs @@ -1,5 +1,9 @@ +use napi::Env; use napi_derive::napi; -use std::{future::Future, sync::Mutex}; +use std::{ + future::Future, + sync::{Arc, Mutex}, +}; use turso_sync_engine::types::ProtocolCommand; @@ -7,18 +11,18 @@ pub const GENERATOR_RESUME_IO: u32 = 0; pub const GENERATOR_RESUME_DONE: u32 = 1; pub trait Generator { - fn resume(&mut self, result: Option) -> napi::Result; + fn resume(&mut self, env: Env, result: Option) -> napi::Result; } impl>> Generator - for genawaiter::sync::Gen, F> + for genawaiter::sync::Gen, F> { - fn resume(&mut self, error: Option) -> napi::Result { + fn resume(&mut self, env: Env, error: Option) -> napi::Result { let result = match error { Some(err) => Err(turso_sync_engine::errors::Error::DatabaseSyncEngineError( format!("JsProtocolIo error: {err}"), )), - None => Ok(()), + None => Ok(env), }; match self.resume_with(result) { genawaiter::GeneratorState::Yielded(ProtocolCommand::IO) => Ok(GENERATOR_RESUME_IO), @@ -31,15 +35,25 @@ impl>> Generator } } +#[napi(discriminant = "type")] +pub enum GeneratorResponse { + SyncEngineStats { operations: i64, wal: i64 }, +} + #[napi] pub struct GeneratorHolder { pub(crate) inner: Box>, + pub(crate) response: Arc>>, } #[napi] impl GeneratorHolder { #[napi] - pub fn resume(&self, error: Option) -> napi::Result { - self.inner.lock().unwrap().resume(error) + pub fn resume(&self, env: Env, error: Option) -> napi::Result { + self.inner.lock().unwrap().resume(env, error) + } + #[napi] + pub fn take(&self) -> Option { + self.response.lock().unwrap().take() } } diff --git a/sync/javascript/src/js_protocol_io.rs b/sync/javascript/src/js_protocol_io.rs index fffd7d026..429c85f7f 100644 --- a/sync/javascript/src/js_protocol_io.rs +++ b/sync/javascript/src/js_protocol_io.rs @@ -15,6 +15,7 @@ pub enum JsProtocolRequest { method: String, path: String, body: Option>, + headers: Vec<(String, String)>, }, FullRead { path: String, @@ -130,11 +131,16 @@ impl ProtocolIO for JsProtocolIo { method: &str, path: &str, body: Option>, + headers: &[(&str, &str)], ) -> turso_sync_engine::Result { Ok(self.add_request(JsProtocolRequest::Http { method: method.to_string(), path: path.to_string(), body, + headers: headers + .iter() + .map(|x| (x.0.to_string(), x.1.to_string())) + .collect(), })) } diff --git a/sync/javascript/src/lib.rs b/sync/javascript/src/lib.rs index e83b8d9ec..3c6a2a7c9 100644 --- a/sync/javascript/src/lib.rs +++ b/sync/javascript/src/lib.rs @@ -3,19 +3,28 @@ pub mod generator; pub mod js_protocol_io; -use std::sync::{Arc, Mutex, OnceLock}; +use std::{ + collections::HashMap, + sync::{Arc, Mutex, OnceLock, RwLock, RwLockReadGuard, RwLockWriteGuard}, +}; -use napi::bindgen_prelude::AsyncTask; +use napi::{ + bindgen_prelude::{AsyncTask, Either5, Function, FunctionRef, Null}, + Env, +}; use napi_derive::napi; use tracing_subscriber::{filter::LevelFilter, fmt::format::FmtSpan}; use turso_node::IoLoopTask; use turso_sync_engine::{ database_sync_engine::{DatabaseSyncEngine, DatabaseSyncEngineOpts}, - types::Coro, + types::{ + Coro, DatabaseChangeType, DatabaseRowMutation, DatabaseRowStatement, + DatabaseSyncEngineProtocolVersion, + }, }; use crate::{ - generator::GeneratorHolder, + generator::{GeneratorHolder, GeneratorResponse}, js_protocol_io::{JsProtocolIo, JsProtocolRequestData}, }; @@ -29,18 +38,92 @@ pub struct SyncEngine { path: String, client_name: String, wal_pull_batch_size: u32, + protocol_version: DatabaseSyncEngineProtocolVersion, + tables_ignore: Vec, + transform: Option>>, io: Arc, protocol: Arc, - sync_engine: Arc>>>, + sync_engine: Arc>>>, opened: Arc>>, } +#[napi] +pub enum DatabaseChangeTypeJs { + Insert, + Update, + Delete, +} + +#[napi] +pub enum SyncEngineProtocolVersion { + Legacy, + V1, +} + +fn core_change_type_to_js(value: DatabaseChangeType) -> DatabaseChangeTypeJs { + match value { + DatabaseChangeType::Delete => DatabaseChangeTypeJs::Delete, + DatabaseChangeType::Update => DatabaseChangeTypeJs::Update, + DatabaseChangeType::Insert => DatabaseChangeTypeJs::Insert, + } +} +fn js_value_to_core(value: Either5>) -> turso_core::Value { + match value { + Either5::A(_) => turso_core::Value::Null, + Either5::B(value) => turso_core::Value::Integer(value as i64), + Either5::C(value) => turso_core::Value::Float(value), + Either5::D(value) => turso_core::Value::Text(turso_core::types::Text::new(&value)), + Either5::E(value) => turso_core::Value::Blob(value), + } +} +fn core_value_to_js(value: turso_core::Value) -> Either5> { + match value { + turso_core::Value::Null => Either5::>::A(Null), + turso_core::Value::Integer(value) => Either5::>::B(value), + turso_core::Value::Float(value) => Either5::>::C(value), + turso_core::Value::Text(value) => { + Either5::>::D(value.as_str().to_string()) + } + turso_core::Value::Blob(value) => Either5::>::E(value), + } +} +fn core_values_map_to_js( + value: HashMap, +) -> HashMap>> { + let mut result = HashMap::new(); + for (key, value) in value { + result.insert(key, core_value_to_js(value)); + } + result +} + #[napi(object)] +pub struct DatabaseRowMutationJs { + pub change_time: i64, + pub table_name: String, + pub id: i64, + pub change_type: DatabaseChangeTypeJs, + pub before: Option>>>, + pub after: Option>>>, + pub updates: Option>>>, +} + +#[napi(object)] +#[derive(Debug)] +pub struct DatabaseRowStatementJs { + pub sql: String, + pub values: Vec>>, +} + +#[napi(object, object_to_js = false)] pub struct SyncEngineOpts { pub path: String, pub client_name: Option, pub wal_pull_batch_size: Option, pub enable_tracing: Option, + pub tables_ignore: Option>, + pub transform: Option>>, + pub protocol_version: Option, } static TRACING_INIT: OnceLock<()> = OnceLock::new(); @@ -81,19 +164,65 @@ impl SyncEngine { path: opts.path, client_name: opts.client_name.unwrap_or("turso-sync-js".to_string()), wal_pull_batch_size: opts.wal_pull_batch_size.unwrap_or(100), - sync_engine: Arc::new(Mutex::new(None)), + tables_ignore: opts.tables_ignore.unwrap_or(Vec::new()), + transform: opts.transform.map(|x| x.create_ref().unwrap()), + sync_engine: Arc::new(RwLock::new(None)), io, protocol: Arc::new(JsProtocolIo::default()), #[allow(clippy::arc_with_non_send_sync)] opened: Arc::new(Mutex::new(None)), + protocol_version: match opts.protocol_version { + Some(SyncEngineProtocolVersion::Legacy) | None => { + DatabaseSyncEngineProtocolVersion::Legacy + } + _ => DatabaseSyncEngineProtocolVersion::V1, + }, }) } #[napi] - pub fn init(&self) -> GeneratorHolder { + pub fn init(&mut self, env: Env) -> GeneratorHolder { + let transform: Option< + Arc< + dyn Fn( + &Env, + DatabaseRowMutation, + ) + -> turso_sync_engine::Result> + + 'static, + >, + > = match self.transform.take() { + Some(f) => Some(Arc::new(move |env, mutation| { + let result = f + .borrow_back(&env) + .unwrap() + .call(DatabaseRowMutationJs { + change_time: mutation.change_time as i64, + table_name: mutation.table_name, + id: mutation.id, + change_type: core_change_type_to_js(mutation.change_type), + before: mutation.before.map(core_values_map_to_js), + after: mutation.after.map(core_values_map_to_js), + updates: mutation.updates.map(core_values_map_to_js), + }) + .map_err(|e| { + turso_sync_engine::errors::Error::DatabaseSyncEngineError(format!( + "transform callback failed: {e}" + )) + })?; + Ok(result.map(|statement| DatabaseRowStatement { + sql: statement.sql, + values: statement.values.into_iter().map(js_value_to_core).collect(), + })) + })), + None => None, + }; let opts = DatabaseSyncEngineOpts { client_name: self.client_name.clone(), wal_pull_batch_size: self.wal_pull_batch_size as u64, + tables_ignore: self.tables_ignore.clone(), + transform, + protocol_version_hint: self.protocol_version, }; let protocol = self.protocol.clone(); @@ -102,17 +231,19 @@ impl SyncEngine { let opened = self.opened.clone(); let path = self.path.clone(); let generator = genawaiter::sync::Gen::new(|coro| async move { + let coro = Coro::new(env, coro); let initialized = DatabaseSyncEngine::new(&coro, io.clone(), protocol, &path, opts).await?; - let connection = initialized.connect(&coro).await?; + let connection = initialized.connect_rw(&coro).await?; let db = turso_node::Database::create(None, io.clone(), connection, false); - *sync_engine.lock().unwrap() = Some(initialized); + *sync_engine.write().unwrap() = Some(initialized); *opened.lock().unwrap() = Some(db); Ok(()) }); GeneratorHolder { inner: Box::new(Mutex::new(generator)), + response: Arc::new(Mutex::new(None)), } } @@ -137,18 +268,63 @@ impl SyncEngine { } #[napi] - pub fn sync(&self) -> GeneratorHolder { - self.run(async move |coro, sync_engine| sync_engine.sync(coro).await) + pub fn sync(&self, env: Env) -> GeneratorHolder { + self.run(env, async move |coro, sync_engine| { + let mut sync_engine = try_write(sync_engine)?; + let sync_engine = try_unwrap_mut(&mut sync_engine)?; + sync_engine.sync(coro).await?; + Ok(None) + }) } #[napi] - pub fn push(&self) -> GeneratorHolder { - self.run(async move |coro, sync_engine| sync_engine.push(coro).await) + pub fn push(&self, env: Env) -> GeneratorHolder { + self.run(env, async move |coro, sync_engine| { + let sync_engine = try_read(sync_engine)?; + let sync_engine = try_unwrap(&sync_engine)?; + sync_engine.push_changes_to_remote(coro).await?; + Ok(None) + }) } #[napi] - pub fn pull(&self) -> GeneratorHolder { - self.run(async move |coro, sync_engine| sync_engine.pull(coro).await) + pub fn stats(&self, env: Env) -> GeneratorHolder { + self.run(env, async move |coro, sync_engine| { + let sync_engine = try_read(sync_engine)?; + let sync_engine = try_unwrap(&sync_engine)?; + let changes = sync_engine.stats(coro).await?; + Ok(Some(GeneratorResponse::SyncEngineStats { + operations: changes.cdc_operations, + wal: changes.wal_size, + })) + }) + } + + #[napi] + pub fn pull(&self, env: Env) -> GeneratorHolder { + self.run(env, async move |coro, sync_engine| { + let changes = { + let sync_engine = try_read(sync_engine)?; + let sync_engine = try_unwrap(&sync_engine)?; + sync_engine.wait_changes_from_remote(coro).await? + }; + if let Some(changes) = changes { + let mut sync_engine = try_write(sync_engine)?; + let sync_engine = try_unwrap_mut(&mut sync_engine)?; + sync_engine.apply_changes_from_remote(coro, changes).await?; + } + Ok(None) + }) + } + + #[napi] + pub fn checkpoint(&self, env: Env) -> GeneratorHolder { + self.run(env, async move |coro, sync_engine| { + let mut sync_engine = try_write(sync_engine)?; + let sync_engine = try_unwrap_mut(&mut sync_engine)?; + sync_engine.checkpoint(coro).await?; + Ok(None) + }) } #[napi] @@ -165,32 +341,76 @@ impl SyncEngine { fn run( &self, + env: Env, f: impl AsyncFnOnce( - &Coro, - &mut DatabaseSyncEngine, - ) -> turso_sync_engine::Result<()> + &Coro, + &Arc>>>, + ) -> turso_sync_engine::Result> + 'static, ) -> GeneratorHolder { + let response = Arc::new(Mutex::new(None)); let sync_engine = self.sync_engine.clone(); #[allow(clippy::await_holding_lock)] - let generator = genawaiter::sync::Gen::new(|coro| async move { - let Ok(mut sync_engine) = sync_engine.try_lock() else { - let nasty_error = "sync_engine is busy".to_string(); - return Err(turso_sync_engine::errors::Error::DatabaseSyncEngineError( - nasty_error, - )); - }; - let Some(sync_engine) = sync_engine.as_mut() else { - let error = "sync_engine must be initialized".to_string(); - return Err(turso_sync_engine::errors::Error::DatabaseSyncEngineError( - error, - )); - }; - f(&coro, sync_engine).await?; - Ok(()) + let generator = genawaiter::sync::Gen::new({ + let response = response.clone(); + |coro| async move { + let coro = Coro::new(env, coro); + *response.lock().unwrap() = f(&coro, &sync_engine).await?; + Ok(()) + } }); GeneratorHolder { inner: Box::new(Mutex::new(generator)), + response, } } } + +fn try_read( + sync_engine: &RwLock>>, +) -> turso_sync_engine::Result>>> { + let Ok(sync_engine) = sync_engine.try_read() else { + let nasty_error = "sync_engine is busy".to_string(); + return Err(turso_sync_engine::errors::Error::DatabaseSyncEngineError( + nasty_error, + )); + }; + Ok(sync_engine) +} + +fn try_write( + sync_engine: &RwLock>>, +) -> turso_sync_engine::Result>>> +{ + let Ok(sync_engine) = sync_engine.try_write() else { + let nasty_error = "sync_engine is busy".to_string(); + return Err(turso_sync_engine::errors::Error::DatabaseSyncEngineError( + nasty_error, + )); + }; + Ok(sync_engine) +} + +fn try_unwrap<'a>( + sync_engine: &'a RwLockReadGuard<'_, Option>>, +) -> turso_sync_engine::Result<&'a DatabaseSyncEngine> { + let Some(sync_engine) = sync_engine.as_ref() else { + let error = "sync_engine must be initialized".to_string(); + return Err(turso_sync_engine::errors::Error::DatabaseSyncEngineError( + error, + )); + }; + Ok(sync_engine) +} + +fn try_unwrap_mut<'a>( + sync_engine: &'a mut RwLockWriteGuard<'_, Option>>, +) -> turso_sync_engine::Result<&'a mut DatabaseSyncEngine> { + let Some(sync_engine) = sync_engine.as_mut() else { + let error = "sync_engine must be initialized".to_string(); + return Err(turso_sync_engine::errors::Error::DatabaseSyncEngineError( + error, + )); + }; + Ok(sync_engine) +} diff --git a/sync/javascript/sync_engine.ts b/sync/javascript/sync_engine.ts index dac182a76..fb9434f66 100644 --- a/sync/javascript/sync_engine.ts +++ b/sync/javascript/sync_engine.ts @@ -1,6 +1,6 @@ "use strict"; -import { SyncEngine } from '#entry-point'; +import { SyncEngine, DatabaseRowMutationJs, DatabaseRowStatementJs } from '#entry-point'; import { Database } from '@tursodatabase/database'; const GENERATOR_RESUME_IO = 0; @@ -63,9 +63,16 @@ async function process(opts, request) { const completion = request.completion(); if (requestType.type == 'Http') { try { + let headers = opts.headers; + if (requestType.headers != null && requestType.headers.length > 0) { + headers = { ...opts.headers }; + for (let header of requestType.headers) { + headers[header[0]] = header[1]; + } + } const response = await fetch(`${opts.url}${requestType.path}`, { method: requestType.method, - headers: opts.headers, + headers: headers, body: requestType.body != null ? new Uint8Array(requestType.body) : null, }); completion.status(response.status); @@ -101,7 +108,7 @@ async function process(opts, request) { } } -async function run(opts, engine, generator) { +async function run(opts, engine, generator): Promise { let tasks = []; while (generator.resume(null) !== GENERATOR_RESUME_DONE) { for (let request = engine.protocolIo(); request != null; request = engine.protocolIo()) { @@ -113,6 +120,7 @@ async function run(opts, engine, generator) { tasks = tasks.filter(t => !t.finished); } + return generator.take(); } interface ConnectOpts { @@ -121,16 +129,27 @@ interface ConnectOpts { url: string; authToken?: string; encryptionKey?: string; + tablesIgnore?: string[], + transform?: (arg: DatabaseRowMutationJs) => DatabaseRowStatementJs | null, + enableTracing?: string, } interface Sync { sync(): Promise; push(): Promise; pull(): Promise; + checkpoint(): Promise; + stats(): Promise<{ operations: number, wal: number }>; } export async function connect(opts: ConnectOpts): Database & Sync { - const engine = new SyncEngine({ path: opts.path, clientName: opts.clientName }); + const engine = new SyncEngine({ + path: opts.path, + clientName: opts.clientName, + tablesIgnore: opts.tablesIgnore, + transform: opts.transform, + enableTracing: opts.enableTracing + }); const httpOpts = { url: opts.url, headers: { @@ -147,5 +166,9 @@ export async function connect(opts: ConnectOpts): Database & Sync { db.sync = async function () { await run(httpOpts, engine, engine.sync()); } db.pull = async function () { await run(httpOpts, engine, engine.pull()); } db.push = async function () { await run(httpOpts, engine, engine.push()); } + db.checkpoint = async function () { await run(httpOpts, engine, engine.checkpoint()); } + db.stats = async function () { return (await run(httpOpts, engine, engine.stats())); } return db; } + +export { Database, Sync }; diff --git a/sync/javascript/turso-sync-js.wasi-browser.js b/sync/javascript/turso-sync-js.wasi-browser.js index 1f5b8c547..55e6a698d 100644 --- a/sync/javascript/turso-sync-js.wasi-browser.js +++ b/sync/javascript/turso-sync-js.wasi-browser.js @@ -64,3 +64,5 @@ export const JsDataPollResult = __napiModule.exports.JsDataPollResult export const JsProtocolIo = __napiModule.exports.JsProtocolIo export const JsProtocolRequestData = __napiModule.exports.JsProtocolRequestData export const SyncEngine = __napiModule.exports.SyncEngine +export const DatabaseChangeTypeJs = __napiModule.exports.DatabaseChangeTypeJs +export const SyncEngineProtocolVersion = __napiModule.exports.SyncEngineProtocolVersion diff --git a/sync/javascript/turso-sync-js.wasi.cjs b/sync/javascript/turso-sync-js.wasi.cjs index f5f3b1763..43e50a77f 100644 --- a/sync/javascript/turso-sync-js.wasi.cjs +++ b/sync/javascript/turso-sync-js.wasi.cjs @@ -116,3 +116,5 @@ module.exports.JsDataPollResult = __napiModule.exports.JsDataPollResult module.exports.JsProtocolIo = __napiModule.exports.JsProtocolIo module.exports.JsProtocolRequestData = __napiModule.exports.JsProtocolRequestData module.exports.SyncEngine = __napiModule.exports.SyncEngine +module.exports.DatabaseChangeTypeJs = __napiModule.exports.DatabaseChangeTypeJs +module.exports.SyncEngineProtocolVersion = __napiModule.exports.SyncEngineProtocolVersion