mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 14:44:21 +01:00
fix: extension config write locks (#2283)
This commit is contained in:
@@ -72,6 +72,7 @@ jsonwebtoken = "9.3.1"
|
||||
|
||||
# Added blake3 hashing library as a dependency
|
||||
blake3 = "1.5"
|
||||
fs2 = "0.4.3"
|
||||
|
||||
[target.'cfg(target_os = "windows")'.dependencies]
|
||||
winapi = { version = "0.3", features = ["wincred"] }
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
use etcetera::{choose_app_strategy, AppStrategy, AppStrategyArgs};
|
||||
use fs2::FileExt;
|
||||
use keyring::Entry;
|
||||
use once_cell::sync::{Lazy, OnceCell};
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::fs::OpenOptions;
|
||||
use std::io::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
use thiserror::Error;
|
||||
|
||||
@@ -32,6 +35,8 @@ pub enum ConfigError {
|
||||
DirectoryError(String),
|
||||
#[error("Failed to access keyring: {0}")]
|
||||
KeyringError(String),
|
||||
#[error("Failed to lock config file: {0}")]
|
||||
LockError(String),
|
||||
}
|
||||
|
||||
impl From<serde_json::Error> for ConfigError {
|
||||
@@ -220,7 +225,22 @@ impl Config {
|
||||
.map_err(|e| ConfigError::DirectoryError(e.to_string()))?;
|
||||
}
|
||||
|
||||
std::fs::write(&self.config_path, yaml_value)?;
|
||||
// Open the file with write permissions, create if it doesn't exist
|
||||
let mut file = OpenOptions::new()
|
||||
.write(true)
|
||||
.create(true)
|
||||
.truncate(true)
|
||||
.open(&self.config_path)?;
|
||||
|
||||
// Acquire an exclusive lock
|
||||
file.lock_exclusive()
|
||||
.map_err(|e| ConfigError::LockError(e.to_string()))?;
|
||||
|
||||
// Write the contents using the same file handle
|
||||
file.write_all(yaml_value.as_bytes())?;
|
||||
file.sync_all()?;
|
||||
|
||||
// Unlock is handled automatically when file is dropped
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -621,4 +641,171 @@ mod tests {
|
||||
cleanup_keyring()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_writes() -> Result<(), ConfigError> {
|
||||
use std::sync::{Arc, Barrier, Mutex};
|
||||
use std::thread;
|
||||
|
||||
let temp_file = NamedTempFile::new().unwrap();
|
||||
let config = Arc::new(Config::new(temp_file.path(), TEST_KEYRING_SERVICE)?);
|
||||
let barrier = Arc::new(Barrier::new(3)); // For 3 concurrent threads
|
||||
let values = Arc::new(Mutex::new(HashMap::new()));
|
||||
let mut handles = vec![];
|
||||
|
||||
// Initialize with empty values
|
||||
config.save_values(HashMap::new())?;
|
||||
|
||||
// Spawn 3 threads that will try to write simultaneously
|
||||
for i in 0..3 {
|
||||
let config = Arc::clone(&config);
|
||||
let barrier = Arc::clone(&barrier);
|
||||
let values = Arc::clone(&values);
|
||||
let handle = thread::spawn(move || -> Result<(), ConfigError> {
|
||||
// Wait for all threads to reach this point
|
||||
barrier.wait();
|
||||
|
||||
// Get the lock and update values
|
||||
let mut values = values.lock().unwrap();
|
||||
values.insert(format!("key{}", i), Value::String(format!("value{}", i)));
|
||||
|
||||
// Write all values
|
||||
config.save_values(values.clone())?;
|
||||
Ok(())
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Wait for all threads to complete
|
||||
for handle in handles {
|
||||
handle.join().unwrap()?;
|
||||
}
|
||||
|
||||
// Verify all values were written correctly
|
||||
let final_values = config.load_values()?;
|
||||
|
||||
// Print the final values for debugging
|
||||
println!("Final values: {:?}", final_values);
|
||||
|
||||
assert_eq!(
|
||||
final_values.len(),
|
||||
3,
|
||||
"Expected 3 values, got {}",
|
||||
final_values.len()
|
||||
);
|
||||
|
||||
for i in 0..3 {
|
||||
let key = format!("key{}", i);
|
||||
let value = format!("value{}", i);
|
||||
assert!(
|
||||
final_values.get(&key).is_some(),
|
||||
"Missing key {} in final values",
|
||||
key
|
||||
);
|
||||
assert_eq!(
|
||||
final_values.get(&key).unwrap(),
|
||||
&Value::String(value),
|
||||
"Incorrect value for key {}",
|
||||
key
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_extension_writes() -> Result<(), ConfigError> {
|
||||
use std::sync::{Arc, Barrier};
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
let temp_file = NamedTempFile::new().unwrap();
|
||||
let config = Arc::new(Config::new(temp_file.path(), TEST_KEYRING_SERVICE)?);
|
||||
let barrier = Arc::new(Barrier::new(3)); // For 3 concurrent threads
|
||||
let mut handles = vec![];
|
||||
|
||||
// Initialize with empty values
|
||||
config.save_values(HashMap::new())?;
|
||||
|
||||
// Spawn 3 threads that will try to write extension configs simultaneously
|
||||
for i in 0..3 {
|
||||
let config = Arc::clone(&config);
|
||||
let barrier = Arc::clone(&barrier);
|
||||
let handle = thread::spawn(move || -> Result<(), ConfigError> {
|
||||
// Wait for all threads to reach this point
|
||||
barrier.wait();
|
||||
|
||||
// Add a small random delay to increase chance of concurrent access
|
||||
thread::sleep(Duration::from_millis(i * 10));
|
||||
|
||||
let extension_key = format!("extension_{}", i);
|
||||
let mut values = config.load_values()?;
|
||||
|
||||
values.insert(
|
||||
extension_key.clone(),
|
||||
serde_json::json!({
|
||||
"name": format!("test_extension_{}", i),
|
||||
"version": format!("1.0.{}", i),
|
||||
"enabled": true,
|
||||
"settings": {
|
||||
"option1": format!("value{}", i),
|
||||
"option2": i
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
// Write all values atomically
|
||||
config.save_values(values)?;
|
||||
Ok(())
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Wait for all threads to complete
|
||||
for handle in handles {
|
||||
handle.join().unwrap()?;
|
||||
}
|
||||
|
||||
// Verify all extension configs were written correctly
|
||||
let final_values = config.load_values()?;
|
||||
|
||||
// Print the final values for debugging
|
||||
println!("Final values: {:?}", final_values);
|
||||
|
||||
assert_eq!(
|
||||
final_values.len(),
|
||||
3,
|
||||
"Expected 3 extension configs, got {}",
|
||||
final_values.len()
|
||||
);
|
||||
|
||||
for i in 0..3 {
|
||||
let extension_key = format!("extension_{}", i);
|
||||
|
||||
let config = final_values
|
||||
.get(&extension_key)
|
||||
.expect(&format!("Missing extension config for {}", extension_key));
|
||||
|
||||
// Verify the structure matches what we wrote
|
||||
let config_obj = config.as_object().unwrap();
|
||||
assert_eq!(
|
||||
config_obj.get("name").unwrap().as_str().unwrap(),
|
||||
format!("test_extension_{}", i)
|
||||
);
|
||||
assert_eq!(
|
||||
config_obj.get("version").unwrap().as_str().unwrap(),
|
||||
format!("1.0.{}", i)
|
||||
);
|
||||
assert!(config_obj.get("enabled").unwrap().as_bool().unwrap());
|
||||
|
||||
let settings = config_obj.get("settings").unwrap().as_object().unwrap();
|
||||
assert_eq!(
|
||||
settings.get("option1").unwrap().as_str().unwrap(),
|
||||
format!("value{}", i)
|
||||
);
|
||||
assert_eq!(settings.get("option2").unwrap().as_i64().unwrap() as i32, i);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user