diff --git a/.cargo/config.toml b/.cargo/config.toml index 2586c8988..9af43e9ac 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,2 +1,8 @@ [env] LIBSQLITE3_FLAGS = "-DSQLITE_ENABLE_MATH_FUNCTIONS" # necessary for rusqlite dependency in order to bundle SQLite with math functions included + +[build] +# turso-sync package uses tokio_unstable to seed LocalRuntime and make it deterministic +# unfortunately, cargo commands invoked from workspace root didn't capture config.toml from dependent crate +# so, we set this cfg globally for workspace (see relevant issue build build-target: https://github.com/rust-lang/cargo/issues/7004) +rustflags = ["--cfg=tokio_unstable"] \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 5c9feb314..176f9b441 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -225,6 +225,29 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "aws-lc-rs" +version = "1.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c953fe1ba023e6b7730c0d4b031d06f267f23a46167dcbd40316644b10a17ba" +dependencies = [ + "aws-lc-sys", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.30.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbfd150b5dbdb988bcc8fb1fe787eb6b7ee6180ca24da683b61ea5405f3d43ff" +dependencies = [ + "bindgen", + "cc", + "cmake", + "dunce", + "fs_extra", +] + [[package]] name = "backtrace" version = "0.3.74" @@ -264,6 +287,29 @@ dependencies = [ "serde", ] +[[package]] +name = "bindgen" +version = "0.69.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" +dependencies = [ + "bitflags 2.9.0", + "cexpr", + "clang-sys", + "itertools 0.10.5", + "lazy_static", + "lazycell", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash 1.1.0", + "shlex", + "syn 2.0.100", + "which", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -375,6 +421,15 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + [[package]] name = "cfg-if" version = "1.0.0" @@ -435,6 +490,17 @@ dependencies = [ "half", ] +[[package]] +name = "clang-sys" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +dependencies = [ + "glob", + "libc", + "libloading", +] + [[package]] name = "clap" version = "4.5.32" @@ -496,6 +562,15 @@ dependencies = [ "error-code", ] +[[package]] +name = "cmake" +version = "0.1.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +dependencies = [ + "cc", +] + [[package]] name = "colorchoice" version = "1.0.3" @@ -563,6 +638,16 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "core-foundation" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -577,7 +662,7 @@ dependencies = [ "assert_cmd", "env_logger 0.10.2", "log", - "rand 0.9.0", + "rand 0.9.2", "rand_chacha 0.9.0", "rexpect", "rusqlite", @@ -952,6 +1037,12 @@ version = "0.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7454e41ff9012c00d53cf7f475c5e3afa3b91b7c90568495495e8d9bf47a1055" +[[package]] +name = "dunce" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" + [[package]] name = "dyn-clone" version = "1.0.19" @@ -1174,6 +1265,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + [[package]] name = "fsevent-sys" version = "4.1.0" @@ -1405,12 +1502,110 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "http" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + [[package]] name = "humantime" version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b112acc8b3adf4b107a8ec20977da0273a8c386765a3ec0229bd500a1443f9f" +[[package]] +name = "hyper" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "http", + "http-body", + "httparse", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", + "want", +] + +[[package]] +name = "hyper-rustls" +version = "0.27.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +dependencies = [ + "http", + "hyper", + "hyper-util", + "log", + "rustls", + "rustls-native-certs", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", +] + +[[package]] +name = "hyper-util" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d9b05277c7e8da2c93a568989bb6207bef0112e8d17df7a6eda4a3cf143bc5e" +dependencies = [ + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "http", + "http-body", + "hyper", + "libc", + "pin-project-lite", + "socket2", + "tokio", + "tower-service", + "tracing", +] + [[package]] name = "iana-time-zone" version = "0.1.62" @@ -1819,6 +2014,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + [[package]] name = "libc" version = "0.2.172" @@ -2175,6 +2376,12 @@ dependencies = [ "libmimalloc-sys", ] +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "miniz_oxide" version = "0.8.5" @@ -2207,7 +2414,7 @@ dependencies = [ "napi-build", "napi-sys", "nohash-hasher", - "rustc-hash", + "rustc-hash 2.1.1", ] [[package]] @@ -2301,6 +2508,16 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2bf50223579dc7cdcfb3bfcacf7069ff68243f8c363f62ffa99cf000a6b9c451" +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "notify" version = "8.0.0" @@ -2424,6 +2641,12 @@ version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" +[[package]] +name = "openssl-probe" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + [[package]] name = "option-ext" version = "0.2.0" @@ -2642,6 +2865,16 @@ dependencies = [ "termtree", ] +[[package]] +name = "prettyplease" +version = "0.2.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6837b9e10d61f45f987d50808f83d1ee3d206c66acf650c3e4ae2e1f6ddedf55" +dependencies = [ + "proc-macro2", + "syn 2.0.100", +] + [[package]] name = "proc-macro-error-attr2" version = "2.0.0" @@ -2824,13 +3057,12 @@ dependencies = [ [[package]] name = "rand" -version = "0.9.0" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ "rand_chacha 0.9.0", "rand_core 0.9.3", - "zerocopy 0.8.26", ] [[package]] @@ -3057,6 +3289,12 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "rustc-hash" version = "2.1.1" @@ -3108,6 +3346,54 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "rustls" +version = "0.23.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0ebcbd2f03de0fc1122ad9bb24b127a5a6cd51d72604a3f3c50ac459762b6cc" +dependencies = [ + "aws-lc-rs", + "log", + "once_cell", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-native-certs" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3" +dependencies = [ + "openssl-probe", + "rustls-pki-types", + "schannel", + "security-framework", +] + +[[package]] +name = "rustls-pki-types" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79" +dependencies = [ + "zeroize", +] + +[[package]] +name = "rustls-webpki" +version = "0.103.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a17884ae0c1b773f1ccd2bd4a8c72f16da897310a98b0e84bf349ad5ead92fc" +dependencies = [ + "aws-lc-rs", + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.20" @@ -3163,6 +3449,15 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "schannel" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" +dependencies = [ + "windows-sys 0.59.0", +] + [[package]] name = "schemars" version = "0.8.22" @@ -3194,6 +3489,29 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "security-framework" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316" +dependencies = [ + "bitflags 2.9.0", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "semver" version = "1.0.26" @@ -3293,12 +3611,12 @@ checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" [[package]] name = "socket2" -version = "0.6.0" +version = "0.5.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "233504af464074f9d066d7b5416c5f9b894a5862a6506e306f7b816cdd6f1807" +checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" dependencies = [ "libc", - "windows-sys 0.59.0", + "windows-sys 0.52.0", ] [[package]] @@ -3356,6 +3674,12 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + [[package]] name = "supports-color" version = "3.0.2" @@ -3642,9 +3966,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.47.0" +version = "1.46.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43864ed400b6043a4757a25c7a64a8efde741aed79a056a2fb348a406701bb35" +checksum = "0cc3a2344dafbe23a245241fe8b09735b521110d30fcefbbd5feb1797ca35d17" dependencies = [ "backtrace", "bytes", @@ -3657,7 +3981,7 @@ dependencies = [ "slab", "socket2", "tokio-macros", - "windows-sys 0.59.0", + "windows-sys 0.52.0", ] [[package]] @@ -3671,6 +3995,16 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "tokio-rustls" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b" +dependencies = [ + "rustls", + "tokio", +] + [[package]] name = "toml" version = "0.8.22" @@ -3713,6 +4047,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bfb942dfe1d8e29a7ee7fcbde5bd2b9a25fb89aa70caea2eba3bee836ff41076" +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + [[package]] name = "tracing" version = "0.1.41" @@ -3786,6 +4126,12 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + [[package]] name = "turso" version = "0.1.3" @@ -3810,6 +4156,18 @@ name = "turso-sync" version = "0.1.3" dependencies = [ "ctor", + "futures", + "http", + "http-body-util", + "hyper", + "hyper-rustls", + "hyper-util", + "paste", + "rand 0.9.2", + "rand_chacha 0.9.0", + "rustls", + "serde", + "serde_json", "tempfile", "thiserror 2.0.12", "tokio", @@ -3817,6 +4175,7 @@ dependencies = [ "tracing-subscriber", "turso", "turso_core", + "uuid", ] [[package]] @@ -4156,6 +4515,15 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -4252,6 +4620,18 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "which" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" +dependencies = [ + "either", + "home", + "once_cell", + "rustix 0.38.44", +] + [[package]] name = "winapi" version = "0.3.9" @@ -4636,6 +5016,12 @@ dependencies = [ "synstructure", ] +[[package]] +name = "zeroize" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" + [[package]] name = "zerovec" version = "0.10.4" diff --git a/packages/turso-sync/Cargo.toml b/packages/turso-sync/Cargo.toml index ef5ef40fa..c7079b94b 100644 --- a/packages/turso-sync/Cargo.toml +++ b/packages/turso-sync/Cargo.toml @@ -11,9 +11,23 @@ turso_core = { workspace = true } turso = { workspace = true } thiserror = "2.0.12" tracing = "0.1.41" +hyper = { version = "1.6.0", features = ["client", "http1"] } +serde_json.workspace = true +http-body-util = "0.1.3" +http = "1.3.1" +hyper-util = { version = "0.1.16", features = ["tokio", "http1", "client"] } +serde = { workspace = true, features = ["derive"] } +tokio = { version = "1.46.1", features = ["fs", "io-util"] } +hyper-rustls = "0.27.7" +rustls = "0.23.31" [dev-dependencies] ctor = "0.4.2" tempfile = "3.20.0" tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } -tokio = { version = "1.46.1", features = ["macros", "rt-multi-thread"] } +tokio = { version = "1.46.1", features = ["macros", "rt-multi-thread", "test-util"] } +uuid = "1.17.0" +rand = "0.9.2" +rand_chacha = "0.9.0" +futures = "0.3.31" +paste = "1.0.15" diff --git a/packages/turso-sync/examples/example_sync.rs b/packages/turso-sync/examples/example_sync.rs new file mode 100644 index 000000000..1200caeb4 --- /dev/null +++ b/packages/turso-sync/examples/example_sync.rs @@ -0,0 +1,70 @@ +use std::io::{self, Write}; + +use tracing_subscriber::EnvFilter; +use turso_sync::database::Builder; + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .init(); + let sync_url = std::env::var("TURSO_SYNC_URL").unwrap(); + let auth_token = std::env::var("TURSO_AUTH_TOKEN").ok(); + let local_path = std::env::var("TURSO_LOCAL_PATH").unwrap(); + let mut db = Builder::new_synced(&local_path, &sync_url, auth_token) + .build() + .await + .unwrap(); + + loop { + print!("> "); + io::stdout().flush().unwrap(); + + let mut input = String::new(); + let bytes_read = io::stdin().read_line(&mut input).unwrap(); + + if bytes_read == 0 { + break; + } + + let trimmed = input.trim(); + match trimmed { + ".exit" | ".quit" => break, + ".pull" => { + db.pull().await.unwrap(); + continue; + } + ".push" => { + db.push().await.unwrap(); + continue; + } + ".sync" => { + db.sync().await.unwrap(); + continue; + } + _ => {} + } + let mut rows = db.query(&input, ()).await.unwrap(); + while let Some(row) = rows.next().await.unwrap() { + let mut values = vec![]; + for i in 0..row.column_count() { + let value = row.get_value(i).unwrap(); + match value { + turso::Value::Null => values.push("NULL".to_string()), + turso::Value::Integer(x) => values.push(format!("{x}")), + turso::Value::Real(x) => values.push(format!("{x}")), + turso::Value::Text(x) => values.push(format!("'{x}'")), + turso::Value::Blob(x) => values.push(format!( + "x'{}'", + x.iter() + .map(|x| format!("{x:02x}")) + .collect::>() + .join(""), + )), + } + } + println!("{}", &values.join(" ")); + io::stdout().flush().unwrap(); + } + } +} diff --git a/packages/turso-sync/src/database.rs b/packages/turso-sync/src/database.rs new file mode 100644 index 000000000..2bc64729d --- /dev/null +++ b/packages/turso-sync/src/database.rs @@ -0,0 +1,98 @@ +use std::path::PathBuf; + +use hyper_rustls::{ConfigBuilderExt, HttpsConnector, HttpsConnectorBuilder}; +use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor}; + +use crate::{ + database_inner::{DatabaseInner, Rows}, + errors::Error, + filesystem::tokio::TokioFilesystem, + sync_server::turso::{TursoSyncServer, TursoSyncServerOpts}, + Result, +}; + +/// [Database] expose public interface for synced database from [DatabaseInner] private implementation +/// +/// This layer also serves a purpose of "gluing" together all component for real use, +/// because [DatabaseInner] abstracts things away in order to simplify testing +pub struct Database(DatabaseInner); + +pub struct Builder { + path: String, + sync_url: String, + auth_token: Option, + encryption_key: Option, + connector: Option>, +} + +impl Builder { + pub fn new_synced(path: &str, sync_url: &str, auth_token: Option) -> Self { + Self { + path: path.to_string(), + sync_url: sync_url.to_string(), + auth_token, + encryption_key: None, + connector: None, + } + } + pub fn with_encryption_key(self, encryption_key: &str) -> Self { + Self { + encryption_key: Some(encryption_key.to_string()), + ..self + } + } + pub fn with_connector(self, connector: HttpsConnector) -> Self { + Self { + connector: Some(connector), + ..self + } + } + pub async fn build(self) -> Result { + let path = PathBuf::from(self.path); + let connector = self.connector.map(Ok).unwrap_or_else(default_connector)?; + let executor = TokioExecutor::new(); + let client = hyper_util::client::legacy::Builder::new(executor).build(connector); + let sync_server = TursoSyncServer::new( + client, + TursoSyncServerOpts { + sync_url: self.sync_url, + auth_token: self.auth_token, + encryption_key: self.encryption_key, + pull_batch_size: None, + }, + )?; + let filesystem = TokioFilesystem(); + let inner = DatabaseInner::new(filesystem, sync_server, &path).await?; + Ok(Database(inner)) + } +} + +impl Database { + pub async fn sync(&mut self) -> Result<()> { + self.0.sync().await + } + pub async fn pull(&mut self) -> Result<()> { + self.0.pull().await + } + pub async fn push(&mut self) -> Result<()> { + self.0.push().await + } + pub async fn execute(&self, sql: &str, params: impl turso::IntoParams) -> Result { + self.0.execute(sql, params).await + } + pub async fn query(&self, sql: &str, params: impl turso::IntoParams) -> Result { + self.0.query(sql, params).await + } +} + +pub fn default_connector() -> Result> { + let tls_config = rustls::ClientConfig::builder() + .with_native_roots() + .map_err(|e| Error::DatabaseSyncError(format!("unable to configure CA roots: {e}")))? + .with_no_client_auth(); + Ok(HttpsConnectorBuilder::new() + .with_tls_config(tls_config) + .https_or_http() + .enable_http1() + .build()) +} diff --git a/packages/turso-sync/src/database_inner.rs b/packages/turso-sync/src/database_inner.rs new file mode 100644 index 000000000..77c0297d9 --- /dev/null +++ b/packages/turso-sync/src/database_inner.rs @@ -0,0 +1,1117 @@ +use std::{ + io::ErrorKind, + path::{Path, PathBuf}, + sync::Arc, +}; + +use tokio::sync::{OwnedRwLockReadGuard, RwLock}; + +use crate::{ + database_tape::{ + DatabaseChangesIteratorMode, DatabaseChangesIteratorOpts, DatabaseTape, DatabaseTapeOpts, + }, + errors::Error, + filesystem::Filesystem, + metadata::{ActiveDatabase, DatabaseMetadata}, + sync_server::{Stream, SyncServer}, + types::DatabaseTapeOperation, + wal_session::WalSession, + Result, +}; + +pub struct DatabaseInner { + filesystem: F, + sync_server: S, + draft_path: PathBuf, + synced_path: PathBuf, + meta_path: PathBuf, + meta: Option, + database: Arc>, + // we remember information if Synced DB is dirty - which will make Database to reset it in case of any sync attempt + // this bit is set to false when we properly reset Synced DB + // this bit is set to true when we transfer changes from Draft to Synced or on initialization + synced_is_dirty: bool, +} + +struct ActiveDatabaseContainer { + db: Option, + active_type: ActiveDatabase, +} + +impl ActiveDatabaseContainer { + pub fn active(&self) -> &DatabaseTape { + self.db.as_ref().unwrap() + } +} + +pub struct Rows { + _guard: OwnedRwLockReadGuard, + rows: turso::Rows, +} + +impl std::ops::Deref for Rows { + type Target = turso::Rows; + + fn deref(&self) -> &Self::Target { + &self.rows + } +} + +impl std::ops::DerefMut for Rows { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.rows + } +} + +const PAGE_SIZE: usize = 4096; +const WAL_HEADER: usize = 32; +const FRAME_SIZE: usize = 24 + PAGE_SIZE; + +impl DatabaseInner { + pub async fn new(filesystem: F, sync_server: S, path: &Path) -> Result { + let path_str = path + .to_str() + .ok_or_else(|| Error::DatabaseSyncError(format!("invalid path: {path:?}")))?; + let draft_path = PathBuf::from(format!("{path_str}-draft")); + let synced_path = PathBuf::from(format!("{path_str}-synced")); + let meta_path = PathBuf::from(format!("{path_str}-info")); + let database_container = Arc::new(RwLock::new(ActiveDatabaseContainer { + db: None, + active_type: ActiveDatabase::Draft, + })); + let mut db = Self { + sync_server, + filesystem, + draft_path, + synced_path, + meta_path, + meta: None, + database: database_container, + synced_is_dirty: true, + }; + db.init().await?; + Ok(db) + } + + pub async fn execute(&self, sql: &str, params: impl turso::IntoParams) -> Result { + let database = self.database.read().await; + let active = database.active(); + let conn = active.connect().await?; + let result = conn.execute(sql, params).await?; + Ok(result) + } + + pub async fn query(&self, sql: &str, params: impl turso::IntoParams) -> Result { + let database = self.database.clone().read_owned().await; + let active = database.active(); + let conn = active.connect().await?; + let rows = conn.query(sql, params).await?; + Ok(Rows { + _guard: database, + rows, + }) + } + + /// Sync any new changes from remote DB and apply them locally + /// This method will **not** send local changed to the remote + /// This method will block writes for the period of sync + pub async fn pull(&mut self) -> Result<()> { + tracing::debug!("sync_from_remote"); + self.cleanup_synced().await?; + + self.pull_synced_from_remote().await?; + // we will copy Synced WAL to the Draft WAL later without pushing it to the remote + // so, we pass 'capture: true' as we need to preserve all changes for future push of WAL + let _ = self.transfer_draft_to_synced(true).await?; + assert!( + self.synced_is_dirty, + "synced_is_dirty must be set after transfer_draft_to_synced" + ); + + // switch requests to Synced DB and update metadata + // because reading Draft while we are transferring data from Synced is not allowed + self.switch_active(ActiveDatabase::Synced, self.open_synced(false).await?) + .await?; + + // as we transferred row changes from Draft to Synced, all changes will be re-written (with their IDs starts from 1) + // and we must updated synced_change_id + self.meta = Some(self.write_meta(|meta| meta.synced_change_id = None).await?); + + self.transfer_synced_to_draft().await.inspect_err(|e| { + tracing::error!("transfer_synced_to_draft failed, writes are blocked for the DB: {e}",) + })?; + + // switch requests back to Draft DB + self.switch_active(ActiveDatabase::Draft, self.open_draft().await?) + .await?; + + // Synced DB now has extra WAL frames from [transfer_draft_to_synced] call, so we need to reset them + self.reset_synced().await?; + assert!( + !self.synced_is_dirty, + "synced_is_dirty must not be set after reset_synced" + ); + Ok(()) + } + + /// Sync local changes to remote DB + /// This method will **not** pull remote changes to the local DB + /// This method will **not** block writes for the period of sync + pub async fn push(&mut self) -> Result<()> { + tracing::debug!("sync to remote"); + self.cleanup_synced().await?; + + self.pull_synced_from_remote().await?; + + let change_id = self.transfer_draft_to_synced(false).await?; + // update transferred_change_id field because after we will start pushing frames - we must be able to resume this operation + // otherwise, we will encounter conflicts because some frames will be pushed while we will think that they are not + self.meta = Some( + self.write_meta(|meta| meta.transferred_change_id = change_id) + .await?, + ); + self.push_synced_to_remote(change_id).await?; + Ok(()) + } + + /// Sync local changes to remote DB and bring new changes from remote to local + /// This method will block writes for the period of sync + pub async fn sync(&mut self) -> Result<()> { + // todo(sivukhin): this is bit suboptimal as both sync_to_remote and sync_from_remote will call pull_synced_from_remote + // but for now - keep it simple + self.push().await?; + self.pull().await?; + Ok(()) + } + + async fn init(&mut self) -> Result<()> { + tracing::debug!("initialize synced database instance"); + + match DatabaseMetadata::read_from(&self.filesystem, &self.meta_path).await? { + Some(meta) => self.meta = Some(meta), + None => { + let meta = self.bootstrap_db_files().await?; + + tracing::debug!("write meta after successful bootstrap"); + meta.write_to(&self.filesystem, &self.meta_path).await?; + self.meta = Some(meta); + } + }; + + let draft_exists = self.filesystem.exists_file(&self.draft_path).await?; + let synced_exists = self.filesystem.exists_file(&self.synced_path).await?; + if !draft_exists || !synced_exists { + return Err(Error::DatabaseSyncError( + "Draft or Synced files doesn't exists, but metadata is".to_string(), + )); + } + + // Synced db is active - we need to finish transfer from Synced to Draft then + if self.meta().active_db == ActiveDatabase::Synced { + self.transfer_synced_to_draft().await?; + } + + // sync WAL from the remote + self.pull().await?; + + assert!( + self.meta().active_db == ActiveDatabase::Draft, + "active_db must be Draft after init" + ); + let db = self.open_draft().await?; + self.database.write().await.db = Some(db); + Ok(()) + } + + async fn open_synced(&self, capture: bool) -> Result { + let clean_path_str = self.synced_path.to_str().unwrap(); + let clean = turso::Builder::new_local(clean_path_str).build().await?; + let opts = DatabaseTapeOpts { + cdc_table: None, + cdc_mode: Some(if capture { "after" } else { "off" }.to_string()), + }; + tracing::debug!("initialize clean database connection"); + Ok(DatabaseTape::new_with_opts(clean, opts)) + } + + async fn open_draft(&self) -> Result { + let draft_path_str = self.draft_path.to_str().unwrap(); + let draft = turso::Builder::new_local(draft_path_str).build().await?; + let opts = DatabaseTapeOpts { + cdc_table: None, + cdc_mode: Some("after".to_string()), + }; + tracing::debug!("initialize draft database connection"); + Ok(DatabaseTape::new_with_opts(draft, opts)) + } + + async fn switch_active(&mut self, active_type: ActiveDatabase, db: DatabaseTape) -> Result<()> { + let mut database = self.database.write().await; + self.meta = Some(self.write_meta(|meta| meta.active_db = active_type).await?); + *database = ActiveDatabaseContainer { + db: Some(db), + active_type, + }; + Ok(()) + } + + async fn bootstrap_db_files(&mut self) -> Result { + assert!( + self.meta.is_none(), + "bootstrap_db_files must be called only when meta is not set" + ); + if self.filesystem.exists_file(&self.draft_path).await? { + self.filesystem.remove_file(&self.draft_path).await?; + } + if self.filesystem.exists_file(&self.synced_path).await? { + self.filesystem.remove_file(&self.synced_path).await?; + } + + let info = self.sync_server.db_info().await?; + let mut synced_file = self.filesystem.create_file(&self.synced_path).await?; + + let start_time = tokio::time::Instant::now(); + let mut written_bytes = 0; + tracing::debug!("start bootstrapping Synced file from remote"); + + let mut bootstrap = self.sync_server.db_export(info.current_generation).await?; + while let Some(chunk) = bootstrap.read_chunk().await? { + self.filesystem.write_file(&mut synced_file, &chunk).await?; + written_bytes += chunk.len(); + } + + let elapsed = tokio::time::Instant::now().duration_since(start_time); + tracing::debug!( + "finish bootstrapping Synced file from remote: written_bytes={}, elapsed={:?}", + written_bytes, + elapsed + ); + + self.filesystem + .copy_file(&self.synced_path, &self.draft_path) + .await?; + tracing::debug!("copied Synced file to Draft"); + + Ok(DatabaseMetadata { + synced_generation: info.current_generation, + synced_frame_no: 0, + synced_change_id: None, + transferred_change_id: None, + active_db: ActiveDatabase::Draft, + }) + } + + async fn write_meta(&self, update: impl Fn(&mut DatabaseMetadata)) -> Result { + let mut meta = self.meta().clone(); + update(&mut meta); + // todo: what happen if we will actually update the metadata on disk but fail and so in memory state will not be updated + meta.write_to(&self.filesystem, &self.meta_path).await?; + Ok(meta) + } + + /// Pull updates from remote to the Synced database + /// This method will update Synced database WAL frames and [DatabaseMetadata::synced_frame_no] metadata field + async fn pull_synced_from_remote(&mut self) -> Result<()> { + tracing::debug!("pull_synced_from_remote"); + let database = self.database.read().await; + assert!( + database.active_type == ActiveDatabase::Draft, + "Draft database must be active as we will modify Clean" + ); + + let (generation, mut frame_no) = { + let meta = self.meta(); + (meta.synced_generation, meta.synced_frame_no) + }; + + // open fresh connection to the Clean database in order to initiate WAL session + let clean = self.open_synced(false).await?; + let clean_conn = clean.connect().await?; + + let mut wal_session = WalSession::new(&clean_conn); + let mut buffer = Vec::with_capacity(FRAME_SIZE); + loop { + tracing::debug!( + "pull clean wal portion: generation={}, frame={}", + generation, + frame_no + 1 + ); + let pull = self.sync_server.wal_pull(generation, frame_no + 1).await; + + let mut data = match pull { + Ok(data) => data, + Err(Error::PullNeedCheckpoint(status)) + if status.generation == generation && status.max_frame_no == frame_no => + { + tracing::debug!("end of history reached for database: status={:?}", status); + break; + } + Err(e @ Error::PullNeedCheckpoint(..)) => { + // todo(sivukhin): temporary not supported - will implement soon after TRUNCATE checkpoint will be merged to turso-db + return Err(e); + } + Err(e) => return Err(e), + }; + while let Some(mut chunk) = data.read_chunk().await? { + // chunk is arbitrary - aggregate groups of FRAME_SIZE bytes out from the chunks stream + while !chunk.is_empty() { + let to_fill = FRAME_SIZE - buffer.len(); + let prefix = chunk.split_to(to_fill.min(chunk.len())); + buffer.extend_from_slice(&prefix); + assert!( + buffer.capacity() == FRAME_SIZE, + "buffer should not extend its capacity" + ); + if buffer.len() < FRAME_SIZE { + continue; + } + frame_no += 1; + if !wal_session.in_txn() { + wal_session.begin()?; + } + let wal_insert_info = clean_conn.wal_insert_frame(frame_no as u32, &buffer)?; + if wal_insert_info.is_commit { + wal_session.end()?; + // transaction boundary reached - it's safe to commit progress + self.meta = Some(self.write_meta(|m| m.synced_frame_no = frame_no).await?); + } + buffer.clear(); + } + } + } + Ok(()) + } + + async fn push_synced_to_remote(&mut self, change_id: Option) -> Result<()> { + tracing::debug!("push_synced_to_remote"); + match self.do_push_synced_to_remote().await { + Ok(()) => { + self.meta = Some( + self.write_meta(|meta| { + meta.synced_change_id = change_id; + meta.transferred_change_id = None; + }) + .await?, + ); + Ok(()) + } + Err(err @ Error::PushConflict) => { + tracing::info!("push_synced_to_remote: conflict detected, rollback local changes"); + // we encountered conflict - which means that other client pushed something to the WAL before us + // as we were unable to insert any frame to the remote WAL - it's safe to reset our state completely + self.meta = Some( + self.write_meta(|meta| meta.transferred_change_id = None) + .await?, + ); + self.reset_synced().await?; + Err(err) + } + Err(err) => { + tracing::info!("err: {}", err); + Err(err) + } + } + } + + async fn do_push_synced_to_remote(&mut self) -> Result<()> { + tracing::debug!("do_push_synced_to_remote"); + let database = self.database.read().await; + assert!(database.active_type == ActiveDatabase::Draft); + + let (generation, frame_no) = { + let meta = self.meta(); + (meta.synced_generation, meta.synced_frame_no) + }; + + let clean = self.open_synced(false).await?; + let clean_conn = clean.connect().await?; + + // todo(sivukhin): push frames in multiple batches + let mut frames = Vec::new(); + let mut frames_cnt = 0; + { + let mut wal_session = WalSession::new(&clean_conn); + wal_session.begin()?; + + let clean_frames = clean_conn.wal_frame_count()? as usize; + + let mut buffer = [0u8; FRAME_SIZE]; + for frame_no in (frame_no + 1)..=clean_frames { + clean_conn.wal_get_frame(frame_no as u32, &mut buffer)?; + frames.extend_from_slice(&buffer); + frames_cnt += 1; + } + } + + if frames_cnt == 0 { + return Ok(()); + } + + self.sync_server + .wal_push( + None, + generation, + frame_no + 1, + frame_no + frames_cnt + 1, + frames, + ) + .await?; + self.meta = Some( + self.write_meta(|meta| meta.synced_frame_no = frame_no + frames_cnt) + .await?, + ); + + Ok(()) + } + + /// Transfers row changes from Draft DB to the Clean DB + async fn transfer_draft_to_synced(&mut self, capture: bool) -> Result> { + tracing::debug!("transfer_draft_to_synced"); + let database = self.database.read().await; + assert!(database.active_type == ActiveDatabase::Draft); + self.synced_is_dirty = true; + + let draft = self.open_draft().await?; + let synced = self.open_synced(capture).await?; + let opts = DatabaseChangesIteratorOpts { + first_change_id: self.meta().synced_change_id.map(|x| x + 1), + mode: DatabaseChangesIteratorMode::Apply, + ..Default::default() + }; + let mut last_change_id = self.meta().synced_change_id; + let mut session = synced.start_tape_session().await?; + let mut changes = draft.iterate_changes(opts).await?; + while let Some(operation) = changes.next().await? { + if let DatabaseTapeOperation::RowChange(change) = &operation { + assert!( + last_change_id.is_none() || last_change_id.unwrap() < change.change_id, + "change id must be strictly increasing: last_change_id={:?}, change.change_id={}", + last_change_id, change.change_id + ); + // we give user full control over CDC table - so let's not emit assert here for now + if last_change_id.is_some() && last_change_id.unwrap() + 1 != change.change_id { + tracing::warn!( + "out of order change sequence: {} -> {}", + last_change_id.unwrap(), + change.change_id + ); + } + last_change_id = Some(change.change_id); + } + session.replay(operation).await?; + } + + Ok(last_change_id) + } + + /// [Self::transfer_synced_to_draft] can fail and require cleanup from next calls of another sync methods + /// [Self::cleanup_synced] will check if active DB is Synced and perform necessary cleanup + async fn cleanup_synced(&mut self) -> Result<()> { + tracing::debug!("cleanup_synced"); + + if self.meta().active_db == ActiveDatabase::Synced { + tracing::info!("active_db was set to Synced - finish transfer of Synced DB to the Draft and switch active database"); + self.transfer_synced_to_draft().await?; + self.switch_active(ActiveDatabase::Draft, self.open_draft().await?) + .await?; + } + if let Some(change_id) = self.meta().transferred_change_id { + tracing::info!("some changes was transferred to the Synced DB but wasn't properly pushed to the remote"); + match self.push_synced_to_remote(Some(change_id)).await { + // ignore Ok and Error::PushConflict - because in this case we sucessfully finalized previous operation + Ok(()) | Err(Error::PushConflict) => {} + Err(err) => return Err(err), + } + } + // if we failed in the middle before - let's reset Synced DB if necessary + // if everything works without error - we will properly set is_synced_dirty flag and this function will be no-op + self.reset_synced().await?; + Ok(()) + } + + async fn transfer_synced_to_draft(&mut self) -> Result<()> { + tracing::debug!("transfer_synced_to_draft"); + { + let database = self.database.read().await; + assert!(database.active_type == ActiveDatabase::Synced); + } + + let draft_path_str = self.draft_path.to_str().unwrap_or(""); + let clean_path_str = self.synced_path.to_str().unwrap_or(""); + let draft_wal = PathBuf::from(format!("{draft_path_str}-wal")); + let clean_wal = PathBuf::from(format!("{clean_path_str}-wal")); + let draft_shm = PathBuf::from(format!("{draft_path_str}-shm")); + self.filesystem + .copy_file(&self.synced_path, &self.draft_path) + .await?; + self.filesystem.copy_file(&clean_wal, &draft_wal).await?; + self.filesystem.remove_file(&draft_shm).await?; + + Ok(()) + } + + /// Reset WAL of Synced database which potentially can have some local changes + async fn reset_synced(&mut self) -> Result<()> { + tracing::debug!("reset_synced"); + { + let database = self.database.read().await; + assert!(database.active_type == ActiveDatabase::Draft); + } + + // if we know that Clean DB is not dirty - let's skip this phase completely + if !self.synced_is_dirty { + return Ok(()); + } + + let clean_path_str = self.synced_path.to_str().unwrap_or(""); + let clean_wal_path = PathBuf::from(format!("{clean_path_str}-wal")); + let wal_size = WAL_HEADER + FRAME_SIZE * self.meta().synced_frame_no; + tracing::debug!( + "reset Synced DB WAL to the size of {} frames", + self.meta().synced_frame_no + ); + match self.filesystem.open_file(&clean_wal_path).await { + Ok(clean_wal) => { + self.filesystem.truncate_file(&clean_wal, wal_size).await?; + } + Err(Error::FilesystemError(err)) if err.kind() == ErrorKind::NotFound => {} + Err(err) => return Err(err), + } + + self.synced_is_dirty = false; + Ok(()) + } + + fn meta(&self) -> &DatabaseMetadata { + self.meta.as_ref().expect("metadata must be set") + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use rand::RngCore; + use tokio::sync::Mutex; + use turso::Value; + + use crate::{ + database_inner::DatabaseInner, + errors::Error, + filesystem::{test::TestFilesystem, Filesystem}, + sync_server::{ + test::{convert_rows, TestSyncServer, TestSyncServerOpts}, + SyncServer, + }, + test_context::{FaultInjectionStrategy, TestContext}, + tests::{deterministic_runtime, seed_u64}, + Result, + }; + + async fn query_rows( + db: &DatabaseInner, + sql: &str, + ) -> Result>> { + let mut rows = db.query(sql, ()).await?; + convert_rows(&mut rows).await + } + + #[test] + pub fn test_sync_single_db_simple() { + deterministic_runtime(async || { + let dir = tempfile::TempDir::new().unwrap(); + let server_path = dir.path().join("server.db"); + let opts = TestSyncServerOpts { pull_batch_size: 1 }; + let ctx = Arc::new(TestContext::new(seed_u64())); + let server = TestSyncServer::new(ctx.clone(), &server_path, opts) + .await + .unwrap(); + let fs = TestFilesystem::new(ctx.clone()); + let local_path = dir.path().join("local.db"); + let mut db = DatabaseInner::new(fs, server.clone(), &local_path) + .await + .unwrap(); + + server + .execute("CREATE TABLE t(x INTEGER PRIMARY KEY)", ()) + .await + .unwrap(); + server + .execute("INSERT INTO t VALUES (1)", ()) + .await + .unwrap(); + + // no table in schema before sync from remote (as DB was initialized when remote was empty) + assert!(matches!( + query_rows(&db, "SELECT * FROM t").await, + Err(Error::TursoError(turso::Error::SqlExecutionFailure(x))) if x.contains("no such table: t") + )); + + // 1 rows synced + db.pull().await.unwrap(); + assert_eq!( + query_rows(&db, "SELECT * FROM t").await.unwrap(), + vec![vec![Value::Integer(1)]] + ); + + server + .execute("INSERT INTO t VALUES (2)", ()) + .await + .unwrap(); + + db.execute("INSERT INTO t VALUES (3)", ()).await.unwrap(); + + // changes are synced from the remote - but remote changes are not propagated locally + db.push().await.unwrap(); + assert_eq!( + query_rows(&db, "SELECT * FROM t").await.unwrap(), + vec![vec![Value::Integer(1)], vec![Value::Integer(3)]] + ); + + let server_db = server.db(); + let server_conn = server_db.connect().unwrap(); + assert_eq!( + convert_rows(&mut server_conn.query("SELECT * FROM t", ()).await.unwrap()) + .await + .unwrap(), + vec![ + vec![Value::Integer(1)], + vec![Value::Integer(2)], + vec![Value::Integer(3)], + ] + ); + + db.execute("INSERT INTO t VALUES (4)", ()).await.unwrap(); + db.push().await.unwrap(); + + assert_eq!( + query_rows(&db, "SELECT * FROM t").await.unwrap(), + vec![ + vec![Value::Integer(1)], + vec![Value::Integer(3)], + vec![Value::Integer(4)] + ] + ); + + assert_eq!( + convert_rows(&mut server_conn.query("SELECT * FROM t", ()).await.unwrap()) + .await + .unwrap(), + vec![ + vec![Value::Integer(1)], + vec![Value::Integer(2)], + vec![Value::Integer(3)], + vec![Value::Integer(4)], + ] + ); + + db.pull().await.unwrap(); + assert_eq!( + query_rows(&db, "SELECT * FROM t").await.unwrap(), + vec![ + vec![Value::Integer(1)], + vec![Value::Integer(2)], + vec![Value::Integer(3)], + vec![Value::Integer(4)] + ] + ); + + assert_eq!( + convert_rows(&mut server_conn.query("SELECT * FROM t", ()).await.unwrap()) + .await + .unwrap(), + vec![ + vec![Value::Integer(1)], + vec![Value::Integer(2)], + vec![Value::Integer(3)], + vec![Value::Integer(4)], + ] + ); + }); + } + + #[test] + pub fn test_sync_single_db_full_syncs() { + deterministic_runtime(async || { + let dir = tempfile::TempDir::new().unwrap(); + let server_path = dir.path().join("server.db"); + let opts = TestSyncServerOpts { pull_batch_size: 1 }; + let ctx = Arc::new(TestContext::new(seed_u64())); + let server = TestSyncServer::new(ctx.clone(), &server_path, opts) + .await + .unwrap(); + let fs = TestFilesystem::new(ctx.clone()); + let local_path = dir.path().join("local.db"); + let mut db = DatabaseInner::new(fs, server.clone(), &local_path) + .await + .unwrap(); + + server + .execute("CREATE TABLE t(x INTEGER PRIMARY KEY)", ()) + .await + .unwrap(); + server + .execute("INSERT INTO t VALUES (1)", ()) + .await + .unwrap(); + + // no table in schema before sync from remote (as DB was initialized when remote was empty) + assert!(matches!( + query_rows(&db, "SELECT * FROM t").await, + Err(Error::TursoError(turso::Error::SqlExecutionFailure(x))) if x.contains("no such table: t") + )); + + db.sync().await.unwrap(); + assert_eq!( + query_rows(&db, "SELECT * FROM t").await.unwrap(), + vec![vec![Value::Integer(1)]] + ); + + db.execute("INSERT INTO t VALUES (2)", ()).await.unwrap(); + db.sync().await.unwrap(); + assert_eq!( + query_rows(&db, "SELECT * FROM t").await.unwrap(), + vec![vec![Value::Integer(1)], vec![Value::Integer(2)]] + ); + + db.execute("INSERT INTO t VALUES (3)", ()).await.unwrap(); + db.sync().await.unwrap(); + assert_eq!( + query_rows(&db, "SELECT * FROM t").await.unwrap(), + vec![ + vec![Value::Integer(1)], + vec![Value::Integer(2)], + vec![Value::Integer(3)] + ] + ); + }); + } + + #[test] + pub fn test_sync_multiple_dbs_conflict() { + deterministic_runtime(async || { + let dir = tempfile::TempDir::new().unwrap(); + let server_path = dir.path().join("server.db"); + let opts = TestSyncServerOpts { pull_batch_size: 1 }; + let ctx = Arc::new(TestContext::new(seed_u64())); + let server = TestSyncServer::new(ctx.clone(), &server_path, opts) + .await + .unwrap(); + let mut dbs = Vec::new(); + const CLIENTS: usize = 8; + for i in 0..CLIENTS { + let db = DatabaseInner::new( + TestFilesystem::new(ctx.clone()), + server.clone(), + &dir.path().join(format!("local-{i}.db")), + ) + .await + .unwrap(); + dbs.push(Arc::new(Mutex::new(db))); + } + + server + .execute("CREATE TABLE t(x INTEGER PRIMARY KEY)", ()) + .await + .unwrap(); + + for db in &mut dbs { + let mut db = db.lock().await; + db.pull().await.unwrap(); + } + for (i, db) in dbs.iter().enumerate() { + let db = db.lock().await; + db.execute("INSERT INTO t VALUES (?)", (i as i32,)) + .await + .unwrap(); + } + + let try_sync = || async { + let mut tasks = Vec::new(); + for db in &dbs { + let db = db.clone(); + tasks.push(async move { + let mut db = db.lock().await; + db.push().await + }); + } + futures::future::join_all(tasks).await + }; + for attempt in 0..CLIENTS { + let results = try_sync().await; + tracing::info!("attempt #{}: {:?}", attempt, results); + assert!(results.iter().filter(|x| x.is_ok()).count() >= attempt); + } + }); + } + + #[test] + pub fn test_sync_multiple_clients_no_conflicts_synchronized() { + deterministic_runtime(async || { + let dir = tempfile::TempDir::new().unwrap(); + let server_path = dir.path().join("server.db"); + let opts = TestSyncServerOpts { pull_batch_size: 1 }; + let ctx = Arc::new(TestContext::new(seed_u64())); + let server = TestSyncServer::new(ctx.clone(), &server_path, opts) + .await + .unwrap(); + + server + .execute("CREATE TABLE t(k INTEGER PRIMARY KEY, v)", ()) + .await + .unwrap(); + + let sync_lock = Arc::new(tokio::sync::Mutex::new(())); + let mut clients = Vec::new(); + const CLIENTS: usize = 10; + let mut expected_rows = Vec::new(); + for i in 0..CLIENTS { + let mut queries = Vec::new(); + let cnt = ctx.rng().await.next_u32() % CLIENTS as u32 + 1; + for q in 0..cnt { + let key = i * CLIENTS + q as usize; + let length = ctx.rng().await.next_u32() % 4096; + queries.push(format!( + "INSERT INTO t VALUES ({key}, randomblob({length}))", + )); + expected_rows.push(vec![ + Value::Integer(key as i64), + Value::Integer(length as i64), + ]); + } + clients.push(tokio::spawn({ + let dir = dir.path().to_path_buf().clone(); + let ctx = ctx.clone(); + let server = server.clone(); + let sync_lock = sync_lock.clone(); + async move { + let local_path = dir.join(format!("local-{i}.db")); + let fs = TestFilesystem::new(ctx.clone()); + let mut db = DatabaseInner::new(fs, server.clone(), &local_path) + .await + .unwrap(); + + db.pull().await.unwrap(); + for query in queries { + db.execute(&query, ()).await.unwrap(); + } + let guard = sync_lock.lock().await; + db.push().await.unwrap(); + drop(guard); + } + })); + } + for client in clients { + client.await.unwrap(); + } + let db = server.db(); + let conn = db.connect().unwrap(); + let mut result = conn.query("SELECT k, length(v) FROM t", ()).await.unwrap(); + let rows = convert_rows(&mut result).await.unwrap(); + assert_eq!(rows, expected_rows); + }); + } + + #[test] + pub fn test_sync_single_db_sync_from_remote_nothing_single_failure() { + deterministic_runtime(async || { + let dir = tempfile::TempDir::new().unwrap(); + let server_path = dir.path().join("server.db"); + let opts = TestSyncServerOpts { pull_batch_size: 1 }; + let ctx = Arc::new(TestContext::new(seed_u64())); + let server = TestSyncServer::new(ctx.clone(), &server_path, opts) + .await + .unwrap(); + + server.execute("CREATE TABLE t(x)", ()).await.unwrap(); + server + .execute("INSERT INTO t VALUES (1), (2), (3)", ()) + .await + .unwrap(); + + let mut session = ctx.fault_session(); + let mut it = 0; + while let Some(strategy) = session.next().await { + let local_path = dir.path().join(format!("local-{it}.db")); + it += 1; + + let fs = TestFilesystem::new(ctx.clone()); + let mut db = DatabaseInner::new(fs, server.clone(), &local_path) + .await + .unwrap(); + + let has_fault = matches!(strategy, FaultInjectionStrategy::Enabled { .. }); + + ctx.switch_mode(strategy).await; + let result = db.pull().await; + ctx.switch_mode(FaultInjectionStrategy::Disabled).await; + + if !has_fault { + result.unwrap(); + } else { + let err = result.err().unwrap(); + tracing::info!("error after fault injection: {}", err); + } + + let rows = query_rows(&db, "SELECT * FROM t").await.unwrap(); + assert_eq!( + rows, + vec![ + vec![Value::Integer(1)], + vec![Value::Integer(2)], + vec![Value::Integer(3)], + ] + ); + + db.pull().await.unwrap(); + + let rows = query_rows(&db, "SELECT * FROM t").await.unwrap(); + assert_eq!( + rows, + vec![ + vec![Value::Integer(1)], + vec![Value::Integer(2)], + vec![Value::Integer(3)], + ] + ); + } + }); + } + + #[test] + pub fn test_sync_single_db_sync_from_remote_single_failure() { + deterministic_runtime(async || { + let dir = tempfile::TempDir::new().unwrap(); + let opts = TestSyncServerOpts { pull_batch_size: 1 }; + let ctx = Arc::new(TestContext::new(seed_u64())); + + let mut session = ctx.fault_session(); + let mut it = 0; + while let Some(strategy) = session.next().await { + let server_path = dir.path().join(format!("server-{it}.db")); + let server = TestSyncServer::new(ctx.clone(), &server_path, opts.clone()) + .await + .unwrap(); + + server.execute("CREATE TABLE t(x)", ()).await.unwrap(); + + let local_path = dir.path().join(format!("local-{it}.db")); + it += 1; + + let fs = TestFilesystem::new(ctx.clone()); + let mut db = DatabaseInner::new(fs, server.clone(), &local_path) + .await + .unwrap(); + + server + .execute("INSERT INTO t VALUES (1), (2), (3)", ()) + .await + .unwrap(); + + let has_fault = matches!(strategy, FaultInjectionStrategy::Enabled { .. }); + + ctx.switch_mode(strategy).await; + let result = db.pull().await; + ctx.switch_mode(FaultInjectionStrategy::Disabled).await; + + if !has_fault { + result.unwrap(); + } else { + let err = result.err().unwrap(); + tracing::info!("error after fault injection: {}", err); + } + + let rows = query_rows(&db, "SELECT * FROM t").await.unwrap(); + assert!(rows.len() <= 3); + + db.pull().await.unwrap(); + + let rows = query_rows(&db, "SELECT * FROM t").await.unwrap(); + assert_eq!( + rows, + vec![ + vec![Value::Integer(1)], + vec![Value::Integer(2)], + vec![Value::Integer(3)], + ] + ); + } + }); + } + + #[test] + pub fn test_sync_single_db_sync_to_remote_single_failure() { + deterministic_runtime(async || { + let dir = tempfile::TempDir::new().unwrap(); + let opts = TestSyncServerOpts { pull_batch_size: 1 }; + let ctx = Arc::new(TestContext::new(seed_u64())); + + let mut session = ctx.fault_session(); + let mut it = 0; + while let Some(strategy) = session.next().await { + let server_path = dir.path().join(format!("server-{it}.db")); + let server = TestSyncServer::new(ctx.clone(), &server_path, opts.clone()) + .await + .unwrap(); + + server + .execute("CREATE TABLE t(x INTEGER PRIMARY KEY)", ()) + .await + .unwrap(); + + server + .execute("INSERT INTO t VALUES (1)", ()) + .await + .unwrap(); + + let local_path = dir.path().join(format!("local-{it}.db")); + it += 1; + + let fs = TestFilesystem::new(ctx.clone()); + let mut db = DatabaseInner::new(fs, server.clone(), &local_path) + .await + .unwrap(); + + db.execute("INSERT INTO t VALUES (2), (3)", ()) + .await + .unwrap(); + + let has_fault = matches!(strategy, FaultInjectionStrategy::Enabled { .. }); + + ctx.switch_mode(strategy).await; + let result = db.push().await; + ctx.switch_mode(FaultInjectionStrategy::Disabled).await; + + if !has_fault { + result.unwrap(); + } else { + let err = result.err().unwrap(); + tracing::info!("error after fault injection: {}", err); + } + + let server_db = server.db(); + let server_conn = server_db.connect().unwrap(); + let rows = + convert_rows(&mut server_conn.query("SELECT * FROM t", ()).await.unwrap()) + .await + .unwrap(); + assert!(rows.len() <= 3); + + db.push().await.unwrap(); + + let rows = + convert_rows(&mut server_conn.query("SELECT * FROM t", ()).await.unwrap()) + .await + .unwrap(); + assert_eq!( + rows, + vec![ + vec![Value::Integer(1)], + vec![Value::Integer(2)], + vec![Value::Integer(3)], + ] + ); + } + }); + } +} diff --git a/packages/turso-sync/src/database_tape.rs b/packages/turso-sync/src/database_tape.rs index 4610930dd..4b108301c 100644 --- a/packages/turso-sync/src/database_tape.rs +++ b/packages/turso-sync/src/database_tape.rs @@ -309,8 +309,12 @@ fn parse_bin_record(bin_record: Vec) -> Result> { #[cfg(test)] mod tests { use tempfile::NamedTempFile; + use turso::Value; - use crate::database_tape::DatabaseTape; + use crate::{ + database_tape::{DatabaseChangesIteratorOpts, DatabaseTape}, + types::DatabaseTapeOperation, + }; async fn fetch_rows(conn: &turso::Connection, query: &str) -> Vec> { let mut rows = vec![]; @@ -326,7 +330,7 @@ mod tests { } #[tokio::test] - async fn test_database_cdc() { + async fn test_database_cdc_single_iteration() { let temp_file1 = NamedTempFile::new().unwrap(); let db_path1 = temp_file1.path().to_str().unwrap(); @@ -475,4 +479,62 @@ mod tests { ]] ); } + + #[tokio::test] + async fn test_database_cdc_multiple_iterations() { + let temp_file1 = NamedTempFile::new().unwrap(); + let db_path1 = temp_file1.path().to_str().unwrap(); + + let temp_file2 = NamedTempFile::new().unwrap(); + let db_path2 = temp_file2.path().to_str().unwrap(); + + let db1 = turso::Builder::new_local(db_path1).build().await.unwrap(); + let db1 = DatabaseTape::new(db1); + let conn1 = db1.connect().await.unwrap(); + + let db2 = turso::Builder::new_local(db_path2).build().await.unwrap(); + let db2 = DatabaseTape::new(db2); + let conn2 = db2.connect().await.unwrap(); + + conn1 + .execute("CREATE TABLE a(x INTEGER PRIMARY KEY, y);", ()) + .await + .unwrap(); + conn2 + .execute("CREATE TABLE a(x INTEGER PRIMARY KEY, y);", ()) + .await + .unwrap(); + + let mut next_change_id = None; + let mut expected = Vec::new(); + for i in 0..10 { + conn1 + .execute("INSERT INTO a VALUES (?, 'hello')", (i,)) + .await + .unwrap(); + expected.push(vec![ + Value::Integer(i as i64), + Value::Text("hello".to_string()), + ]); + + let mut iterator = db1 + .iterate_changes(DatabaseChangesIteratorOpts { + first_change_id: next_change_id, + ..Default::default() + }) + .await + .unwrap(); + { + let mut replay = db2.start_tape_session().await.unwrap(); + while let Some(change) = iterator.next().await.unwrap() { + if let DatabaseTapeOperation::RowChange(change) = &change { + next_change_id = Some(change.change_id + 1); + } + replay.replay(change).await.unwrap(); + } + } + let conn2 = db2.connect().await.unwrap(); + assert_eq!(fetch_rows(&conn2, "SELECT * FROM a").await, expected); + } + } } diff --git a/packages/turso-sync/src/errors.rs b/packages/turso-sync/src/errors.rs index dc9191e46..21837c20f 100644 --- a/packages/turso-sync/src/errors.rs +++ b/packages/turso-sync/src/errors.rs @@ -1,9 +1,37 @@ +use crate::sync_server::DbSyncStatus; + #[derive(Debug, thiserror::Error)] pub enum Error { #[error("database error: {0}")] TursoError(turso::Error), #[error("database tape error: {0}")] DatabaseTapeError(String), + #[error("invalid URI: {0}")] + Uri(http::uri::InvalidUri), + #[error("invalid HTTP request: {0}")] + Http(http::Error), + #[error("HTTP request error: {0}")] + HyperRequest(hyper_util::client::legacy::Error), + #[error("HTTP response error: {0}")] + HyperResponse(hyper::Error), + #[error("deserialization error: {0}")] + JsonDecode(serde_json::Error), + #[error("unexpected sync server error: code={0}, info={1}")] + SyncServerError(http::StatusCode, String), + #[error("unexpected sync server status: {0:?}")] + SyncServerUnexpectedStatus(DbSyncStatus), + #[error("unexpected filesystem error: {0}")] + FilesystemError(std::io::Error), + #[error("local metadata error: {0}")] + MetadataError(String), + #[error("database sync error: {0}")] + DatabaseSyncError(String), + #[error("sync server pull error: checkpoint required: `{0:?}`")] + PullNeedCheckpoint(DbSyncStatus), + #[error("sync server push error: wal conflict detected")] + PushConflict, + #[error("sync server push error: inconsitent state on remote: `{0:?}`")] + PushInconsistent(DbSyncStatus), } impl From for Error { @@ -17,3 +45,15 @@ impl From for Error { Self::TursoError(value.into()) } } + +impl From for Error { + fn from(value: std::io::Error) -> Self { + Self::FilesystemError(value) + } +} + +impl From for Error { + fn from(value: serde_json::Error) -> Self { + Self::JsonDecode(value) + } +} diff --git a/packages/turso-sync/src/filesystem/mod.rs b/packages/turso-sync/src/filesystem/mod.rs new file mode 100644 index 000000000..027d18d0a --- /dev/null +++ b/packages/turso-sync/src/filesystem/mod.rs @@ -0,0 +1,42 @@ +#[cfg(test)] +pub mod test; +pub mod tokio; + +use crate::Result; +use std::path::Path; + +pub trait Filesystem { + type File; + fn exists_file(&self, path: &Path) -> impl std::future::Future> + Send; + fn remove_file(&self, path: &Path) -> impl std::future::Future> + Send; + fn create_file( + &self, + path: &Path, + ) -> impl std::future::Future> + Send; + fn open_file( + &self, + path: &Path, + ) -> impl std::future::Future> + Send; + fn copy_file( + &self, + src: &Path, + dst: &Path, + ) -> impl std::future::Future> + Send; + fn rename_file( + &self, + src: &Path, + dst: &Path, + ) -> impl std::future::Future> + Send; + fn truncate_file( + &self, + file: &Self::File, + size: usize, + ) -> impl std::future::Future> + Send; + fn write_file( + &self, + file: &mut Self::File, + buf: &[u8], + ) -> impl std::future::Future> + Send; + fn sync_file(&self, file: &Self::File) -> impl std::future::Future> + Send; + fn read_file(&self, path: &Path) -> impl std::future::Future>> + Send; +} diff --git a/packages/turso-sync/src/filesystem/test.rs b/packages/turso-sync/src/filesystem/test.rs new file mode 100644 index 000000000..207977e0e --- /dev/null +++ b/packages/turso-sync/src/filesystem/test.rs @@ -0,0 +1,97 @@ +use std::{io::Write, sync::Arc}; + +use crate::{filesystem::Filesystem, test_context::TestContext, Result}; + +pub struct TestFilesystem { + ctx: Arc, +} + +impl TestFilesystem { + pub fn new(ctx: Arc) -> Self { + Self { ctx } + } +} + +impl Filesystem for TestFilesystem { + type File = std::fs::File; + + async fn exists_file(&self, path: &std::path::Path) -> Result { + self.ctx.faulty_call("exists_file_start").await?; + let result = std::fs::exists(path)?; + self.ctx.faulty_call("exists_file_end").await?; + Ok(result) + } + + async fn remove_file(&self, path: &std::path::Path) -> Result<()> { + self.ctx.faulty_call("remove_file_start").await?; + match std::fs::remove_file(path) { + Ok(()) => Result::Ok(()), + Err(e) if e.kind() == std::io::ErrorKind::NotFound => Result::Ok(()), + Err(e) => Err(e.into()), + }?; + self.ctx.faulty_call("remove_file_end").await?; + Ok(()) + } + + async fn create_file(&self, path: &std::path::Path) -> Result { + self.ctx.faulty_call("create_file_start").await?; + let result = std::fs::OpenOptions::new() + .write(true) + .create_new(true) + .open(path)?; + self.ctx.faulty_call("create_file_end").await?; + Ok(result) + } + + async fn open_file(&self, path: &std::path::Path) -> Result { + self.ctx.faulty_call("open_file_start").await?; + let result = std::fs::OpenOptions::new() + .read(true) + .write(true) + .open(path)?; + self.ctx.faulty_call("open_file_end").await?; + Ok(result) + } + + async fn copy_file(&self, src: &std::path::Path, dst: &std::path::Path) -> Result<()> { + self.ctx.faulty_call("copy_file_start").await?; + std::fs::copy(src, dst)?; + self.ctx.faulty_call("copy_file_end").await?; + Ok(()) + } + + async fn rename_file(&self, src: &std::path::Path, dst: &std::path::Path) -> Result<()> { + self.ctx.faulty_call("rename_file_start").await?; + std::fs::rename(src, dst)?; + self.ctx.faulty_call("rename_file_end").await?; + Ok(()) + } + + async fn truncate_file(&self, file: &Self::File, size: usize) -> Result<()> { + self.ctx.faulty_call("truncate_file_start").await?; + file.set_len(size as u64)?; + self.ctx.faulty_call("truncate_file_end").await?; + Ok(()) + } + + async fn write_file(&self, file: &mut Self::File, buf: &[u8]) -> Result<()> { + self.ctx.faulty_call("write_file_start").await?; + file.write_all(buf)?; + self.ctx.faulty_call("write_file_end").await?; + Ok(()) + } + + async fn sync_file(&self, file: &Self::File) -> Result<()> { + self.ctx.faulty_call("sync_file_start").await?; + file.sync_all()?; + self.ctx.faulty_call("sync_file_end").await?; + Ok(()) + } + + async fn read_file(&self, path: &std::path::Path) -> Result> { + self.ctx.faulty_call("read_file_start").await?; + let data = std::fs::read(path)?; + self.ctx.faulty_call("read_file_end").await?; + Ok(data) + } +} diff --git a/packages/turso-sync/src/filesystem/tokio.rs b/packages/turso-sync/src/filesystem/tokio.rs new file mode 100644 index 000000000..b729d21c1 --- /dev/null +++ b/packages/turso-sync/src/filesystem/tokio.rs @@ -0,0 +1,78 @@ +use std::path::Path; + +use tokio::io::AsyncWriteExt; + +use crate::{filesystem::Filesystem, Result}; + +pub struct TokioFilesystem(); + +impl Filesystem for TokioFilesystem { + type File = tokio::fs::File; + + async fn exists_file(&self, path: &Path) -> Result { + tracing::debug!("check file exists at {:?}", path); + Ok(tokio::fs::try_exists(&path).await?) + } + + async fn remove_file(&self, path: &Path) -> Result<()> { + tracing::debug!("remove file at {:?}", path); + match tokio::fs::remove_file(path).await { + Ok(()) => Ok(()), + Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()), + Err(e) => Err(e.into()), + } + } + + async fn create_file(&self, path: &Path) -> Result { + tracing::debug!("create file at {:?}", path); + Ok(tokio::fs::File::create_new(path) + .await + .inspect_err(|e| tracing::error!("failed to create file at {:?}: {}", path, e))?) + } + + async fn open_file(&self, path: &Path) -> Result { + tracing::debug!("open file at {:?}", path); + Ok(tokio::fs::OpenOptions::new() + .write(true) + .read(true) + .open(path) + .await?) + } + + async fn copy_file(&self, src: &Path, dst: &Path) -> Result<()> { + tracing::debug!("copy file from {:?} to {:?}", src, dst); + tokio::fs::copy(&src, &dst).await?; + Ok(()) + } + + async fn rename_file(&self, src: &Path, dst: &Path) -> Result<()> { + tracing::debug!("rename file from {:?} to {:?}", src, dst); + tokio::fs::rename(&src, &dst) + .await + .inspect_err(|e| tracing::error!("failed to rename {:?} to {:?}: {}", src, dst, e))?; + Ok(()) + } + + async fn truncate_file(&self, file: &Self::File, size: usize) -> Result<()> { + tracing::debug!("truncate file to size {}", size); + file.set_len(size as u64).await?; + Ok(()) + } + + async fn write_file(&self, file: &mut Self::File, buf: &[u8]) -> Result<()> { + tracing::debug!("write buffer of size {} to file", buf.len()); + file.write_all(buf).await?; + Ok(()) + } + + async fn sync_file(&self, file: &Self::File) -> Result<()> { + tracing::debug!("sync file"); + file.sync_all().await?; + Ok(()) + } + + async fn read_file(&self, path: &Path) -> Result> { + tracing::debug!("read file {:?}", path); + Ok(tokio::fs::read(path).await?) + } +} diff --git a/packages/turso-sync/src/lib.rs b/packages/turso-sync/src/lib.rs index b280e1d4d..f5697c51e 100644 --- a/packages/turso-sync/src/lib.rs +++ b/packages/turso-sync/src/lib.rs @@ -1,9 +1,18 @@ +pub mod database; pub mod database_tape; pub mod errors; pub mod types; pub type Result = std::result::Result; +mod database_inner; +mod filesystem; +mod metadata; +mod sync_server; +#[cfg(test)] +mod test_context; +mod wal_session; + #[cfg(test)] mod tests { use tracing_subscriber::EnvFilter; @@ -12,6 +21,34 @@ mod tests { fn init() { tracing_subscriber::fmt() .with_env_filter(EnvFilter::from_default_env()) + // .with_ansi(false) .init(); } + + pub fn seed_u64() -> u64 { + seed().parse().unwrap_or(0) + } + + pub fn seed() -> String { + std::env::var("SEED").unwrap_or("0".to_string()) + } + + pub fn deterministic_runtime_from_seed>( + seed: &[u8], + f: impl Fn() -> F, + ) { + let seed = tokio::runtime::RngSeed::from_bytes(seed); + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_time() + .start_paused(true) + .rng_seed(seed) + .build_local(Default::default()) + .unwrap(); + runtime.block_on(f()); + } + + pub fn deterministic_runtime>(f: impl Fn() -> F) { + let seed = seed(); + deterministic_runtime_from_seed(seed.as_bytes(), f); + } } diff --git a/packages/turso-sync/src/metadata.rs b/packages/turso-sync/src/metadata.rs new file mode 100644 index 000000000..78d74fa33 --- /dev/null +++ b/packages/turso-sync/src/metadata.rs @@ -0,0 +1,112 @@ +use std::path::Path; + +use serde::{Deserialize, Serialize}; + +use crate::{errors::Error, filesystem::Filesystem, Result}; + +#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq)] +pub enum ActiveDatabase { + /// Draft database is the only one from the pair which can accept writes + /// It holds all local changes + Draft, + /// Synced database most of the time holds DB state from remote + /// We can temporary apply changes from Draft DB to it - but they will be reseted almost immediately + Synced, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct DatabaseMetadata { + /// Latest generation from remote which was pulled locally to the Synced DB + pub synced_generation: usize, + /// Latest frame number from remote which was pulled locally to the Synced DB + pub synced_frame_no: usize, + /// Latest change_id from CDC table in Draft DB which was successfully pushed to the remote through Synced DB + pub synced_change_id: Option, + /// Optional field which will store change_id from CDC table in Draft DB which was successfully transferred to the SyncedDB + /// but not durably pushed to the remote yet + /// + /// This can happen if WAL push will abort in the middle due to network partition, application crash, etc + pub transferred_change_id: Option, + /// Current active databasel + pub active_db: ActiveDatabase, +} + +impl DatabaseMetadata { + pub async fn read_from(fs: &impl Filesystem, path: &Path) -> Result> { + tracing::debug!("try read metadata from: {:?}", path); + if !fs.exists_file(path).await? { + tracing::debug!("no metadata found at {:?}", path); + return Ok(None); + } + let contents = fs.read_file(path).await?; + let meta = serde_json::from_slice::(&contents[..])?; + tracing::debug!("read metadata from {:?}: {:?}", path, meta); + Ok(Some(meta)) + } + pub async fn write_to(&self, fs: &impl Filesystem, path: &Path) -> Result<()> { + tracing::debug!("write metadata to {:?}: {:?}", path, self); + let directory = path.parent().ok_or_else(|| { + Error::MetadataError(format!( + "unable to get parent of the provided path: {path:?}", + )) + })?; + let filename = path + .file_name() + .and_then(|x| x.to_str()) + .ok_or_else(|| Error::MetadataError(format!("unable to get filename: {path:?}")))?; + + let timestamp = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH); + let timestamp = timestamp.map_err(|e| { + Error::MetadataError(format!("failed to get current time for temp file: {e}")) + })?; + let temp_name = format!("{}.tmp.{}", filename, timestamp.as_nanos()); + let temp_path = directory.join(temp_name); + + let data = serde_json::to_string(self)?; + + let mut temp_file = fs.create_file(&temp_path).await?; + let mut result = fs.write_file(&mut temp_file, data.as_bytes()).await; + if result.is_ok() { + result = fs.sync_file(&temp_file).await; + } + drop(temp_file); + if result.is_ok() { + result = fs.rename_file(&temp_path, path).await; + } + if result.is_err() { + let _ = fs.remove_file(&temp_path).await.inspect_err(|e| { + tracing::warn!("failed to remove temp file at {:?}: {}", temp_path, e) + }); + } + result + } +} + +#[cfg(test)] +mod tests { + use crate::{ + filesystem::tokio::TokioFilesystem, + metadata::{ActiveDatabase, DatabaseMetadata}, + }; + + #[tokio::test] + pub async fn metadata_simple_test() { + let dir = tempfile::TempDir::new().unwrap(); + let path = dir.path().join("db-info"); + let meta = DatabaseMetadata { + synced_generation: 1, + synced_frame_no: 2, + synced_change_id: Some(3), + transferred_change_id: Some(4), + active_db: ActiveDatabase::Draft, + }; + let fs = TokioFilesystem(); + meta.write_to(&fs, &path).await.unwrap(); + + let read = DatabaseMetadata::read_from(&fs, &path) + .await + .unwrap() + .unwrap(); + assert_eq!(meta, read); + } +} diff --git a/packages/turso-sync/src/sync_server/.gitignore b/packages/turso-sync/src/sync_server/.gitignore new file mode 100644 index 000000000..e199dc294 --- /dev/null +++ b/packages/turso-sync/src/sync_server/.gitignore @@ -0,0 +1 @@ +!empty_wal_mode.db diff --git a/packages/turso-sync/src/sync_server/empty_wal_mode.db b/packages/turso-sync/src/sync_server/empty_wal_mode.db new file mode 100644 index 000000000..0a06b0094 Binary files /dev/null and b/packages/turso-sync/src/sync_server/empty_wal_mode.db differ diff --git a/packages/turso-sync/src/sync_server/mod.rs b/packages/turso-sync/src/sync_server/mod.rs new file mode 100644 index 000000000..f9e9e8db1 --- /dev/null +++ b/packages/turso-sync/src/sync_server/mod.rs @@ -0,0 +1,46 @@ +use crate::Result; + +#[cfg(test)] +pub mod test; +pub mod turso; + +#[derive(Debug, serde::Deserialize)] +pub struct DbSyncInfo { + pub current_generation: usize, +} + +#[derive(Debug, serde::Deserialize)] +pub struct DbSyncStatus { + pub baton: Option, + pub status: String, + pub generation: usize, + pub max_frame_no: usize, +} + +pub trait Stream { + fn read_chunk( + &mut self, + ) -> impl std::future::Future>> + Send; +} + +pub trait SyncServer { + type Stream: Stream; + fn db_info(&self) -> impl std::future::Future> + Send; + fn db_export( + &self, + generation_id: usize, + ) -> impl std::future::Future> + Send; + fn wal_pull( + &self, + generation_id: usize, + start_frame: usize, + ) -> impl std::future::Future> + Send; + fn wal_push( + &self, + baton: Option, + generation_id: usize, + start_frame: usize, + end_frame: usize, + frames: Vec, + ) -> impl std::future::Future> + Send; +} diff --git a/packages/turso-sync/src/sync_server/test.rs b/packages/turso-sync/src/sync_server/test.rs new file mode 100644 index 000000000..32b9667b4 --- /dev/null +++ b/packages/turso-sync/src/sync_server/test.rs @@ -0,0 +1,285 @@ +use std::{collections::HashMap, path::Path, sync::Arc}; + +use tokio::sync::Mutex; +use turso::{IntoParams, Value}; + +use crate::{ + errors::Error, + sync_server::{DbSyncInfo, DbSyncStatus, Stream, SyncServer}, + test_context::TestContext, + Result, +}; + +struct Generation { + snapshot: Vec, + frames: Vec>, +} + +#[derive(Clone)] +struct SyncSession { + baton: String, + conn: turso::Connection, + in_txn: bool, +} + +struct TestSyncServerState { + generation: usize, + generations: HashMap, + sessions: HashMap, +} + +#[derive(Debug, Clone)] +pub struct TestSyncServerOpts { + pub pull_batch_size: usize, +} + +#[derive(Clone)] +pub struct TestSyncServer { + ctx: Arc, + db: turso::Database, + opts: Arc, + state: Arc>, +} + +pub struct TestStream { + ctx: Arc, + data: Vec, + position: usize, +} + +impl TestStream { + pub fn new(ctx: Arc, data: Vec) -> Self { + Self { + ctx, + data, + position: 0, + } + } +} + +impl Stream for TestStream { + async fn read_chunk(&mut self) -> Result> { + self.ctx + .faulty_call(if self.position == 0 { + "read_chunk_first" + } else { + "read_chunk_next" + }) + .await?; + let size = (self.data.len() - self.position).min(FRAME_SIZE); + if size == 0 { + Ok(None) + } else { + let chunk = &self.data[self.position..self.position + size]; + self.position += size; + Ok(Some(hyper::body::Bytes::copy_from_slice(chunk))) + } + } +} + +const PAGE_SIZE: usize = 4096; +const FRAME_SIZE: usize = 24 + PAGE_SIZE; + +impl SyncServer for TestSyncServer { + type Stream = TestStream; + async fn db_info(&self) -> Result { + self.ctx.faulty_call("db_info").await?; + + let state = self.state.lock().await; + Ok(DbSyncInfo { + current_generation: state.generation, + }) + } + + async fn db_export(&self, generation_id: usize) -> Result { + self.ctx.faulty_call("db_export").await?; + + let state = self.state.lock().await; + let Some(generation) = state.generations.get(&generation_id) else { + return Err(Error::DatabaseSyncError("generation not found".to_string())); + }; + Ok(TestStream::new( + self.ctx.clone(), + generation.snapshot.clone(), + )) + } + + async fn wal_pull(&self, generation_id: usize, start_frame: usize) -> Result { + tracing::debug!("wal_pull: {}/{}", generation_id, start_frame); + self.ctx.faulty_call("wal_pull").await?; + + let state = self.state.lock().await; + let Some(generation) = state.generations.get(&generation_id) else { + return Err(Error::DatabaseSyncError("generation not found".to_string())); + }; + let mut data = Vec::new(); + for frame_no in start_frame..start_frame + self.opts.pull_batch_size { + let frame_idx = frame_no - 1; + let Some(frame) = generation.frames.get(frame_idx) else { + break; + }; + data.extend_from_slice(frame); + } + if data.is_empty() { + let last_generation = state.generations.get(&state.generation).unwrap(); + return Err(Error::PullNeedCheckpoint(DbSyncStatus { + baton: None, + status: "checkpoint_needed".to_string(), + generation: state.generation, + max_frame_no: last_generation.frames.len(), + })); + } + Ok(TestStream::new(self.ctx.clone(), data)) + } + + async fn wal_push( + &self, + mut baton: Option, + generation_id: usize, + start_frame: usize, + end_frame: usize, + frames: Vec, + ) -> Result { + tracing::debug!( + "wal_push: {}/{}/{}/{:?}", + generation_id, + start_frame, + end_frame, + baton + ); + self.ctx.faulty_call("wal_push").await?; + + let mut session = { + let mut state = self.state.lock().await; + if state.generation != generation_id { + return Err(Error::DatabaseSyncError( + "generation id mismatch".to_string(), + )); + } + let baton_str = baton.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); + let session = match state.sessions.get(&baton_str) { + Some(session) => session.clone(), + None => { + let session = SyncSession { + baton: baton_str.clone(), + conn: self.db.connect()?, + in_txn: false, + }; + state.sessions.insert(baton_str.clone(), session.clone()); + session + } + }; + baton = Some(baton_str.clone()); + session + }; + + let mut offset = 0; + for frame_no in start_frame..end_frame { + if offset + FRAME_SIZE > frames.len() { + return Err(Error::DatabaseSyncError( + "unexpected length of frames data".to_string(), + )); + } + if !session.in_txn { + session.conn.wal_insert_begin()?; + session.in_txn = true; + } + let frame = &frames[offset..offset + FRAME_SIZE]; + match session.conn.wal_insert_frame(frame_no as u32, frame) { + Ok(info) => { + if info.is_commit { + if session.in_txn { + session.conn.wal_insert_end()?; + session.in_txn = false; + } + self.sync_frames_from_conn(&session.conn).await?; + } + } + Err(turso::Error::WalOperationError(err)) if err.contains("Conflict") => { + session.conn.wal_insert_end()?; + return Err(Error::PushConflict); + } + Err(err) => { + session.conn.wal_insert_end()?; + return Err(err.into()); + } + } + offset += FRAME_SIZE; + } + let mut state = self.state.lock().await; + state + .sessions + .insert(baton.clone().unwrap(), session.clone()); + Ok(DbSyncStatus { + baton: Some(session.baton.clone()), + status: "ok".into(), + generation: state.generation, + max_frame_no: session.conn.wal_frame_count()? as usize, + }) + } +} + +// empty DB with single 4096-byte page and WAL mode (PRAGMA journal_mode=WAL) +// see test test_empty_wal_mode_db_content which validates asset content +pub const EMPTY_WAL_MODE_DB: &[u8] = include_bytes!("empty_wal_mode.db"); + +pub async fn convert_rows(rows: &mut turso::Rows) -> Result>> { + let mut rows_values = vec![]; + while let Some(row) = rows.next().await? { + let mut row_values = vec![]; + for i in 0..row.column_count() { + row_values.push(row.get_value(i)?); + } + rows_values.push(row_values); + } + Ok(rows_values) +} + +impl TestSyncServer { + pub async fn new(ctx: Arc, path: &Path, opts: TestSyncServerOpts) -> Result { + let mut generations = HashMap::new(); + generations.insert( + 1, + Generation { + snapshot: EMPTY_WAL_MODE_DB.to_vec(), + frames: Vec::new(), + }, + ); + Ok(Self { + ctx, + db: turso::Builder::new_local(path.to_str().unwrap()) + .build() + .await?, + opts: Arc::new(opts), + state: Arc::new(Mutex::new(TestSyncServerState { + generation: 1, + generations, + sessions: HashMap::new(), + })), + }) + } + pub fn db(&self) -> turso::Database { + self.db.clone() + } + pub async fn execute(&self, sql: &str, params: impl IntoParams) -> Result<()> { + let conn = self.db.connect()?; + conn.execute(sql, params).await?; + self.sync_frames_from_conn(&conn).await?; + Ok(()) + } + async fn sync_frames_from_conn(&self, conn: &turso::Connection) -> Result<()> { + let mut state = self.state.lock().await; + let generation = state.generation; + let generation = state.generations.get_mut(&generation).unwrap(); + let last_frame = generation.frames.len() + 1; + let mut frame = [0u8; FRAME_SIZE]; + let wal_frame_count = conn.wal_frame_count()?; + tracing::debug!("conn frames count: {}", wal_frame_count); + for frame_no in last_frame..=wal_frame_count as usize { + conn.wal_get_frame(frame_no as u32, &mut frame)?; + tracing::debug!("push local frame {}", frame_no); + generation.frames.push(frame.to_vec()); + } + Ok(()) + } +} diff --git a/packages/turso-sync/src/sync_server/turso.rs b/packages/turso-sync/src/sync_server/turso.rs new file mode 100644 index 000000000..963103f45 --- /dev/null +++ b/packages/turso-sync/src/sync_server/turso.rs @@ -0,0 +1,202 @@ +use std::io::Read; + +use http::request; +use http_body_util::BodyExt; +use hyper::body::{Buf, Bytes}; + +use crate::{ + errors::Error, + sync_server::{DbSyncInfo, DbSyncStatus, Stream, SyncServer}, + Result, +}; + +pub type Client = hyper_util::client::legacy::Client< + hyper_rustls::HttpsConnector, + http_body_util::Full, +>; + +const DEFAULT_PULL_BATCH_SIZE: usize = 100; + +pub struct TursoSyncServerOpts { + pub sync_url: String, + pub auth_token: Option, + pub encryption_key: Option, + pub pull_batch_size: Option, +} + +pub struct TursoSyncServer { + client: Client, + auth_token_header: Option, + opts: TursoSyncServerOpts, +} + +fn sync_server_error(status: http::StatusCode, body: impl Buf) -> Error { + let mut body_str = String::new(); + if let Err(e) = body.reader().read_to_string(&mut body_str) { + Error::SyncServerError(status, format!("unable to read response body: {e}")) + } else { + Error::SyncServerError(status, body_str) + } +} + +async fn aggregate_body(body: hyper::body::Incoming) -> Result { + let chunks = body.collect().await; + let chunks = chunks.map_err(Error::HyperResponse)?; + Ok(chunks.aggregate()) +} + +pub struct HyperStream { + body: hyper::body::Incoming, +} + +impl Stream for HyperStream { + async fn read_chunk(&mut self) -> Result> { + let Some(frame) = self.body.frame().await else { + return Ok(None); + }; + let frame = frame.map_err(Error::HyperResponse)?; + let frame = frame + .into_data() + .map_err(|_| Error::DatabaseSyncError("failed to read export chunk".to_string()))?; + Ok(Some(frame)) + } +} + +impl TursoSyncServer { + pub fn new(client: Client, opts: TursoSyncServerOpts) -> Result { + let auth_token_header = opts + .auth_token + .as_ref() + .map(|token| hyper::header::HeaderValue::from_str(&format!("Bearer {token}"))) + .transpose() + .map_err(|e| Error::Http(e.into()))?; + Ok(Self { + client, + opts, + auth_token_header, + }) + } + async fn send( + &self, + method: http::Method, + url: &str, + body: http_body_util::Full, + ) -> Result<(http::StatusCode, hyper::body::Incoming)> { + let url: hyper::Uri = url.parse().map_err(Error::Uri)?; + let mut request = request::Builder::new().uri(url).method(method); + if let Some(auth_token_header) = &self.auth_token_header { + request = request.header("Authorization", auth_token_header); + } + if let Some(encryption_key) = &self.opts.encryption_key { + request = request.header("x-turso-encryption-key", encryption_key); + } + let request = request.body(body).map_err(Error::Http)?; + let response = self.client.request(request).await; + let response = response.map_err(Error::HyperRequest)?; + let status = response.status(); + Ok((status, response.into_body())) + } +} + +impl SyncServer for TursoSyncServer { + type Stream = HyperStream; + async fn db_info(&self) -> Result { + tracing::debug!("db_info"); + let url = format!("{}/info", self.opts.sync_url); + let empty = http_body_util::Full::new(Bytes::new()); + let (status, body) = self.send(http::Method::GET, &url, empty).await?; + let body = aggregate_body(body).await?; + + if !status.is_success() { + return Err(sync_server_error(status, body)); + } + + let info = serde_json::from_reader(body.reader()).map_err(Error::JsonDecode)?; + tracing::debug!("db_info response: {:?}", info); + Ok(info) + } + + async fn wal_push( + &self, + baton: Option, + generation_id: usize, + start_frame: usize, + end_frame: usize, + frames: Vec, + ) -> Result { + tracing::debug!( + "wal_push: {}/{}/{} (baton: {:?})", + generation_id, + start_frame, + end_frame, + baton + ); + let url = if let Some(baton) = baton { + format!( + "{}/sync/{}/{}/{}/{}", + self.opts.sync_url, generation_id, start_frame, end_frame, baton + ) + } else { + format!( + "{}/sync/{}/{}/{}", + self.opts.sync_url, generation_id, start_frame, end_frame + ) + }; + let body = http_body_util::Full::new(Bytes::from(frames)); + let (status_code, body) = self.send(http::Method::POST, &url, body).await?; + let body = aggregate_body(body).await?; + + if !status_code.is_success() { + return Err(sync_server_error(status_code, body)); + } + + let status: DbSyncStatus = + serde_json::from_reader(body.reader()).map_err(Error::JsonDecode)?; + + match status.status.as_str() { + "ok" => Ok(status), + "conflict" => Err(Error::PushConflict), + "push_needed" => Err(Error::PushInconsistent(status)), + _ => Err(Error::SyncServerUnexpectedStatus(status)), + } + } + + async fn db_export(&self, generation_id: usize) -> Result { + tracing::debug!("db_export: {}", generation_id); + let url = format!("{}/export/{}", self.opts.sync_url, generation_id); + let empty = http_body_util::Full::new(Bytes::new()); + let (status, body) = self.send(http::Method::GET, &url, empty).await?; + if !status.is_success() { + let body = aggregate_body(body).await?; + return Err(sync_server_error(status, body)); + } + Ok(HyperStream { body }) + } + + async fn wal_pull(&self, generation_id: usize, start_frame: usize) -> Result { + let batch = self.opts.pull_batch_size.unwrap_or(DEFAULT_PULL_BATCH_SIZE); + let end_frame = start_frame + batch; + tracing::debug!("wall_pull: {}/{}/{}", generation_id, start_frame, end_frame); + let url = format!( + "{}/sync/{}/{}/{}", + self.opts.sync_url, generation_id, start_frame, end_frame + ); + let empty = http_body_util::Full::new(Bytes::new()); + let (status, body) = self.send(http::Method::GET, &url, empty).await?; + if status == http::StatusCode::BAD_REQUEST { + let body = aggregate_body(body).await?; + let status: DbSyncStatus = + serde_json::from_reader(body.reader()).map_err(Error::JsonDecode)?; + if status.status == "checkpoint_needed" { + return Err(Error::PullNeedCheckpoint(status)); + } else { + return Err(Error::SyncServerUnexpectedStatus(status)); + } + } + if !status.is_success() { + let body = aggregate_body(body).await?; + return Err(sync_server_error(status, body)); + } + Ok(HyperStream { body }) + } +} diff --git a/packages/turso-sync/src/test_context.rs b/packages/turso-sync/src/test_context.rs new file mode 100644 index 000000000..9f18294ea --- /dev/null +++ b/packages/turso-sync/src/test_context.rs @@ -0,0 +1,139 @@ +use std::{ + collections::{HashMap, HashSet}, + future::Future, + pin::Pin, + sync::Arc, +}; + +use rand::SeedableRng; +use rand_chacha::ChaCha8Rng; +use tokio::sync::Mutex; + +use crate::{errors::Error, Result}; + +type PinnedFuture = Pin + Send>>; + +pub struct FaultInjectionPlan { + pub is_fault: Box PinnedFuture + Send + Sync>, +} + +pub enum FaultInjectionStrategy { + Disabled, + Record, + Enabled { plan: FaultInjectionPlan }, +} + +impl std::fmt::Debug for FaultInjectionStrategy { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Disabled => write!(f, "Disabled"), + Self::Record => write!(f, "Record"), + Self::Enabled { .. } => write!(f, "Enabled"), + } + } +} + +pub struct TestContext { + fault_injection: Mutex, + faulty_call: Mutex>, + rng: Mutex, +} + +pub struct FaultSession { + ctx: Arc, + recording: bool, + plans: Option>, +} + +impl FaultSession { + pub async fn next(&mut self) -> Option { + if !self.recording { + self.recording = true; + return Some(FaultInjectionStrategy::Record); + } + if self.plans.is_none() { + self.plans = Some(self.ctx.enumerate_simple_plans().await); + } + + let plans = self.plans.as_mut().unwrap(); + if plans.is_empty() { + return None; + } + + let plan = plans.pop().unwrap(); + Some(FaultInjectionStrategy::Enabled { plan }) + } +} + +impl TestContext { + pub fn new(seed: u64) -> Self { + Self { + rng: Mutex::new(ChaCha8Rng::seed_from_u64(seed)), + fault_injection: Mutex::new(FaultInjectionStrategy::Disabled), + faulty_call: Mutex::new(HashSet::new()), + } + } + pub async fn rng(&self) -> tokio::sync::MutexGuard { + self.rng.lock().await + } + pub fn fault_session(self: &Arc) -> FaultSession { + FaultSession { + ctx: self.clone(), + recording: false, + plans: None, + } + } + pub async fn switch_mode(&self, updated: FaultInjectionStrategy) { + let mut mode = self.fault_injection.lock().await; + tracing::info!("switch fault injection mode: {:?}", updated); + *mode = updated; + } + pub async fn enumerate_simple_plans(&self) -> Vec { + let mut plans = vec![]; + for call in self.faulty_call.lock().await.iter() { + let mut fault_counts = HashMap::new(); + fault_counts.insert(call.clone(), 1); + + let count = Arc::new(Mutex::new(1)); + let call = call.clone(); + plans.push(FaultInjectionPlan { + is_fault: Box::new(move |name, bt| { + let call = call.clone(); + let count = count.clone(); + Box::pin(async move { + if (name, bt) != call { + return false; + } + let mut count = count.lock().await; + *count -= 1; + *count >= 0 + }) + }), + }) + } + plans + } + pub async fn faulty_call(&self, name: &str) -> Result<()> { + tracing::trace!("faulty_call: {}", name); + tokio::task::yield_now().await; + if let FaultInjectionStrategy::Disabled = &*self.fault_injection.lock().await { + return Ok(()); + } + let bt = std::backtrace::Backtrace::force_capture().to_string(); + match &mut *self.fault_injection.lock().await { + FaultInjectionStrategy::Record => { + let mut call_sites = self.faulty_call.lock().await; + call_sites.insert((name.to_string(), bt)); + Ok(()) + } + FaultInjectionStrategy::Enabled { plan } => { + if plan.is_fault.as_ref()(name.to_string(), bt.clone()).await { + Err(Error::DatabaseSyncError("injected fault".to_string())) + } else { + Ok(()) + } + } + _ => unreachable!("Disabled case handled above"), + } + } +} diff --git a/packages/turso-sync/src/wal_session.rs b/packages/turso-sync/src/wal_session.rs new file mode 100644 index 000000000..88846de51 --- /dev/null +++ b/packages/turso-sync/src/wal_session.rs @@ -0,0 +1,40 @@ +use crate::Result; + +pub struct WalSession<'a> { + conn: &'a turso::Connection, + in_txn: bool, +} + +impl<'a> WalSession<'a> { + pub fn new(conn: &'a turso::Connection) -> Self { + Self { + conn, + in_txn: false, + } + } + pub fn begin(&mut self) -> Result<()> { + assert!(!self.in_txn); + self.conn.wal_insert_begin()?; + self.in_txn = true; + Ok(()) + } + pub fn end(&mut self) -> Result<()> { + assert!(self.in_txn); + self.conn.wal_insert_end()?; + self.in_txn = false; + Ok(()) + } + pub fn in_txn(&self) -> bool { + self.in_txn + } +} + +impl<'a> Drop for WalSession<'a> { + fn drop(&mut self) { + if self.in_txn { + let _ = self + .end() + .inspect_err(|e| tracing::error!("failed to close WAL session: {}", e)); + } + } +}