use clap::Parser; use rand::rngs::SmallRng; use rand::{Rng, RngCore, SeedableRng}; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Arc, Barrier}; use std::time::{Duration, Instant}; use turso::{Builder, Database, EncryptionOpts, Result}; #[derive(Parser)] #[command(name = "encryption-throughput")] #[command(about = "Encryption throughput benchmark on Turso DB")] struct Args { /// More than one thread does not work yet #[arg(short = 't', long = "threads", default_value = "1")] threads: usize, /// the number operations per transaction #[arg(short = 'b', long = "batch-size", default_value = "100")] batch_size: usize, /// number of transactions per thread #[arg(short = 'i', long = "iterations", default_value = "10")] iterations: usize, #[arg( short = 'r', long = "read-ratio", help = "Percentage of operations that should be reads (0-100)" )] read_ratio: Option, #[arg( short = 'w', long = "write-ratio", help = "Percentage of operations that should be writes (0-100)" )] write_ratio: Option, #[arg( long = "encryption", action = clap::ArgAction::SetTrue, help = "Enable database encryption" )] encryption: bool, #[arg( long = "cipher", default_value = "aegis-256", help = "Encryption cipher to use (only relevant if --encryption is set)" )] cipher: String, #[arg( long = "think", default_value = "0", help = "Per transaction think time (ms)" )] think: u64, #[arg( long = "timeout", default_value = "30000", help = "Busy timeout in milliseconds" )] timeout: u64, #[arg( long = "seed", default_value = "2167532792061351037", help = "Random seed for reproducible workloads" )] seed: u64, } #[derive(Debug)] struct WorkerStats { transactions_completed: u64, reads_completed: u64, writes_completed: u64, reads_found: u64, reads_not_found: u64, total_transaction_time: Duration, total_read_time: Duration, total_write_time: Duration, } #[derive(Debug, Clone)] struct SharedState { max_inserted_id: Arc, } #[tokio::main] async fn main() -> Result<()> { let _ = tracing_subscriber::fmt::try_init(); let args = Args::parse(); let read_ratio = match (args.read_ratio, args.write_ratio) { (Some(_), Some(_)) => { eprintln!("Error: Cannot specify both --read-ratio and --write-ratio"); std::process::exit(1); } (Some(r), None) => { if r > 100 { eprintln!("Error: read-ratio must be between 0 and 100"); std::process::exit(1); } r } (None, Some(w)) => { if w > 100 { eprintln!("Error: write-ratio must be between 0 and 100"); std::process::exit(1); } 100 - w } // lets default to 0% reads (100% writes) (None, None) => 0, }; println!( "Running encryption throughput benchmark with {} threads, {} batch size, {} iterations", args.threads, args.batch_size, args.iterations ); println!( "Read/Write ratio: {}% reads, {}% writes", read_ratio, 100 - read_ratio ); println!("Encryption enabled: {}", args.encryption); println!("Random seed: {}", args.seed); let encryption_opts = if args.encryption { let mut key_rng = SmallRng::seed_from_u64(args.seed); let key_size = get_key_size_for_cipher(&args.cipher); let mut key = vec![0u8; key_size]; key_rng.fill_bytes(&mut key); let config = EncryptionOpts { cipher: args.cipher.clone(), hexkey: hex::encode(&key), }; println!("Cipher: {}", config.cipher); println!("Hexkey: {}", config.hexkey); Some(config) } else { None }; let db_path = "encryption_throughput_test.db"; if std::path::Path::new(db_path).exists() { std::fs::remove_file(db_path).expect("Failed to remove existing database"); } let wal_path = "encryption_throughput_test.db-wal"; if std::path::Path::new(wal_path).exists() { std::fs::remove_file(wal_path).expect("Failed to remove existing WAL file"); } let db = setup_database(db_path, &encryption_opts).await?; // for create a var which is shared between all the threads, this we use to track the // max inserted id so that we only read these let shared_state = SharedState { max_inserted_id: Arc::new(AtomicU64::new(0)), }; let start_barrier = Arc::new(Barrier::new(args.threads)); let mut handles = Vec::new(); let timeout = Duration::from_millis(args.timeout); let overall_start = Instant::now(); for thread_id in 0..args.threads { let db_clone = db.clone(); let barrier = Arc::clone(&start_barrier); let encryption_opts_clone = encryption_opts.clone(); let shared_state_clone = shared_state.clone(); let handle = tokio::task::spawn(worker_thread( thread_id, db_clone, args.batch_size, args.iterations, barrier, read_ratio, encryption_opts_clone, args.think, timeout, shared_state_clone, args.seed, )); handles.push(handle); } let mut total_transactions = 0; let mut total_reads = 0; let mut total_writes = 0; let mut total_reads_found = 0; let mut total_reads_not_found = 0; let mut total_read_time = Duration::ZERO; let mut total_write_time = Duration::ZERO; for (idx, handle) in handles.into_iter().enumerate() { match handle.await { Ok(Ok(stats)) => { total_transactions += stats.transactions_completed; total_reads += stats.reads_completed; total_writes += stats.writes_completed; total_reads_found += stats.reads_found; total_reads_not_found += stats.reads_not_found; total_read_time += stats.total_read_time; total_write_time += stats.total_write_time; } Ok(Err(e)) => { eprintln!("Thread error {idx}: {e}"); return Err(e); } Err(_) => { eprintln!("Thread panicked"); std::process::exit(1); } } } let overall_elapsed = overall_start.elapsed(); let total_operations = total_reads + total_writes; let transaction_throughput = (total_transactions as f64) / overall_elapsed.as_secs_f64(); let operation_throughput = (total_operations as f64) / overall_elapsed.as_secs_f64(); let read_throughput = if total_reads > 0 { (total_reads as f64) / overall_elapsed.as_secs_f64() } else { 0.0 }; let write_throughput = if total_writes > 0 { (total_writes as f64) / overall_elapsed.as_secs_f64() } else { 0.0 }; let avg_ops_per_txn = (total_operations as f64) / (total_transactions as f64); // Turso DB is too damn fast, so lets measure latencies in microseconds let avg_read_latency_us = if total_reads > 0 { total_read_time.as_micros() as f64 / total_reads as f64 } else { 0.0 }; let avg_write_latency_us = if total_writes > 0 { total_write_time.as_micros() as f64 / total_writes as f64 } else { 0.0 }; println!("\n=== BENCHMARK RESULTS ==="); println!("Total transactions: {total_transactions}"); println!("Total operations: {total_operations}"); println!("Operations per transaction: {avg_ops_per_txn:.1}"); println!("Total time: {:.2}s", overall_elapsed.as_secs_f64()); println!(); println!("Transaction throughput: {transaction_throughput:.2} txns/sec"); println!("Operation throughput: {operation_throughput:.2} ops/sec"); // not found should be zero since track the max inserted id // todo(v): probably handle the not found error and remove max id if total_reads > 0 { println!( " - Read operations: {total_reads} ({total_reads_found} found, {total_reads_not_found} not found)" ); println!(" - Read throughput: {read_throughput:.2} reads/sec"); println!(" - Read total time: {:.2}s", total_read_time.as_secs_f64()); println!(" - Read avg latency: {avg_read_latency_us:.2}μs"); } if total_writes > 0 { println!(" - Write operations: {total_writes}"); println!(" - Write throughput: {write_throughput:.2} writes/sec"); println!( " - Write total time: {:.2}s", total_write_time.as_secs_f64() ); println!(" - Write avg latency: {avg_write_latency_us:.2}μs"); } println!("\nConfiguration:"); println!("Threads: {}", args.threads); println!("Batch size: {}", args.batch_size); println!("Iterations per thread: {}", args.iterations); println!("Encryption: {}", args.encryption); println!("Seed: {}", args.seed); if let Ok(metadata) = std::fs::metadata(db_path) { println!("Database file size: {} bytes", metadata.len()); } Ok(()) } fn get_key_size_for_cipher(cipher: &str) -> usize { match cipher.to_lowercase().as_str() { "aes-128-gcm" | "aegis-128l" | "aegis-128x2" | "aegis-128x4" => 16, "aes-256-gcm" | "aegis-256" | "aegis-256x2" | "aegis-256x4" => 32, _ => 32, // default to 256-bit key } } async fn setup_database( db_path: &str, encryption_opts: &Option, ) -> Result { let builder = Builder::new_local(db_path); let builder = if let Some(opts) = encryption_opts { builder .experimental_encryption(true) .with_encryption(opts.clone()) } else { builder }; let db = builder.build().await?; let conn = db.connect()?; if let Some(config) = encryption_opts { conn.execute(&format!("PRAGMA cipher='{}'", config.cipher), ()) .await?; conn.execute(&format!("PRAGMA hexkey='{}'", config.hexkey), ()) .await?; } // todo(v): probably store blobs and then have option of randomblob size conn.execute( "CREATE TABLE IF NOT EXISTS test_table ( id INTEGER PRIMARY KEY, data TEXT NOT NULL )", (), ) .await?; println!("Database created at: {db_path}"); Ok(db) } #[allow(clippy::too_many_arguments)] async fn worker_thread( thread_id: usize, db: Database, batch_size: usize, iterations: usize, start_barrier: Arc, read_ratio: u8, encryption_opts: Option, think_ms: u64, timeout: Duration, shared_state: SharedState, base_seed: u64, ) -> Result { start_barrier.wait(); let start_time = Instant::now(); let mut stats = WorkerStats { transactions_completed: 0, reads_completed: 0, writes_completed: 0, reads_found: 0, reads_not_found: 0, total_transaction_time: Duration::ZERO, total_read_time: Duration::ZERO, total_write_time: Duration::ZERO, }; let thread_seed = base_seed.wrapping_add(thread_id as u64); let mut rng = SmallRng::seed_from_u64(thread_seed); for iteration in 0..iterations { let conn = db.connect()?; if let Some(config) = &encryption_opts { conn.execute(&format!("PRAGMA cipher='{}'", config.cipher), ()) .await?; conn.execute(&format!("PRAGMA hexkey='{}'", config.hexkey), ()) .await?; } conn.busy_timeout(timeout)?; let mut insert_stmt = conn .prepare("INSERT INTO test_table (id, data) VALUES (?, ?)") .await?; let transaction_start = Instant::now(); conn.execute("BEGIN", ()).await?; for i in 0..batch_size { let should_read = rng.random_range(0..100) < read_ratio; if should_read { // only attempt reads if we have inserted some data let max_id = shared_state.max_inserted_id.load(Ordering::Relaxed); if max_id > 0 { let read_id = rng.random_range(1..=max_id); let read_start = Instant::now(); let row = conn .query( "SELECT data FROM test_table WHERE id = ?", turso::params::Params::Positional(vec![turso::Value::Integer( read_id as i64, )]), ) .await; let read_elapsed = read_start.elapsed(); stats.total_read_time += read_elapsed; match row { Ok(_) => stats.reads_found += 1, Err(turso::Error::QueryReturnedNoRows) => stats.reads_not_found += 1, Err(e) => return Err(e), }; stats.reads_completed += 1; } else { // if no data inserted yet, convert to a write let id = thread_id * iterations * batch_size + iteration * batch_size + i + 1; let write_start = Instant::now(); insert_stmt .execute(turso::params::Params::Positional(vec![ turso::Value::Integer(id as i64), turso::Value::Text(format!("data_{id}")), ])) .await?; let write_elapsed = write_start.elapsed(); stats.total_write_time += write_elapsed; shared_state .max_inserted_id .fetch_max(id as u64, Ordering::Relaxed); stats.writes_completed += 1; } } else { let id = thread_id * iterations * batch_size + iteration * batch_size + i + 1; // lets measure the "write" time. do note that, this write might not hit disk and do // i/o, since fsync is called only on commit at the end. but thats okay since // we want to compare overhead of encryption. let write_start = Instant::now(); insert_stmt .execute(turso::params::Params::Positional(vec![ turso::Value::Integer(id as i64), turso::Value::Text(format!("data_{id}")), ])) .await?; let write_elapsed = write_start.elapsed(); stats.total_write_time += write_elapsed; shared_state .max_inserted_id .fetch_max(id as u64, Ordering::Relaxed); stats.writes_completed += 1; } } if think_ms > 0 { tokio::time::sleep(Duration::from_millis(think_ms)).await; } conn.execute("COMMIT", ()).await?; let transaction_elapsed = transaction_start.elapsed(); stats.transactions_completed += 1; stats.total_transaction_time += transaction_elapsed; } let elapsed = start_time.elapsed(); let total_ops = stats.reads_completed + stats.writes_completed; let transaction_throughput = (stats.transactions_completed as f64) / elapsed.as_secs_f64(); let operation_throughput = (total_ops as f64) / elapsed.as_secs_f64(); let avg_txn_latency = stats.total_transaction_time.as_secs_f64() * 1000.0 / stats.transactions_completed as f64; let avg_read_latency = if stats.reads_completed > 0 { (stats.total_read_time.as_secs_f64() * 1_000_000.0) / (stats.reads_completed as f64) } else { 0.0 }; let avg_write_latency = if stats.writes_completed > 0 { (stats.total_write_time.as_secs_f64() * 1_000_000.0) / (stats.writes_completed as f64) } else { 0.0 }; println!( "Thread {}: {} txns ({} ops: {} reads, {} writes) in {:.2}s ({:.2} txns/sec, {:.2} ops/sec, {:.2}ms avg txn latency)", thread_id, stats.transactions_completed, total_ops, stats.reads_completed, stats.writes_completed, elapsed.as_secs_f64(), transaction_throughput, operation_throughput, avg_txn_latency ); if stats.reads_completed > 0 { println!( " Thread {} reads: {} found, {} not found (avg latency: {:.2}μs)", thread_id, stats.reads_found, stats.reads_not_found, avg_read_latency ); } if stats.writes_completed > 0 { println!( " Thread {} writes: {} (avg latency: {:.2}μs)", thread_id, stats.writes_completed, avg_write_latency ); } Ok(stats) }