mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 22:54:24 +01:00
fix: extension config write locks (#2283)
This commit is contained in:
11
Cargo.lock
generated
11
Cargo.lock
generated
@@ -2100,6 +2100,16 @@ version = "2.0.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6c2141d6d6c8512188a7891b4b01590a45f6dac67afb4f255c4124dbb86d4eaa"
|
checksum = "6c2141d6d6c8512188a7891b4b01590a45f6dac67afb4f255c4124dbb86d4eaa"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fs2"
|
||||||
|
version = "0.4.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9564fc758e15025b46aa6643b1b77d047d1a56a1aea6e01002ac0c7026876213"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
"winapi",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures"
|
name = "futures"
|
||||||
version = "0.3.31"
|
version = "0.3.31"
|
||||||
@@ -2393,6 +2403,7 @@ dependencies = [
|
|||||||
"ctor",
|
"ctor",
|
||||||
"dotenv",
|
"dotenv",
|
||||||
"etcetera",
|
"etcetera",
|
||||||
|
"fs2",
|
||||||
"futures",
|
"futures",
|
||||||
"include_dir",
|
"include_dir",
|
||||||
"indoc",
|
"indoc",
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ jsonwebtoken = "9.3.1"
|
|||||||
|
|
||||||
# Added blake3 hashing library as a dependency
|
# Added blake3 hashing library as a dependency
|
||||||
blake3 = "1.5"
|
blake3 = "1.5"
|
||||||
|
fs2 = "0.4.3"
|
||||||
|
|
||||||
[target.'cfg(target_os = "windows")'.dependencies]
|
[target.'cfg(target_os = "windows")'.dependencies]
|
||||||
winapi = { version = "0.3", features = ["wincred"] }
|
winapi = { version = "0.3", features = ["wincred"] }
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
use etcetera::{choose_app_strategy, AppStrategy, AppStrategyArgs};
|
use etcetera::{choose_app_strategy, AppStrategy, AppStrategyArgs};
|
||||||
|
use fs2::FileExt;
|
||||||
use keyring::Entry;
|
use keyring::Entry;
|
||||||
use once_cell::sync::{Lazy, OnceCell};
|
use once_cell::sync::{Lazy, OnceCell};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::env;
|
use std::env;
|
||||||
|
use std::fs::OpenOptions;
|
||||||
|
use std::io::Write;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
@@ -32,6 +35,8 @@ pub enum ConfigError {
|
|||||||
DirectoryError(String),
|
DirectoryError(String),
|
||||||
#[error("Failed to access keyring: {0}")]
|
#[error("Failed to access keyring: {0}")]
|
||||||
KeyringError(String),
|
KeyringError(String),
|
||||||
|
#[error("Failed to lock config file: {0}")]
|
||||||
|
LockError(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<serde_json::Error> for ConfigError {
|
impl From<serde_json::Error> for ConfigError {
|
||||||
@@ -220,7 +225,22 @@ impl Config {
|
|||||||
.map_err(|e| ConfigError::DirectoryError(e.to_string()))?;
|
.map_err(|e| ConfigError::DirectoryError(e.to_string()))?;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::fs::write(&self.config_path, yaml_value)?;
|
// Open the file with write permissions, create if it doesn't exist
|
||||||
|
let mut file = OpenOptions::new()
|
||||||
|
.write(true)
|
||||||
|
.create(true)
|
||||||
|
.truncate(true)
|
||||||
|
.open(&self.config_path)?;
|
||||||
|
|
||||||
|
// Acquire an exclusive lock
|
||||||
|
file.lock_exclusive()
|
||||||
|
.map_err(|e| ConfigError::LockError(e.to_string()))?;
|
||||||
|
|
||||||
|
// Write the contents using the same file handle
|
||||||
|
file.write_all(yaml_value.as_bytes())?;
|
||||||
|
file.sync_all()?;
|
||||||
|
|
||||||
|
// Unlock is handled automatically when file is dropped
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -621,4 +641,171 @@ mod tests {
|
|||||||
cleanup_keyring()?;
|
cleanup_keyring()?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_concurrent_writes() -> Result<(), ConfigError> {
|
||||||
|
use std::sync::{Arc, Barrier, Mutex};
|
||||||
|
use std::thread;
|
||||||
|
|
||||||
|
let temp_file = NamedTempFile::new().unwrap();
|
||||||
|
let config = Arc::new(Config::new(temp_file.path(), TEST_KEYRING_SERVICE)?);
|
||||||
|
let barrier = Arc::new(Barrier::new(3)); // For 3 concurrent threads
|
||||||
|
let values = Arc::new(Mutex::new(HashMap::new()));
|
||||||
|
let mut handles = vec![];
|
||||||
|
|
||||||
|
// Initialize with empty values
|
||||||
|
config.save_values(HashMap::new())?;
|
||||||
|
|
||||||
|
// Spawn 3 threads that will try to write simultaneously
|
||||||
|
for i in 0..3 {
|
||||||
|
let config = Arc::clone(&config);
|
||||||
|
let barrier = Arc::clone(&barrier);
|
||||||
|
let values = Arc::clone(&values);
|
||||||
|
let handle = thread::spawn(move || -> Result<(), ConfigError> {
|
||||||
|
// Wait for all threads to reach this point
|
||||||
|
barrier.wait();
|
||||||
|
|
||||||
|
// Get the lock and update values
|
||||||
|
let mut values = values.lock().unwrap();
|
||||||
|
values.insert(format!("key{}", i), Value::String(format!("value{}", i)));
|
||||||
|
|
||||||
|
// Write all values
|
||||||
|
config.save_values(values.clone())?;
|
||||||
|
Ok(())
|
||||||
|
});
|
||||||
|
handles.push(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all threads to complete
|
||||||
|
for handle in handles {
|
||||||
|
handle.join().unwrap()?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all values were written correctly
|
||||||
|
let final_values = config.load_values()?;
|
||||||
|
|
||||||
|
// Print the final values for debugging
|
||||||
|
println!("Final values: {:?}", final_values);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
final_values.len(),
|
||||||
|
3,
|
||||||
|
"Expected 3 values, got {}",
|
||||||
|
final_values.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
for i in 0..3 {
|
||||||
|
let key = format!("key{}", i);
|
||||||
|
let value = format!("value{}", i);
|
||||||
|
assert!(
|
||||||
|
final_values.get(&key).is_some(),
|
||||||
|
"Missing key {} in final values",
|
||||||
|
key
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
final_values.get(&key).unwrap(),
|
||||||
|
&Value::String(value),
|
||||||
|
"Incorrect value for key {}",
|
||||||
|
key
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_concurrent_extension_writes() -> Result<(), ConfigError> {
|
||||||
|
use std::sync::{Arc, Barrier};
|
||||||
|
use std::thread;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
let temp_file = NamedTempFile::new().unwrap();
|
||||||
|
let config = Arc::new(Config::new(temp_file.path(), TEST_KEYRING_SERVICE)?);
|
||||||
|
let barrier = Arc::new(Barrier::new(3)); // For 3 concurrent threads
|
||||||
|
let mut handles = vec![];
|
||||||
|
|
||||||
|
// Initialize with empty values
|
||||||
|
config.save_values(HashMap::new())?;
|
||||||
|
|
||||||
|
// Spawn 3 threads that will try to write extension configs simultaneously
|
||||||
|
for i in 0..3 {
|
||||||
|
let config = Arc::clone(&config);
|
||||||
|
let barrier = Arc::clone(&barrier);
|
||||||
|
let handle = thread::spawn(move || -> Result<(), ConfigError> {
|
||||||
|
// Wait for all threads to reach this point
|
||||||
|
barrier.wait();
|
||||||
|
|
||||||
|
// Add a small random delay to increase chance of concurrent access
|
||||||
|
thread::sleep(Duration::from_millis(i * 10));
|
||||||
|
|
||||||
|
let extension_key = format!("extension_{}", i);
|
||||||
|
let mut values = config.load_values()?;
|
||||||
|
|
||||||
|
values.insert(
|
||||||
|
extension_key.clone(),
|
||||||
|
serde_json::json!({
|
||||||
|
"name": format!("test_extension_{}", i),
|
||||||
|
"version": format!("1.0.{}", i),
|
||||||
|
"enabled": true,
|
||||||
|
"settings": {
|
||||||
|
"option1": format!("value{}", i),
|
||||||
|
"option2": i
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Write all values atomically
|
||||||
|
config.save_values(values)?;
|
||||||
|
Ok(())
|
||||||
|
});
|
||||||
|
handles.push(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all threads to complete
|
||||||
|
for handle in handles {
|
||||||
|
handle.join().unwrap()?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all extension configs were written correctly
|
||||||
|
let final_values = config.load_values()?;
|
||||||
|
|
||||||
|
// Print the final values for debugging
|
||||||
|
println!("Final values: {:?}", final_values);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
final_values.len(),
|
||||||
|
3,
|
||||||
|
"Expected 3 extension configs, got {}",
|
||||||
|
final_values.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
for i in 0..3 {
|
||||||
|
let extension_key = format!("extension_{}", i);
|
||||||
|
|
||||||
|
let config = final_values
|
||||||
|
.get(&extension_key)
|
||||||
|
.expect(&format!("Missing extension config for {}", extension_key));
|
||||||
|
|
||||||
|
// Verify the structure matches what we wrote
|
||||||
|
let config_obj = config.as_object().unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
config_obj.get("name").unwrap().as_str().unwrap(),
|
||||||
|
format!("test_extension_{}", i)
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
config_obj.get("version").unwrap().as_str().unwrap(),
|
||||||
|
format!("1.0.{}", i)
|
||||||
|
);
|
||||||
|
assert!(config_obj.get("enabled").unwrap().as_bool().unwrap());
|
||||||
|
|
||||||
|
let settings = config_obj.get("settings").unwrap().as_object().unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
settings.get("option1").unwrap().as_str().unwrap(),
|
||||||
|
format!("value{}", i)
|
||||||
|
);
|
||||||
|
assert_eq!(settings.get("option2").unwrap().as_i64().unwrap() as i32, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@
|
|||||||
"license": {
|
"license": {
|
||||||
"name": "Apache-2.0"
|
"name": "Apache-2.0"
|
||||||
},
|
},
|
||||||
"version": "1.0.18"
|
"version": "1.0.19"
|
||||||
},
|
},
|
||||||
"paths": {
|
"paths": {
|
||||||
"/agent/tools": {
|
"/agent/tools": {
|
||||||
|
|||||||
@@ -94,6 +94,7 @@ export default function ExtensionItem({
|
|||||||
<Switch
|
<Switch
|
||||||
checked={(isToggling && visuallyEnabled) || extension.enabled}
|
checked={(isToggling && visuallyEnabled) || extension.enabled}
|
||||||
onCheckedChange={() => handleToggle(extension)}
|
onCheckedChange={() => handleToggle(extension)}
|
||||||
|
disabled={isToggling}
|
||||||
variant="mono"
|
variant="mono"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
Reference in New Issue
Block a user