From a1b46697f3c48a24c46e4577c966129bcf8ea9a1 Mon Sep 17 00:00:00 2001 From: Alex Hancock Date: Mon, 21 Apr 2025 14:36:12 -0400 Subject: [PATCH] fix: extension config write locks (#2283) --- Cargo.lock | 11 + crates/goose/Cargo.toml | 1 + crates/goose/src/config/base.rs | 189 +++++++++++++++++- ui/desktop/openapi.json | 2 +- .../subcomponents/ExtensionItem.tsx | 1 + 5 files changed, 202 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bd3470ff..22f21f8a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2100,6 +2100,16 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6c2141d6d6c8512188a7891b4b01590a45f6dac67afb4f255c4124dbb86d4eaa" +[[package]] +name = "fs2" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9564fc758e15025b46aa6643b1b77d047d1a56a1aea6e01002ac0c7026876213" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "futures" version = "0.3.31" @@ -2393,6 +2403,7 @@ dependencies = [ "ctor", "dotenv", "etcetera", + "fs2", "futures", "include_dir", "indoc", diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index 73fa74b5..70388a8d 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -72,6 +72,7 @@ jsonwebtoken = "9.3.1" # Added blake3 hashing library as a dependency blake3 = "1.5" +fs2 = "0.4.3" [target.'cfg(target_os = "windows")'.dependencies] winapi = { version = "0.3", features = ["wincred"] } diff --git a/crates/goose/src/config/base.rs b/crates/goose/src/config/base.rs index 4f11f3f5..cdaf3c65 100644 --- a/crates/goose/src/config/base.rs +++ b/crates/goose/src/config/base.rs @@ -1,10 +1,13 @@ use etcetera::{choose_app_strategy, AppStrategy, AppStrategyArgs}; +use fs2::FileExt; use keyring::Entry; use once_cell::sync::{Lazy, OnceCell}; use serde::Deserialize; use serde_json::Value; use std::collections::HashMap; use std::env; +use std::fs::OpenOptions; +use std::io::Write; use std::path::{Path, PathBuf}; use thiserror::Error; @@ -32,6 +35,8 @@ pub enum ConfigError { DirectoryError(String), #[error("Failed to access keyring: {0}")] KeyringError(String), + #[error("Failed to lock config file: {0}")] + LockError(String), } impl From for ConfigError { @@ -220,7 +225,22 @@ impl Config { .map_err(|e| ConfigError::DirectoryError(e.to_string()))?; } - std::fs::write(&self.config_path, yaml_value)?; + // Open the file with write permissions, create if it doesn't exist + let mut file = OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(&self.config_path)?; + + // Acquire an exclusive lock + file.lock_exclusive() + .map_err(|e| ConfigError::LockError(e.to_string()))?; + + // Write the contents using the same file handle + file.write_all(yaml_value.as_bytes())?; + file.sync_all()?; + + // Unlock is handled automatically when file is dropped Ok(()) } @@ -621,4 +641,171 @@ mod tests { cleanup_keyring()?; Ok(()) } + + #[test] + fn test_concurrent_writes() -> Result<(), ConfigError> { + use std::sync::{Arc, Barrier, Mutex}; + use std::thread; + + let temp_file = NamedTempFile::new().unwrap(); + let config = Arc::new(Config::new(temp_file.path(), TEST_KEYRING_SERVICE)?); + let barrier = Arc::new(Barrier::new(3)); // For 3 concurrent threads + let values = Arc::new(Mutex::new(HashMap::new())); + let mut handles = vec![]; + + // Initialize with empty values + config.save_values(HashMap::new())?; + + // Spawn 3 threads that will try to write simultaneously + for i in 0..3 { + let config = Arc::clone(&config); + let barrier = Arc::clone(&barrier); + let values = Arc::clone(&values); + let handle = thread::spawn(move || -> Result<(), ConfigError> { + // Wait for all threads to reach this point + barrier.wait(); + + // Get the lock and update values + let mut values = values.lock().unwrap(); + values.insert(format!("key{}", i), Value::String(format!("value{}", i))); + + // Write all values + config.save_values(values.clone())?; + Ok(()) + }); + handles.push(handle); + } + + // Wait for all threads to complete + for handle in handles { + handle.join().unwrap()?; + } + + // Verify all values were written correctly + let final_values = config.load_values()?; + + // Print the final values for debugging + println!("Final values: {:?}", final_values); + + assert_eq!( + final_values.len(), + 3, + "Expected 3 values, got {}", + final_values.len() + ); + + for i in 0..3 { + let key = format!("key{}", i); + let value = format!("value{}", i); + assert!( + final_values.get(&key).is_some(), + "Missing key {} in final values", + key + ); + assert_eq!( + final_values.get(&key).unwrap(), + &Value::String(value), + "Incorrect value for key {}", + key + ); + } + + Ok(()) + } + + #[test] + fn test_concurrent_extension_writes() -> Result<(), ConfigError> { + use std::sync::{Arc, Barrier}; + use std::thread; + use std::time::Duration; + + let temp_file = NamedTempFile::new().unwrap(); + let config = Arc::new(Config::new(temp_file.path(), TEST_KEYRING_SERVICE)?); + let barrier = Arc::new(Barrier::new(3)); // For 3 concurrent threads + let mut handles = vec![]; + + // Initialize with empty values + config.save_values(HashMap::new())?; + + // Spawn 3 threads that will try to write extension configs simultaneously + for i in 0..3 { + let config = Arc::clone(&config); + let barrier = Arc::clone(&barrier); + let handle = thread::spawn(move || -> Result<(), ConfigError> { + // Wait for all threads to reach this point + barrier.wait(); + + // Add a small random delay to increase chance of concurrent access + thread::sleep(Duration::from_millis(i * 10)); + + let extension_key = format!("extension_{}", i); + let mut values = config.load_values()?; + + values.insert( + extension_key.clone(), + serde_json::json!({ + "name": format!("test_extension_{}", i), + "version": format!("1.0.{}", i), + "enabled": true, + "settings": { + "option1": format!("value{}", i), + "option2": i + } + }), + ); + + // Write all values atomically + config.save_values(values)?; + Ok(()) + }); + handles.push(handle); + } + + // Wait for all threads to complete + for handle in handles { + handle.join().unwrap()?; + } + + // Verify all extension configs were written correctly + let final_values = config.load_values()?; + + // Print the final values for debugging + println!("Final values: {:?}", final_values); + + assert_eq!( + final_values.len(), + 3, + "Expected 3 extension configs, got {}", + final_values.len() + ); + + for i in 0..3 { + let extension_key = format!("extension_{}", i); + + let config = final_values + .get(&extension_key) + .expect(&format!("Missing extension config for {}", extension_key)); + + // Verify the structure matches what we wrote + let config_obj = config.as_object().unwrap(); + assert_eq!( + config_obj.get("name").unwrap().as_str().unwrap(), + format!("test_extension_{}", i) + ); + assert_eq!( + config_obj.get("version").unwrap().as_str().unwrap(), + format!("1.0.{}", i) + ); + assert!(config_obj.get("enabled").unwrap().as_bool().unwrap()); + + let settings = config_obj.get("settings").unwrap().as_object().unwrap(); + assert_eq!( + settings.get("option1").unwrap().as_str().unwrap(), + format!("value{}", i) + ); + assert_eq!(settings.get("option2").unwrap().as_i64().unwrap() as i32, i); + } + + Ok(()) + } } diff --git a/ui/desktop/openapi.json b/ui/desktop/openapi.json index 050e40fd..4e20f6cc 100644 --- a/ui/desktop/openapi.json +++ b/ui/desktop/openapi.json @@ -10,7 +10,7 @@ "license": { "name": "Apache-2.0" }, - "version": "1.0.18" + "version": "1.0.19" }, "paths": { "/agent/tools": { diff --git a/ui/desktop/src/components/settings_v2/extensions/subcomponents/ExtensionItem.tsx b/ui/desktop/src/components/settings_v2/extensions/subcomponents/ExtensionItem.tsx index 214a503c..ccc379b3 100644 --- a/ui/desktop/src/components/settings_v2/extensions/subcomponents/ExtensionItem.tsx +++ b/ui/desktop/src/components/settings_v2/extensions/subcomponents/ExtensionItem.tsx @@ -94,6 +94,7 @@ export default function ExtensionItem({ handleToggle(extension)} + disabled={isToggling} variant="mono" />