diff --git a/packages/turso-sync/src/sync_server/test.rs b/packages/turso-sync/src/sync_server/test.rs new file mode 100644 index 000000000..c0c13a9ba --- /dev/null +++ b/packages/turso-sync/src/sync_server/test.rs @@ -0,0 +1,283 @@ +use std::{collections::HashMap, path::Path, sync::Arc}; + +use tokio::sync::Mutex; +use turso::{IntoParams, Value}; + +use crate::{ + errors::Error, + sync_server::{DbSyncInfo, DbSyncStatus, Stream, SyncServer}, + test_context::TestContext, + Result, +}; + +struct Generation { + snapshot: Vec, + frames: Vec>, +} + +#[derive(Clone)] +struct SyncSession { + baton: String, + conn: turso::Connection, + in_txn: bool, +} + +struct TestSyncServerState { + generation: usize, + generations: HashMap, + sessions: HashMap, +} + +#[derive(Debug, Clone)] +pub struct TestSyncServerOpts { + pub pull_batch_size: usize, +} + +#[derive(Clone)] +pub struct TestSyncServer { + ctx: Arc, + db: turso::Database, + opts: Arc, + state: Arc>, +} + +pub struct TestStream { + ctx: Arc, + data: Vec, + position: usize, +} + +impl TestStream { + pub fn new(ctx: Arc, data: Vec) -> Self { + Self { + ctx, + data, + position: 0, + } + } +} + +impl Stream for TestStream { + async fn read_chunk(&mut self) -> Result> { + self.ctx + .faulty_call(if self.position == 0 { + "read_chunk_first" + } else { + "read_chunk_next" + }) + .await?; + let size = (self.data.len() - self.position).min(FRAME_SIZE); + if size == 0 { + Ok(None) + } else { + let chunk = &self.data[self.position..self.position + size]; + self.position += size; + Ok(Some(hyper::body::Bytes::copy_from_slice(chunk))) + } + } +} + +const PAGE_SIZE: usize = 4096; +const FRAME_SIZE: usize = 24 + PAGE_SIZE; + +impl SyncServer for TestSyncServer { + type Stream = TestStream; + async fn db_info(&self) -> Result { + self.ctx.faulty_call("db_info").await?; + + let state = self.state.lock().await; + Ok(DbSyncInfo { + current_generation: state.generation, + }) + } + + async fn db_export(&self, generation_id: usize) -> Result { + self.ctx.faulty_call("db_export").await?; + + let state = self.state.lock().await; + let Some(generation) = state.generations.get(&generation_id) else { + return Err(Error::DatabaseSyncError(format!("generation not found"))); + }; + Ok(TestStream::new( + self.ctx.clone(), + generation.snapshot.clone(), + )) + } + + async fn wal_pull(&self, generation_id: usize, start_frame: usize) -> Result { + tracing::debug!("wal_pull: {}/{}", generation_id, start_frame); + self.ctx.faulty_call("wal_pull").await?; + + let state = self.state.lock().await; + let Some(generation) = state.generations.get(&generation_id) else { + return Err(Error::DatabaseSyncError(format!("generation not found"))); + }; + let mut data = Vec::new(); + for frame_no in start_frame..start_frame + self.opts.pull_batch_size { + let frame_idx = frame_no - 1; + let Some(frame) = generation.frames.get(frame_idx) else { + break; + }; + data.extend_from_slice(frame); + } + if data.len() == 0 { + let last_generation = state.generations.get(&state.generation).unwrap(); + return Err(Error::PullNeedCheckpoint(DbSyncStatus { + baton: None, + status: "checkpoint_needed".to_string(), + generation: state.generation, + max_frame_no: last_generation.frames.len(), + })); + } + Ok(TestStream::new(self.ctx.clone(), data)) + } + + async fn wal_push( + &self, + mut baton: Option, + generation_id: usize, + start_frame: usize, + end_frame: usize, + frames: Vec, + ) -> Result { + tracing::debug!( + "wal_push: {}/{}/{}/{:?}", + generation_id, + start_frame, + end_frame, + baton + ); + self.ctx.faulty_call("wal_push").await?; + + let mut session = { + let mut state = self.state.lock().await; + if state.generation != generation_id { + return Err(Error::DatabaseSyncError(format!("generation id mismatch"))); + } + let baton_str = baton.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); + let session = match state.sessions.get(&baton_str) { + Some(session) => session.clone(), + None => { + let session = SyncSession { + baton: baton_str.clone(), + conn: self.db.connect()?, + in_txn: false, + }; + state.sessions.insert(baton_str.clone(), session.clone()); + session + } + }; + baton = Some(baton_str.clone()); + session + }; + + let mut offset = 0; + for frame_no in start_frame..end_frame { + if offset + FRAME_SIZE > frames.len() { + return Err(Error::DatabaseSyncError(format!( + "unexpected length of frames data" + ))); + } + if !session.in_txn { + session.conn.wal_insert_begin()?; + session.in_txn = true; + } + let frame = &frames[offset..offset + FRAME_SIZE]; + match session.conn.wal_insert_frame(frame_no as u32, frame) { + Ok(info) => { + if info.is_commit { + if session.in_txn { + session.conn.wal_insert_end()?; + session.in_txn = false; + } + self.sync_frames_from_conn(&session.conn).await?; + } + } + Err(turso::Error::WalOperationError(err)) if err.contains("Conflict") => { + session.conn.wal_insert_end()?; + return Err(Error::PushConflict); + } + Err(err) => { + session.conn.wal_insert_end()?; + return Err(err.into()); + } + } + offset += FRAME_SIZE; + } + let mut state = self.state.lock().await; + state + .sessions + .insert(baton.clone().unwrap(), session.clone()); + Ok(DbSyncStatus { + baton: Some(session.baton.clone()), + status: "ok".into(), + generation: state.generation, + max_frame_no: session.conn.wal_frame_count()? as usize, + }) + } +} + +// empty DB with single 4096-byte page and WAL mode (PRAGMA journal_mode=WAL) +// see test test_empty_wal_mode_db_content which validates asset content +pub const EMPTY_WAL_MODE_DB: &[u8] = include_bytes!("empty_wal_mode.db"); + +pub async fn convert_rows(rows: &mut turso::Rows) -> Result>> { + let mut rows_values = vec![]; + while let Some(row) = rows.next().await? { + let mut row_values = vec![]; + for i in 0..row.column_count() { + row_values.push(row.get_value(i)?); + } + rows_values.push(row_values); + } + Ok(rows_values) +} + +impl TestSyncServer { + pub async fn new(ctx: Arc, path: &Path, opts: TestSyncServerOpts) -> Result { + let mut generations = HashMap::new(); + generations.insert( + 1, + Generation { + snapshot: EMPTY_WAL_MODE_DB.to_vec(), + frames: Vec::new(), + }, + ); + Ok(Self { + ctx, + db: turso::Builder::new_local(path.to_str().unwrap()) + .build() + .await?, + opts: Arc::new(opts), + state: Arc::new(Mutex::new(TestSyncServerState { + generation: 1, + generations: generations, + sessions: HashMap::new(), + })), + }) + } + pub fn db(&self) -> turso::Database { + self.db.clone() + } + pub async fn execute(&self, sql: &str, params: impl IntoParams) -> Result<()> { + let conn = self.db.connect()?; + conn.execute(sql, params).await?; + self.sync_frames_from_conn(&conn).await?; + Ok(()) + } + async fn sync_frames_from_conn(&self, conn: &turso::Connection) -> Result<()> { + let mut state = self.state.lock().await; + let generation = state.generation; + let generation = state.generations.get_mut(&generation).unwrap(); + let last_frame = generation.frames.len() + 1; + let mut frame = [0u8; FRAME_SIZE]; + let wal_frame_count = conn.wal_frame_count()?; + tracing::debug!("conn frames count: {}", wal_frame_count); + for frame_no in last_frame..=wal_frame_count as usize { + conn.wal_get_frame(frame_no as u32, &mut frame)?; + tracing::debug!("push local frame {}", frame_no); + generation.frames.push(frame.to_vec()); + } + Ok(()) + } +}