use std::sync::Arc; use bytes::BytesMut; use prost::Message; use turso_core::{ types::{Text, WalFrameInfo}, Buffer, Completion, LimboError, OpenFlags, Value, }; use crate::{ database_replay_generator::DatabaseReplayGenerator, database_sync_engine::DatabaseSyncEngineOpts, database_tape::{ run_stmt_expect_one_row, run_stmt_ignore_rows, DatabaseChangesIteratorMode, DatabaseChangesIteratorOpts, DatabaseReplaySessionOpts, DatabaseTape, DatabaseWalSession, }, errors::Error, io_operations::IoOperations, protocol_io::{DataCompletion, DataPollResult, ProtocolIO}, server_proto::{ self, ExecuteStreamReq, PageData, PageUpdatesEncodingReq, PullUpdatesReqProtoBody, PullUpdatesRespProtoBody, Stmt, StmtResult, StreamRequest, }, types::{ Coro, DatabasePullRevision, DatabaseSyncEngineProtocolVersion, DatabaseTapeOperation, DatabaseTapeRowChangeType, DbSyncInfo, DbSyncStatus, ProtocolCommand, }, wal_session::WalSession, Result, }; pub const WAL_HEADER: usize = 32; pub const WAL_FRAME_HEADER: usize = 24; pub const PAGE_SIZE: usize = 4096; pub const WAL_FRAME_SIZE: usize = WAL_FRAME_HEADER + PAGE_SIZE; enum WalHttpPullResult { Frames(C), NeedCheckpoint(DbSyncStatus), } pub enum WalPushResult { Ok { baton: Option }, NeedCheckpoint, } pub fn connect_untracked(tape: &DatabaseTape) -> Result> { let conn = tape.connect_untracked()?; conn.wal_auto_checkpoint_disable(); Ok(conn) } /// Bootstrap multiple DB files from latest generation from remote pub async fn db_bootstrap( coro: &Coro, client: &C, db: Arc, ) -> Result { tracing::debug!("db_bootstrap"); let start_time = std::time::Instant::now(); let db_info = db_info_http(coro, client).await?; tracing::debug!("db_bootstrap: fetched db_info={db_info:?}"); let content = db_bootstrap_http(coro, client, db_info.current_generation).await?; let mut pos = 0; loop { while let Some(chunk) = content.poll_data()? { let chunk = chunk.data(); let content_len = chunk.len(); // todo(sivukhin): optimize allocations here #[allow(clippy::arc_with_non_send_sync)] let buffer = Arc::new(Buffer::new_temporary(chunk.len())); buffer.as_mut_slice().copy_from_slice(chunk); let c = Completion::new_write(move |result| { // todo(sivukhin): we need to error out in case of partial read let Ok(size) = result else { return; }; assert!(size as usize == content_len); }); let c = db.pwrite(pos, buffer.clone(), c)?; while !c.is_completed() { coro.yield_(ProtocolCommand::IO).await?; } pos += content_len; } if content.is_done()? { break; } coro.yield_(ProtocolCommand::IO).await?; } // sync files in the end let c = Completion::new_sync(move |_| { // todo(sivukhin): we need to error out in case of failed sync }); let c = db.sync(c)?; while !c.is_completed() { coro.yield_(ProtocolCommand::IO).await?; } let elapsed = std::time::Instant::now().duration_since(start_time); tracing::debug!("db_bootstrap: finished: bytes={pos}, elapsed={:?}", elapsed); Ok(db_info) } pub async fn wal_apply_from_file( coro: &Coro, frames_file: Arc, session: &mut DatabaseWalSession, ) -> Result { let size = frames_file.size()?; assert!(size % WAL_FRAME_SIZE as u64 == 0); #[allow(clippy::arc_with_non_send_sync)] let buffer = Arc::new(Buffer::new_temporary(WAL_FRAME_SIZE)); tracing::debug!("wal_apply_from_file: size={}", size); let mut db_size = 0; for offset in (0..size).step_by(WAL_FRAME_SIZE) { let c = Completion::new_read(buffer.clone(), move |result| { let Ok((_, size)) = result else { return; }; // todo(sivukhin): we need to error out in case of partial read assert!(size as usize == WAL_FRAME_SIZE); }); let c = frames_file.pread(offset as usize, c)?; while !c.is_completed() { coro.yield_(ProtocolCommand::IO).await?; } let info = WalFrameInfo::from_frame_header(buffer.as_slice()); tracing::debug!("got frame: {:?}", info); db_size = info.db_size; session.append_page(info.page_no, &buffer.as_slice()[WAL_FRAME_HEADER..])?; } assert!(db_size > 0); Ok(db_size) } pub async fn wal_pull_to_file( coro: &Coro, client: &C, frames_file: Arc, revision: &DatabasePullRevision, wal_pull_batch_size: u64, ) -> Result { match revision { DatabasePullRevision::Legacy { generation, synced_frame_no, } => { let start_frame = synced_frame_no.unwrap_or(0) + 1; wal_pull_to_file_legacy( coro, client, frames_file, *generation, start_frame, wal_pull_batch_size, ) .await } DatabasePullRevision::V1 { revision } => { wal_pull_to_file_v1(coro, client, frames_file, revision).await } } } /// Pull updates from remote to the separate file pub async fn wal_pull_to_file_v1( coro: &Coro, client: &C, frames_file: Arc, revision: &str, ) -> Result { tracing::info!("wal_pull: revision={revision}"); let mut bytes = BytesMut::new(); let request = PullUpdatesReqProtoBody { encoding: PageUpdatesEncodingReq::Raw as i32, server_revision: String::new(), client_revision: revision.to_string(), long_poll_timeout_ms: 0, server_pages: BytesMut::new().into(), client_pages: BytesMut::new().into(), }; let request = request.encode_to_vec(); let completion = client.http( "POST", "/pull-updates", Some(request), &[ ("content-type", "application/protobuf"), ("accept-encoding", "application/protobuf"), ], )?; let Some(header) = wait_proto_message::(coro, &completion, &mut bytes).await? else { return Err(Error::DatabaseSyncEngineError( "no header returned in the pull-updates protobuf call".to_string(), )); }; tracing::info!("wal_pull_to_file: got header={:?}", header); let mut offset = 0; #[allow(clippy::arc_with_non_send_sync)] let buffer = Arc::new(Buffer::new_temporary(WAL_FRAME_SIZE)); let mut page_data_opt = wait_proto_message::(coro, &completion, &mut bytes).await?; while let Some(page_data) = page_data_opt.take() { let page_id = page_data.page_id; tracing::info!("received page {}", page_id); let page = decode_page(&header, page_data)?; if page.len() != PAGE_SIZE { return Err(Error::DatabaseSyncEngineError(format!( "page has unexpected size: {} != {}", page.len(), PAGE_SIZE ))); } buffer.as_mut_slice()[WAL_FRAME_HEADER..].copy_from_slice(&page); page_data_opt = wait_proto_message(coro, &completion, &mut bytes).await?; let mut frame_info = WalFrameInfo { db_size: 0, page_no: page_id as u32 + 1, }; if page_data_opt.is_none() { frame_info.db_size = header.db_size as u32; } tracing::info!("page_data_opt: {}", page_data_opt.is_some()); frame_info.put_to_frame_header(buffer.as_mut_slice()); let c = Completion::new_write(move |result| { // todo(sivukhin): we need to error out in case of partial read let Ok(size) = result else { return; }; assert!(size as usize == WAL_FRAME_SIZE); }); let c = frames_file.pwrite(offset, buffer.clone(), c)?; while !c.is_completed() { coro.yield_(ProtocolCommand::IO).await?; } offset += WAL_FRAME_SIZE; } let c = Completion::new_sync(move |_| { // todo(sivukhin): we need to error out in case of failed sync }); let c = frames_file.sync(c)?; while !c.is_completed() { coro.yield_(ProtocolCommand::IO).await?; } Ok(DatabasePullRevision::V1 { revision: header.server_revision, }) } /// Pull updates from remote to the separate file pub async fn wal_pull_to_file_legacy( coro: &Coro, client: &C, frames_file: Arc, mut generation: u64, mut start_frame: u64, wal_pull_batch_size: u64, ) -> Result { tracing::info!( "wal_pull: generation={generation}, start_frame={start_frame}, wal_pull_batch_size={wal_pull_batch_size}" ); // todo(sivukhin): optimize allocation by using buffer pool in the DatabaseSyncOperations #[allow(clippy::arc_with_non_send_sync)] let buffer = Arc::new(Buffer::new_temporary(WAL_FRAME_SIZE)); let mut buffer_len = 0; let mut last_offset = 0; let mut committed_len = 0; let revision = loop { let end_frame = start_frame + wal_pull_batch_size; let result = wal_pull_http(coro, client, generation, start_frame, end_frame).await?; let data = match result { WalHttpPullResult::NeedCheckpoint(status) => { assert!(status.status == "checkpoint_needed"); tracing::debug!("wal_pull: need checkpoint: status={status:?}"); if status.generation == generation && status.max_frame_no < start_frame { tracing::debug!("wal_pull: end of history: status={:?}", status); break DatabasePullRevision::Legacy { generation: status.generation, synced_frame_no: Some(status.max_frame_no), }; } generation += 1; start_frame = 1; continue; } WalHttpPullResult::Frames(content) => content, }; loop { while let Some(chunk) = data.poll_data()? { let mut chunk = chunk.data(); while !chunk.is_empty() { let to_fill = (WAL_FRAME_SIZE - buffer_len).min(chunk.len()); buffer.as_mut_slice()[buffer_len..buffer_len + to_fill] .copy_from_slice(&chunk[0..to_fill]); buffer_len += to_fill; chunk = &chunk[to_fill..]; if buffer_len < WAL_FRAME_SIZE { continue; } let c = Completion::new_write(move |result| { // todo(sivukhin): we need to error out in case of partial read let Ok(size) = result else { return; }; assert!(size as usize == WAL_FRAME_SIZE); }); let c = frames_file.pwrite(last_offset, buffer.clone(), c)?; while !c.is_completed() { coro.yield_(ProtocolCommand::IO).await?; } last_offset += WAL_FRAME_SIZE; buffer_len = 0; start_frame += 1; let info = WalFrameInfo::from_frame_header(buffer.as_slice()); if info.is_commit_frame() { committed_len = last_offset; } } } if data.is_done()? { break; } coro.yield_(ProtocolCommand::IO).await?; } if start_frame < end_frame { // chunk which was sent from the server has ended early - so there is nothing left on server-side for pull break DatabasePullRevision::Legacy { generation, synced_frame_no: Some(start_frame - 1), }; } if buffer_len != 0 { return Err(Error::DatabaseSyncEngineError(format!( "wal_pull: response has unexpected trailing data: buffer_len={buffer_len}" ))); } }; tracing::info!( "wal_pull: generation={generation}, frame={start_frame}, last_offset={last_offset}, commited_len={committed_len}" ); let c = Completion::new_trunc(move |result| { let Ok(rc) = result else { return; }; assert!(rc as usize == 0); }); let c = frames_file.truncate(committed_len, c)?; while !c.is_completed() { coro.yield_(ProtocolCommand::IO).await?; } let c = Completion::new_sync(move |_| { // todo(sivukhin): we need to error out in case of failed sync }); let c = frames_file.sync(c)?; while !c.is_completed() { coro.yield_(ProtocolCommand::IO).await?; } Ok(revision) } /// Push frame range [start_frame..end_frame) to the remote /// Returns baton for WAL remote-session in case of success /// Returns [Error::DatabaseSyncEngineConflict] in case of frame conflict at remote side /// /// Guarantees: /// 1. If there is a single client which calls wal_push, then this operation is idempotent for fixed generation /// and can be called multiple times with same frame range pub async fn wal_push( coro: &Coro, client: &C, wal_session: &mut WalSession, baton: Option, generation: u64, start_frame: u64, end_frame: u64, ) -> Result { assert!(wal_session.in_txn()); tracing::info!("wal_push: baton={baton:?}, generation={generation}, start_frame={start_frame}, end_frame={end_frame}"); if start_frame == end_frame { return Ok(WalPushResult::Ok { baton: None }); } let mut frames_data = Vec::with_capacity((end_frame - start_frame) as usize * WAL_FRAME_SIZE); let mut buffer = [0u8; WAL_FRAME_SIZE]; for frame_no in start_frame..end_frame { let frame_info = wal_session.read_at(frame_no, &mut buffer)?; tracing::trace!( "wal_push: collect frame {} ({:?}) for push", frame_no, frame_info ); frames_data.extend_from_slice(&buffer); } let status = wal_push_http( coro, client, None, generation, start_frame, end_frame, frames_data, ) .await?; if status.status == "ok" { Ok(WalPushResult::Ok { baton: status.baton, }) } else if status.status == "checkpoint_needed" { Ok(WalPushResult::NeedCheckpoint) } else if status.status == "conflict" { Err(Error::DatabaseSyncEngineConflict(format!( "wal_push conflict: {status:?}" ))) } else { Err(Error::DatabaseSyncEngineError(format!( "wal_push unexpected status: {status:?}" ))) } } pub const TURSO_SYNC_TABLE_NAME: &str = "turso_sync_last_change_id"; pub const TURSO_SYNC_CREATE_TABLE: &str = "CREATE TABLE IF NOT EXISTS turso_sync_last_change_id (client_id TEXT PRIMARY KEY, pull_gen INTEGER, change_id INTEGER)"; pub const TURSO_SYNC_INSERT_LAST_CHANGE_ID: &str = "INSERT INTO turso_sync_last_change_id(client_id, pull_gen, change_id) VALUES (?, ?, ?)"; pub const TURSO_SYNC_UPSERT_LAST_CHANGE_ID: &str = "INSERT INTO turso_sync_last_change_id(client_id, pull_gen, change_id) VALUES (?, ?, ?) ON CONFLICT(client_id) DO UPDATE SET pull_gen=excluded.pull_gen, change_id=excluded.change_id"; pub const TURSO_SYNC_UPDATE_LAST_CHANGE_ID: &str = "UPDATE turso_sync_last_change_id SET pull_gen = ?, change_id = ? WHERE client_id = ?"; const TURSO_SYNC_SELECT_LAST_CHANGE_ID: &str = "SELECT pull_gen, change_id FROM turso_sync_last_change_id WHERE client_id = ?"; fn convert_to_args(values: Vec) -> Vec { values .into_iter() .map(|value| match value { Value::Null => server_proto::Value::Null, Value::Integer(value) => server_proto::Value::Integer { value }, Value::Float(value) => server_proto::Value::Float { value }, Value::Text(value) => server_proto::Value::Text { value: value.as_str().to_string(), }, Value::Blob(value) => server_proto::Value::Blob { value: value.into(), }, }) .collect() } pub async fn has_table( coro: &Coro, conn: &Arc, table_name: &str, ) -> Result { let mut stmt = conn.prepare("SELECT COUNT(*) FROM sqlite_schema WHERE type = 'table' AND name = ?")?; stmt.bind_at(1.try_into().unwrap(), Value::Text(Text::new(table_name))); let count = match run_stmt_expect_one_row(coro, &mut stmt).await? { Some(row) => row[0] .as_int() .ok_or_else(|| Error::DatabaseSyncEngineError("unexpected column type".to_string()))?, _ => panic!("expected single row"), }; Ok(count > 0) } pub async fn count_local_changes( coro: &Coro, conn: &Arc, change_id: i64, ) -> Result { let mut stmt = conn.prepare("SELECT COUNT(*) FROM turso_cdc WHERE change_id > ?")?; stmt.bind_at(1.try_into().unwrap(), Value::Integer(change_id)); let count = match run_stmt_expect_one_row(coro, &mut stmt).await? { Some(row) => row[0] .as_int() .ok_or_else(|| Error::DatabaseSyncEngineError("unexpected column type".to_string()))?, _ => panic!("expected single row"), }; Ok(count) } pub async fn update_last_change_id( coro: &Coro, conn: &Arc, client_id: &str, pull_gen: i64, change_id: i64, ) -> Result<()> { tracing::info!( "update_last_change_id(client_id={client_id}): pull_gen={pull_gen}, change_id={change_id}" ); conn.execute(TURSO_SYNC_CREATE_TABLE)?; tracing::info!("update_last_change_id(client_id={client_id}): initialized table"); let mut select_stmt = conn.prepare(TURSO_SYNC_SELECT_LAST_CHANGE_ID)?; select_stmt.bind_at( 1.try_into().unwrap(), turso_core::Value::Text(turso_core::types::Text::new(client_id)), ); let row = run_stmt_expect_one_row(coro, &mut select_stmt).await?; tracing::info!("update_last_change_id(client_id={client_id}): selected client row if any"); if row.is_some() { let mut update_stmt = conn.prepare(TURSO_SYNC_UPDATE_LAST_CHANGE_ID)?; update_stmt.bind_at(1.try_into().unwrap(), turso_core::Value::Integer(pull_gen)); update_stmt.bind_at(2.try_into().unwrap(), turso_core::Value::Integer(change_id)); update_stmt.bind_at( 3.try_into().unwrap(), turso_core::Value::Text(turso_core::types::Text::new(client_id)), ); run_stmt_ignore_rows(coro, &mut update_stmt).await?; tracing::info!("update_last_change_id(client_id={client_id}): updated row for the client"); } else { let mut update_stmt = conn.prepare(TURSO_SYNC_INSERT_LAST_CHANGE_ID)?; update_stmt.bind_at( 1.try_into().unwrap(), turso_core::Value::Text(turso_core::types::Text::new(client_id)), ); update_stmt.bind_at(2.try_into().unwrap(), turso_core::Value::Integer(pull_gen)); update_stmt.bind_at(3.try_into().unwrap(), turso_core::Value::Integer(change_id)); run_stmt_ignore_rows(coro, &mut update_stmt).await?; tracing::info!( "update_last_change_id(client_id={client_id}): inserted new row for the client" ); } Ok(()) } pub async fn fetch_last_change_id( coro: &Coro, client: &C, source_conn: &Arc, client_id: &str, ) -> Result<(i64, Option)> { tracing::info!("fetch_last_change_id: client_id={client_id}"); // fetch last_change_id from the target DB in order to guarantee atomic replay of changes and avoid conflicts in case of failure let source_pull_gen = 'source_pull_gen: { let mut select_last_change_id_stmt = match source_conn.prepare(TURSO_SYNC_SELECT_LAST_CHANGE_ID) { Ok(stmt) => stmt, Err(LimboError::ParseError(..)) => break 'source_pull_gen 0, Err(err) => return Err(err.into()), }; select_last_change_id_stmt .bind_at(1.try_into().unwrap(), Value::Text(Text::new(client_id))); match run_stmt_expect_one_row(coro, &mut select_last_change_id_stmt).await? { Some(row) => row[0].as_int().ok_or_else(|| { Error::DatabaseSyncEngineError("unexpected source pull_gen type".to_string()) })?, None => { tracing::info!("fetch_last_change_id: client_id={client_id}, turso_sync_last_change_id table is not found"); 0 } } }; tracing::info!( "fetch_last_change_id: client_id={client_id}, source_pull_gen={source_pull_gen}" ); // fetch last_change_id from the target DB in order to guarantee atomic replay of changes and avoid conflicts in case of failure let init_hrana_request = server_proto::PipelineReqBody { baton: None, requests: vec![ // read pull_gen, change_id values for current client if they were set before StreamRequest::Execute(ExecuteStreamReq { stmt: Stmt { sql: Some(TURSO_SYNC_SELECT_LAST_CHANGE_ID.to_string()), sql_id: None, args: vec![server_proto::Value::Text { value: client_id.to_string(), }], named_args: Vec::new(), want_rows: Some(true), replication_index: None, }, }), ] .into(), }; let response = match sql_execute_http(coro, client, init_hrana_request).await { Ok(response) => response, Err(Error::DatabaseSyncEngineError(err)) if err.contains("no such table") => { return Ok((source_pull_gen, None)); } Err(err) => return Err(err), }; assert!(response.len() == 1); let last_change_id_response = &response[0]; tracing::debug!("fetch_last_change_id: response={:?}", response); assert!(last_change_id_response.rows.len() <= 1); if last_change_id_response.rows.is_empty() { return Ok((source_pull_gen, None)); } let row = &last_change_id_response.rows[0].values; let server_proto::Value::Integer { value: target_pull_gen, } = row[0] else { return Err(Error::DatabaseSyncEngineError( "unexpected target pull_gen type".to_string(), )); }; let server_proto::Value::Integer { value: target_change_id, } = row[1] else { return Err(Error::DatabaseSyncEngineError( "unexpected target change_id type".to_string(), )); }; tracing::debug!( "fetch_last_change_id: client_id={client_id}, target_pull_gen={target_pull_gen}, target_change_id={target_change_id}" ); if target_pull_gen > source_pull_gen { return Err(Error::DatabaseSyncEngineError(format!("protocol error: target_pull_gen > source_pull_gen: {target_pull_gen} > {source_pull_gen}"))); } let last_change_id = if target_pull_gen == source_pull_gen { Some(target_change_id) } else { Some(0) }; Ok((source_pull_gen, last_change_id)) } pub async fn push_logical_changes( coro: &Coro, client: &C, source: &DatabaseTape, client_id: &str, opts: &DatabaseSyncEngineOpts, ) -> Result<(i64, i64)> { tracing::info!("push_logical_changes: client_id={client_id}"); let source_conn = connect_untracked(source)?; let (source_pull_gen, mut last_change_id) = fetch_last_change_id(coro, client, &source_conn, client_id).await?; tracing::debug!("push_logical_changes: last_change_id={:?}", last_change_id); let replay_opts = DatabaseReplaySessionOpts { use_implicit_rowid: false, transform: None, }; let generator = DatabaseReplayGenerator::new(source_conn, replay_opts); let iterate_opts = DatabaseChangesIteratorOpts { first_change_id: last_change_id.map(|x| x + 1), mode: DatabaseChangesIteratorMode::Apply, ignore_schema_changes: false, ..Default::default() }; let mut sql_over_http_requests = vec![ Stmt { sql: Some("BEGIN IMMEDIATE".to_string()), sql_id: None, args: Vec::new(), named_args: Vec::new(), want_rows: Some(false), replication_index: None, }, Stmt { sql: Some(TURSO_SYNC_CREATE_TABLE.to_string()), sql_id: None, args: Vec::new(), named_args: Vec::new(), want_rows: Some(false), replication_index: None, }, ]; let mut rows_changed = 0; let mut changes = source.iterate_changes(iterate_opts)?; while let Some(operation) = changes.next(coro).await? { match operation { DatabaseTapeOperation::RowChange(change) => { assert!( last_change_id.is_none() || last_change_id.unwrap() < change.change_id, "change id must be strictly increasing: last_change_id={:?}, change.change_id={}", last_change_id, change.change_id ); if change.table_name == TURSO_SYNC_TABLE_NAME { continue; } let ignore = &opts.tables_ignore; if ignore.iter().any(|x| &change.table_name == x) { continue; } rows_changed += 1; // we give user full control over CDC table - so let's not emit assert here for now if last_change_id.is_some() && last_change_id.unwrap() + 1 != change.change_id { tracing::warn!( "out of order change sequence: {} -> {}", last_change_id.unwrap(), change.change_id ); } last_change_id = Some(change.change_id); let replay_info = generator.replay_info(coro, &change).await?; if !replay_info.is_ddl_replay { if let Some(transform) = &opts.transform { let mutation = generator.create_mutation(&replay_info, &change)?; if let Some(statement) = transform(&coro.ctx.borrow(), mutation)? { tracing::info!( "push_logical_changes: use mutation from custom transformer: sql={}, values={:?}", statement.sql, statement.values ); sql_over_http_requests.push(Stmt { sql: Some(statement.sql), sql_id: None, args: convert_to_args(statement.values), named_args: Vec::new(), want_rows: Some(false), replication_index: None, }); continue; } } } let change_type = (&change.change).into(); match change.change { DatabaseTapeRowChangeType::Delete { before } => { let values = generator.replay_values( &replay_info, change_type, change.id, before, None, ); sql_over_http_requests.push(Stmt { sql: Some(replay_info.query.clone()), sql_id: None, args: convert_to_args(values), named_args: Vec::new(), want_rows: Some(false), replication_index: None, }) } DatabaseTapeRowChangeType::Insert { after } => { let values = generator.replay_values( &replay_info, change_type, change.id, after, None, ); sql_over_http_requests.push(Stmt { sql: Some(replay_info.query.clone()), sql_id: None, args: convert_to_args(values), named_args: Vec::new(), want_rows: Some(false), replication_index: None, }) } DatabaseTapeRowChangeType::Update { after, updates: Some(updates), .. } => { let values = generator.replay_values( &replay_info, change_type, change.id, after, Some(updates), ); sql_over_http_requests.push(Stmt { sql: Some(replay_info.query.clone()), sql_id: None, args: convert_to_args(values), named_args: Vec::new(), want_rows: Some(false), replication_index: None, }) } DatabaseTapeRowChangeType::Update { after, updates: None, .. } => { let values = generator.replay_values( &replay_info, change_type, change.id, after, None, ); sql_over_http_requests.push(Stmt { sql: Some(replay_info.query.clone()), sql_id: None, args: convert_to_args(values), named_args: Vec::new(), want_rows: Some(false), replication_index: None, }); } } } DatabaseTapeOperation::Commit => { if rows_changed > 0 { tracing::info!("prepare update stmt for turso_sync_last_change_id table with client_id={} and last_change_id={:?}", client_id, last_change_id); // update turso_sync_last_change_id table with new value before commit let next_change_id = last_change_id.unwrap_or(0); tracing::info!("push_logical_changes: client_id={client_id}, set pull_gen={source_pull_gen}, change_id={next_change_id}, rows_changed={rows_changed}"); sql_over_http_requests.push(Stmt { sql: Some(TURSO_SYNC_UPSERT_LAST_CHANGE_ID.to_string()), sql_id: None, args: vec![ server_proto::Value::Text { value: client_id.to_string(), }, server_proto::Value::Integer { value: source_pull_gen, }, server_proto::Value::Integer { value: next_change_id, }, ], named_args: Vec::new(), want_rows: Some(false), replication_index: None, }); } sql_over_http_requests.push(Stmt { sql: Some("COMMIT".to_string()), sql_id: None, args: Vec::new(), named_args: Vec::new(), want_rows: Some(false), replication_index: None, }); } } } tracing::trace!("hrana request: {:?}", sql_over_http_requests); let replay_hrana_request = server_proto::PipelineReqBody { baton: None, requests: sql_over_http_requests .into_iter() .map(|stmt| StreamRequest::Execute(ExecuteStreamReq { stmt })) .collect(), }; let _ = sql_execute_http(coro, client, replay_hrana_request).await?; tracing::info!("push_logical_changes: rows_changed={:?}", rows_changed); Ok((source_pull_gen, last_change_id.unwrap_or(0))) } pub async fn read_wal_salt( coro: &Coro, wal: &Arc, ) -> Result>> { #[allow(clippy::arc_with_non_send_sync)] let buffer = Arc::new(Buffer::new_temporary(WAL_HEADER)); let c = Completion::new_read(buffer.clone(), |result| { let Ok((buffer, len)) = result else { return; }; if (len as usize) < WAL_HEADER { buffer.as_mut_slice().fill(0); } }); let c = wal.pread(0, c)?; while !c.is_completed() { coro.yield_(ProtocolCommand::IO).await?; } if buffer.as_mut_slice() == [0u8; WAL_HEADER] { return Ok(None); } let salt1 = u32::from_be_bytes(buffer.as_slice()[16..20].try_into().unwrap()); let salt2 = u32::from_be_bytes(buffer.as_slice()[20..24].try_into().unwrap()); Ok(Some(vec![salt1, salt2])) } pub async fn checkpoint_wal_file( coro: &Coro, conn: &Arc, ) -> Result<()> { let mut checkpoint_stmt = conn.prepare("PRAGMA wal_checkpoint(TRUNCATE)")?; loop { match checkpoint_stmt.step()? { turso_core::StepResult::IO => coro.yield_(ProtocolCommand::IO).await?, turso_core::StepResult::Done => break, turso_core::StepResult::Row => continue, r => { return Err(Error::DatabaseSyncEngineError(format!( "unexepcted checkpoint result: {r:?}" ))) } } } Ok(()) } pub async fn bootstrap_db_file( coro: &Coro, client: &C, io: &Arc, main_db_path: &str, protocol: DatabaseSyncEngineProtocolVersion, ) -> Result { match protocol { DatabaseSyncEngineProtocolVersion::Legacy => { bootstrap_db_file_legacy(coro, client, io, main_db_path).await } DatabaseSyncEngineProtocolVersion::V1 => { bootstrap_db_file_v1(coro, client, io, main_db_path).await } } } pub async fn bootstrap_db_file_v1( coro: &Coro, client: &C, io: &Arc, main_db_path: &str, ) -> Result { let mut bytes = BytesMut::new(); let completion = client.http( "GET", "/pull-updates", None, &[ ("content-type", "application/protobuf"), ("accept-encoding", "application/protobuf"), ], )?; let Some(header) = wait_proto_message::(coro, &completion, &mut bytes).await? else { return Err(Error::DatabaseSyncEngineError( "no header returned in the pull-updates protobuf call".to_string(), )); }; tracing::info!( "bootstrap_db_file(path={}): got header={:?}", main_db_path, header ); let file = io.open_file(main_db_path, OpenFlags::Create, false)?; let c = Completion::new_trunc(move |result| { let Ok(rc) = result else { return; }; assert!(rc as usize == 0); }); let c = file.truncate(header.db_size as usize * PAGE_SIZE, c)?; while !c.is_completed() { coro.yield_(ProtocolCommand::IO).await?; } #[allow(clippy::arc_with_non_send_sync)] let buffer = Arc::new(Buffer::new_temporary(PAGE_SIZE)); while let Some(page_data) = wait_proto_message::(coro, &completion, &mut bytes).await? { let offset = page_data.page_id as usize * PAGE_SIZE; let page = decode_page(&header, page_data)?; if page.len() != PAGE_SIZE { return Err(Error::DatabaseSyncEngineError(format!( "page has unexpected size: {} != {}", page.len(), PAGE_SIZE ))); } buffer.as_mut_slice().copy_from_slice(&page); let c = Completion::new_write(move |result| { // todo(sivukhin): we need to error out in case of partial read let Ok(size) = result else { return; }; assert!(size as usize == PAGE_SIZE); }); let c = file.pwrite(offset, buffer.clone(), c)?; while !c.is_completed() { coro.yield_(ProtocolCommand::IO).await?; } } Ok(DatabasePullRevision::V1 { revision: header.server_revision, }) } fn decode_page(header: &PullUpdatesRespProtoBody, page_data: PageData) -> Result> { if header.raw_encoding.is_some() && header.zstd_encoding.is_some() { return Err(Error::DatabaseSyncEngineError( "both of raw_encoding and zstd_encoding are set".to_string(), )); } if header.raw_encoding.is_none() && header.zstd_encoding.is_none() { return Err(Error::DatabaseSyncEngineError( "none from raw_encoding and zstd_encoding are set".to_string(), )); } if header.raw_encoding.is_some() { return Ok(page_data.encoded_page.to_vec()); } Err(Error::DatabaseSyncEngineError( "zstd encoding is not supported".to_string(), )) } pub async fn bootstrap_db_file_legacy( coro: &Coro, client: &C, io: &Arc, main_db_path: &str, ) -> Result { tracing::info!("bootstrap_db_file(path={})", main_db_path); let start_time = std::time::Instant::now(); // cleanup all files left from previous attempt to bootstrap // we shouldn't write any WAL files - but let's truncate them too for safety if let Some(file) = io.try_open(main_db_path)? { io.truncate(coro, file, 0).await?; } if let Some(file) = io.try_open(&format!("{main_db_path}-wal"))? { io.truncate(coro, file, 0).await?; } let file = io.create(main_db_path)?; let db_info = db_bootstrap(coro, client, file).await?; let elapsed = std::time::Instant::now().duration_since(start_time); tracing::info!( "bootstrap_db_files(path={}): finished: elapsed={:?}", main_db_path, elapsed ); Ok(DatabasePullRevision::Legacy { generation: db_info.current_generation, synced_frame_no: None, }) } pub async fn reset_wal_file( coro: &Coro, wal: Arc, frames_count: u64, ) -> Result<()> { let wal_size = if frames_count == 0 { // let's truncate WAL file completely in order for this operation to safely execute on empty WAL in case of initial bootstrap phase 0 } else { WAL_HEADER + WAL_FRAME_SIZE * (frames_count as usize) }; tracing::debug!("reset db wal to the size of {} frames", frames_count); let c = Completion::new_trunc(move |result| { let Ok(rc) = result else { return; }; assert!(rc as usize == 0); }); let c = wal.truncate(wal_size, c)?; while !c.is_completed() { coro.yield_(ProtocolCommand::IO).await?; } Ok(()) } async fn sql_execute_http( coro: &Coro, client: &C, request: server_proto::PipelineReqBody, ) -> Result> { let body = serde_json::to_vec(&request)?; let completion = client.http("POST", "/v2/pipeline", Some(body), &[])?; let status = wait_status(coro, &completion).await?; if status != http::StatusCode::OK { let error = format!("sql_execute_http: unexpected status code: {status}"); return Err(Error::DatabaseSyncEngineError(error)); } let response = wait_full_body(coro, &completion).await?; let response: server_proto::PipelineRespBody = serde_json::from_slice(&response)?; tracing::debug!("hrana response: {:?}", response); let mut results = Vec::new(); for result in response.results { match result { server_proto::StreamResult::Error { error } => { return Err(Error::DatabaseSyncEngineError(format!( "failed to execute sql: {error:?}" ))) } server_proto::StreamResult::None => { return Err(Error::DatabaseSyncEngineError( "unexpected None result".to_string(), )); } server_proto::StreamResult::Ok { response } => match response { server_proto::StreamResponse::Execute(execute) => { results.push(execute.result); } }, } } Ok(results) } async fn wal_pull_http( coro: &Coro, client: &C, generation: u64, start_frame: u64, end_frame: u64, ) -> Result> { let completion = client.http( "GET", &format!("/sync/{generation}/{start_frame}/{end_frame}"), None, &[], )?; let status = wait_status(coro, &completion).await?; if status == http::StatusCode::BAD_REQUEST { let status_body = wait_full_body(coro, &completion).await?; let status: DbSyncStatus = serde_json::from_slice(&status_body)?; if status.status == "checkpoint_needed" { return Ok(WalHttpPullResult::NeedCheckpoint(status)); } else { let error = format!("wal_pull: unexpected sync status: {status:?}"); return Err(Error::DatabaseSyncEngineError(error)); } } if status != http::StatusCode::OK { let error = format!("wal_pull: unexpected status code: {status}"); return Err(Error::DatabaseSyncEngineError(error)); } Ok(WalHttpPullResult::Frames(completion)) } async fn wal_push_http( coro: &Coro, client: &C, baton: Option, generation: u64, start_frame: u64, end_frame: u64, frames: Vec, ) -> Result { let baton = baton .map(|baton| format!("/{baton}")) .unwrap_or("".to_string()); let completion = client.http( "POST", &format!("/sync/{generation}/{start_frame}/{end_frame}{baton}"), Some(frames), &[], )?; let status = wait_status(coro, &completion).await?; let status_body = wait_full_body(coro, &completion).await?; if status != http::StatusCode::OK { let error = std::str::from_utf8(&status_body).ok().unwrap_or(""); return Err(Error::DatabaseSyncEngineError(format!( "wal_push go unexpected status: {status} (error={error})" ))); } Ok(serde_json::from_slice(&status_body)?) } async fn db_info_http(coro: &Coro, client: &C) -> Result { let completion = client.http("GET", "/info", None, &[])?; let status = wait_status(coro, &completion).await?; let status_body = wait_full_body(coro, &completion).await?; if status != http::StatusCode::OK { return Err(Error::DatabaseSyncEngineError(format!( "db_info go unexpected status: {status}" ))); } Ok(serde_json::from_slice(&status_body)?) } async fn db_bootstrap_http( coro: &Coro, client: &C, generation: u64, ) -> Result { let completion = client.http("GET", &format!("/export/{generation}"), None, &[])?; let status = wait_status(coro, &completion).await?; if status != http::StatusCode::OK.as_u16() { return Err(Error::DatabaseSyncEngineError(format!( "db_bootstrap go unexpected status: {status}" ))); } Ok(completion) } pub async fn wait_status(coro: &Coro, completion: &impl DataCompletion) -> Result { while completion.status()?.is_none() { coro.yield_(ProtocolCommand::IO).await?; } Ok(completion.status()?.unwrap()) } #[inline(always)] pub fn read_varint(buf: &[u8]) -> Result> { let mut v: u64 = 0; for i in 0..9 { match buf.get(i) { Some(c) => { v |= ((c & 0x7f) as u64) << (i * 7); if (c & 0x80) == 0 { return Ok(Some((v as usize, i + 1))); } } None => return Ok(None), } } Err(Error::DatabaseSyncEngineError(format!( "invalid variant byte: {:?}", &buf[0..=8] ))) } pub async fn wait_proto_message( coro: &Coro, completion: &impl DataCompletion, bytes: &mut BytesMut, ) -> Result> { let start_time = std::time::Instant::now(); loop { let length = read_varint(bytes)?; let not_enough_bytes = match length { None => true, Some((message_length, prefix_length)) => message_length + prefix_length > bytes.len(), }; if not_enough_bytes { if let Some(poll) = completion.poll_data()? { bytes.extend_from_slice(poll.data()); } else if !completion.is_done()? { coro.yield_(ProtocolCommand::IO).await?; } else if bytes.is_empty() { return Ok(None); } else { return Err(Error::DatabaseSyncEngineError( "unexpected end of protobuf message".to_string(), )); } continue; } let (message_length, prefix_length) = length.unwrap(); let message = T::decode_length_delimited(&**bytes).map_err(|e| { Error::DatabaseSyncEngineError(format!("unable to deserialize protobuf message: {e}")) })?; let _ = bytes.split_to(message_length + prefix_length); tracing::debug!( "wait_proto_message: elapsed={:?}", std::time::Instant::now().duration_since(start_time) ); return Ok(Some(message)); } } pub async fn wait_full_body( coro: &Coro, completion: &impl DataCompletion, ) -> Result> { let mut bytes = Vec::new(); loop { while let Some(poll) = completion.poll_data()? { bytes.extend_from_slice(poll.data()); } if completion.is_done()? { break; } coro.yield_(ProtocolCommand::IO).await?; } Ok(bytes) } #[cfg(test)] mod tests { use std::cell::RefCell; use bytes::{Bytes, BytesMut}; use prost::Message; use crate::{ database_sync_operations::wait_proto_message, protocol_io::{DataCompletion, DataPollResult}, server_proto::PageData, types::Coro, Result, }; struct TestPollResult(Vec); impl DataPollResult for TestPollResult { fn data(&self) -> &[u8] { &self.0 } } struct TestCompletion { data: RefCell, chunk: usize, } impl DataCompletion for TestCompletion { type DataPollResult = TestPollResult; fn status(&self) -> crate::Result> { Ok(Some(200)) } fn poll_data(&self) -> crate::Result> { let mut data = self.data.borrow_mut(); let len = data.len(); let chunk = data.split_to(len.min(self.chunk)); if chunk.is_empty() { Ok(None) } else { Ok(Some(TestPollResult(chunk.to_vec()))) } } fn is_done(&self) -> crate::Result { Ok(self.data.borrow().is_empty()) } } #[test] pub fn wait_proto_message_test() { let mut data = Vec::new(); for i in 0..1024 { let page = PageData { page_id: i as u64, encoded_page: vec![0u8; 16 * 1024].into(), }; data.extend_from_slice(&page.encode_length_delimited_to_vec()); } let completion = TestCompletion { data: RefCell::new(data.into()), chunk: 128, }; let mut gen = genawaiter::sync::Gen::new({ |coro| async move { let coro: Coro<()> = coro.into(); let mut bytes = BytesMut::new(); let mut count = 0; while wait_proto_message::<(), PageData>(&coro, &completion, &mut bytes) .await? .is_some() { assert!(bytes.capacity() <= 16 * 1024 + 1024); count += 1; } assert_eq!(count, 1024); Result::Ok(()) } }); loop { match gen.resume_with(Ok(())) { genawaiter::GeneratorState::Yielded(..) => {} genawaiter::GeneratorState::Complete(result) => break result.unwrap(), } } } }