diff --git a/sync/engine/src/database_sync_operations.rs b/sync/engine/src/database_sync_operations.rs index 5197e04e0..c8235df33 100644 --- a/sync/engine/src/database_sync_operations.rs +++ b/sync/engine/src/database_sync_operations.rs @@ -18,8 +18,9 @@ use crate::{ io_operations::IoOperations, protocol_io::{DataCompletion, DataPollResult, ProtocolIO}, server_proto::{ - self, ExecuteStreamReq, PageData, PageUpdatesEncodingReq, PullUpdatesReqProtoBody, - PullUpdatesRespProtoBody, Stmt, StmtResult, StreamRequest, + self, Batch, BatchCond, BatchStep, BatchStreamReq, ExecuteStreamReq, PageData, + PageUpdatesEncodingReq, PullUpdatesReqProtoBody, PullUpdatesRespProtoBody, Stmt, + StmtResult, StreamRequest, }, types::{ Coro, DatabasePullRevision, DatabaseRowTransformResult, DatabaseSyncEngineProtocolVersion, @@ -713,23 +714,32 @@ pub async fn push_logical_changes( ignore_schema_changes: false, ..Default::default() }; + let step = |query, args| BatchStep { + stmt: Stmt { + sql: Some(query), + sql_id: None, + args, + named_args: Vec::new(), + want_rows: Some(false), + replication_index: None, + }, + condition: Some(BatchCond::Not { + cond: Box::new(BatchCond::IsAutocommit {}), + }), + }; let mut sql_over_http_requests = vec![ - Stmt { - sql: Some("BEGIN IMMEDIATE".to_string()), - sql_id: None, - args: Vec::new(), - named_args: Vec::new(), - want_rows: Some(false), - replication_index: None, - }, - Stmt { - sql: Some(TURSO_SYNC_CREATE_TABLE.to_string()), - sql_id: None, - args: Vec::new(), - named_args: Vec::new(), - want_rows: Some(false), - replication_index: None, + BatchStep { + stmt: Stmt { + sql: Some("BEGIN IMMEDIATE".to_string()), + sql_id: None, + args: Vec::new(), + named_args: Vec::new(), + want_rows: Some(false), + replication_index: None, + }, + condition: None, }, + step(TURSO_SYNC_CREATE_TABLE.to_string(), Vec::new()), ]; let mut rows_changed = 0; let mut changes = source.iterate_changes(iterate_opts)?; @@ -797,14 +807,9 @@ pub async fn push_logical_changes( DatabaseTapeOperation::Commit => { panic!("Commit operation must not be emited at this stage") } - DatabaseTapeOperation::StmtReplay(replay) => sql_over_http_requests.push(Stmt { - sql: Some(replay.sql), - sql_id: None, - args: convert_to_args(replay.values), - named_args: Vec::new(), - want_rows: Some(false), - replication_index: None, - }), + DatabaseTapeOperation::StmtReplay(replay) => { + sql_over_http_requests.push(step(replay.sql, convert_to_args(replay.values))) + } DatabaseTapeOperation::RowChange(change) => { let replay_info = generator.replay_info(coro, &change).await?; match change.change { @@ -816,14 +821,8 @@ pub async fn push_logical_changes( before, None, ); - sql_over_http_requests.push(Stmt { - sql: Some(replay_info.query.clone()), - sql_id: None, - args: convert_to_args(values), - named_args: Vec::new(), - want_rows: Some(false), - replication_index: None, - }) + sql_over_http_requests + .push(step(replay_info.query.clone(), convert_to_args(values))) } DatabaseTapeRowChangeType::Insert { after } => { let values = generator.replay_values( @@ -833,14 +832,8 @@ pub async fn push_logical_changes( after, None, ); - sql_over_http_requests.push(Stmt { - sql: Some(replay_info.query.clone()), - sql_id: None, - args: convert_to_args(values), - named_args: Vec::new(), - want_rows: Some(false), - replication_index: None, - }) + sql_over_http_requests + .push(step(replay_info.query.clone(), convert_to_args(values))); } DatabaseTapeRowChangeType::Update { after, @@ -854,14 +847,8 @@ pub async fn push_logical_changes( after, Some(updates), ); - sql_over_http_requests.push(Stmt { - sql: Some(replay_info.query.clone()), - sql_id: None, - args: convert_to_args(values), - named_args: Vec::new(), - want_rows: Some(false), - replication_index: None, - }) + sql_over_http_requests + .push(step(replay_info.query.clone(), convert_to_args(values))); } DatabaseTapeRowChangeType::Update { after, @@ -875,14 +862,8 @@ pub async fn push_logical_changes( after, None, ); - sql_over_http_requests.push(Stmt { - sql: Some(replay_info.query.clone()), - sql_id: None, - args: convert_to_args(values), - named_args: Vec::new(), - want_rows: Some(false), - replication_index: None, - }); + sql_over_http_requests + .push(step(replay_info.query.clone(), convert_to_args(values))); } } } @@ -894,10 +875,9 @@ pub async fn push_logical_changes( // update turso_sync_last_change_id table with new value before commit let next_change_id = last_change_id.unwrap_or(0); tracing::info!("push_logical_changes: client_id={client_id}, set pull_gen={source_pull_gen}, change_id={next_change_id}, rows_changed={rows_changed}"); - sql_over_http_requests.push(Stmt { - sql: Some(TURSO_SYNC_UPSERT_LAST_CHANGE_ID.to_string()), - sql_id: None, - args: vec![ + sql_over_http_requests.push(step( + TURSO_SYNC_UPSERT_LAST_CHANGE_ID.to_string(), + vec![ server_proto::Value::Text { value: client_id.to_string(), }, @@ -908,27 +888,20 @@ pub async fn push_logical_changes( value: next_change_id, }, ], - named_args: Vec::new(), - want_rows: Some(false), - replication_index: None, - }); + )); } - sql_over_http_requests.push(Stmt { - sql: Some("COMMIT".to_string()), - sql_id: None, - args: Vec::new(), - named_args: Vec::new(), - want_rows: Some(false), - replication_index: None, - }); + sql_over_http_requests.push(step("COMMIT".to_string(), Vec::new())); tracing::trace!("hrana request: {:?}", sql_over_http_requests); let replay_hrana_request = server_proto::PipelineReqBody { baton: None, - requests: sql_over_http_requests - .into_iter() - .map(|stmt| StreamRequest::Execute(ExecuteStreamReq { stmt })) - .collect(), + requests: vec![StreamRequest::Batch(BatchStreamReq { + batch: Batch { + steps: sql_over_http_requests.into(), + replication_index: None, + }, + })] + .into(), }; let _ = sql_execute_http(coro, client, replay_hrana_request).await?; @@ -1206,6 +1179,20 @@ async fn sql_execute_http( server_proto::StreamResponse::Execute(execute) => { results.push(execute.result); } + server_proto::StreamResponse::Batch(batch) => { + for error in batch.result.step_errors { + if let Some(error) = error { + return Err(Error::DatabaseSyncEngineError(format!( + "failed to execute sql: {error:?}" + ))); + } + } + for result in batch.result.step_results { + if let Some(result) = result { + results.push(result); + } + } + } }, } } diff --git a/sync/engine/src/server_proto.rs b/sync/engine/src/server_proto.rs index 19e72082c..bca505d68 100644 --- a/sync/engine/src/server_proto.rs +++ b/sync/engine/src/server_proto.rs @@ -82,6 +82,8 @@ pub enum StreamRequest { None, /// See [`ExecuteStreamReq`] Execute(ExecuteStreamReq), + /// See [`BatchStreamReq`] + Batch(BatchStreamReq), } #[derive(Serialize, Deserialize, Default, Debug, PartialEq)] @@ -101,6 +103,66 @@ pub enum StreamResult { #[serde(tag = "type", rename_all = "snake_case")] pub enum StreamResponse { Execute(ExecuteStreamResp), + Batch(BatchStreamResp), +} + +#[derive(Serialize, Deserialize, Debug)] +/// A request to execute a batch of SQL statements that may each have a [`BatchCond`] that must be satisfied for the statement to be executed. +pub struct BatchStreamReq { + pub batch: Batch, +} + +#[derive(Serialize, Deserialize, Debug, PartialEq)] +/// A response to a [`BatchStreamReq`]. +pub struct BatchStreamResp { + pub result: BatchResult, +} + +#[derive(Clone, Deserialize, Serialize, Debug, Default, PartialEq)] +pub struct BatchResult { + pub step_results: Vec>, + pub step_errors: Vec>, + #[serde(default, with = "option_u64_as_str")] + pub replication_index: Option, +} + +#[derive(Clone, Deserialize, Serialize, Debug)] +pub struct Batch { + pub steps: VecDeque, + #[serde(default, with = "option_u64_as_str")] + pub replication_index: Option, +} + +#[derive(Clone, Deserialize, Serialize, Debug)] +pub struct BatchStep { + #[serde(default)] + pub condition: Option, + pub stmt: Stmt, +} + +#[derive(Clone, Deserialize, Serialize, Debug, Default)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum BatchCond { + #[serde(skip_deserializing)] + #[default] + None, + Ok { + step: u32, + }, + Error { + step: u32, + }, + Not { + cond: Box, + }, + And(BatchCondList), + Or(BatchCondList), + IsAutocommit {}, +} + +#[derive(Clone, Deserialize, Serialize, Debug)] +pub struct BatchCondList { + pub conds: Vec, } #[derive(Serialize, Deserialize, Debug, PartialEq)]