fix: extension config write locks (#2283)

This commit is contained in:
Alex Hancock
2025-04-21 14:36:12 -04:00
committed by GitHub
parent 4b0cd95310
commit a1b46697f3
5 changed files with 202 additions and 2 deletions

View File

@@ -72,6 +72,7 @@ jsonwebtoken = "9.3.1"
# Added blake3 hashing library as a dependency
blake3 = "1.5"
fs2 = "0.4.3"
[target.'cfg(target_os = "windows")'.dependencies]
winapi = { version = "0.3", features = ["wincred"] }

View File

@@ -1,10 +1,13 @@
use etcetera::{choose_app_strategy, AppStrategy, AppStrategyArgs};
use fs2::FileExt;
use keyring::Entry;
use once_cell::sync::{Lazy, OnceCell};
use serde::Deserialize;
use serde_json::Value;
use std::collections::HashMap;
use std::env;
use std::fs::OpenOptions;
use std::io::Write;
use std::path::{Path, PathBuf};
use thiserror::Error;
@@ -32,6 +35,8 @@ pub enum ConfigError {
DirectoryError(String),
#[error("Failed to access keyring: {0}")]
KeyringError(String),
#[error("Failed to lock config file: {0}")]
LockError(String),
}
impl From<serde_json::Error> for ConfigError {
@@ -220,7 +225,22 @@ impl Config {
.map_err(|e| ConfigError::DirectoryError(e.to_string()))?;
}
std::fs::write(&self.config_path, yaml_value)?;
// Open the file with write permissions, create if it doesn't exist
let mut file = OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&self.config_path)?;
// Acquire an exclusive lock
file.lock_exclusive()
.map_err(|e| ConfigError::LockError(e.to_string()))?;
// Write the contents using the same file handle
file.write_all(yaml_value.as_bytes())?;
file.sync_all()?;
// Unlock is handled automatically when file is dropped
Ok(())
}
@@ -621,4 +641,171 @@ mod tests {
cleanup_keyring()?;
Ok(())
}
#[test]
fn test_concurrent_writes() -> Result<(), ConfigError> {
use std::sync::{Arc, Barrier, Mutex};
use std::thread;
let temp_file = NamedTempFile::new().unwrap();
let config = Arc::new(Config::new(temp_file.path(), TEST_KEYRING_SERVICE)?);
let barrier = Arc::new(Barrier::new(3)); // For 3 concurrent threads
let values = Arc::new(Mutex::new(HashMap::new()));
let mut handles = vec![];
// Initialize with empty values
config.save_values(HashMap::new())?;
// Spawn 3 threads that will try to write simultaneously
for i in 0..3 {
let config = Arc::clone(&config);
let barrier = Arc::clone(&barrier);
let values = Arc::clone(&values);
let handle = thread::spawn(move || -> Result<(), ConfigError> {
// Wait for all threads to reach this point
barrier.wait();
// Get the lock and update values
let mut values = values.lock().unwrap();
values.insert(format!("key{}", i), Value::String(format!("value{}", i)));
// Write all values
config.save_values(values.clone())?;
Ok(())
});
handles.push(handle);
}
// Wait for all threads to complete
for handle in handles {
handle.join().unwrap()?;
}
// Verify all values were written correctly
let final_values = config.load_values()?;
// Print the final values for debugging
println!("Final values: {:?}", final_values);
assert_eq!(
final_values.len(),
3,
"Expected 3 values, got {}",
final_values.len()
);
for i in 0..3 {
let key = format!("key{}", i);
let value = format!("value{}", i);
assert!(
final_values.get(&key).is_some(),
"Missing key {} in final values",
key
);
assert_eq!(
final_values.get(&key).unwrap(),
&Value::String(value),
"Incorrect value for key {}",
key
);
}
Ok(())
}
#[test]
fn test_concurrent_extension_writes() -> Result<(), ConfigError> {
use std::sync::{Arc, Barrier};
use std::thread;
use std::time::Duration;
let temp_file = NamedTempFile::new().unwrap();
let config = Arc::new(Config::new(temp_file.path(), TEST_KEYRING_SERVICE)?);
let barrier = Arc::new(Barrier::new(3)); // For 3 concurrent threads
let mut handles = vec![];
// Initialize with empty values
config.save_values(HashMap::new())?;
// Spawn 3 threads that will try to write extension configs simultaneously
for i in 0..3 {
let config = Arc::clone(&config);
let barrier = Arc::clone(&barrier);
let handle = thread::spawn(move || -> Result<(), ConfigError> {
// Wait for all threads to reach this point
barrier.wait();
// Add a small random delay to increase chance of concurrent access
thread::sleep(Duration::from_millis(i * 10));
let extension_key = format!("extension_{}", i);
let mut values = config.load_values()?;
values.insert(
extension_key.clone(),
serde_json::json!({
"name": format!("test_extension_{}", i),
"version": format!("1.0.{}", i),
"enabled": true,
"settings": {
"option1": format!("value{}", i),
"option2": i
}
}),
);
// Write all values atomically
config.save_values(values)?;
Ok(())
});
handles.push(handle);
}
// Wait for all threads to complete
for handle in handles {
handle.join().unwrap()?;
}
// Verify all extension configs were written correctly
let final_values = config.load_values()?;
// Print the final values for debugging
println!("Final values: {:?}", final_values);
assert_eq!(
final_values.len(),
3,
"Expected 3 extension configs, got {}",
final_values.len()
);
for i in 0..3 {
let extension_key = format!("extension_{}", i);
let config = final_values
.get(&extension_key)
.expect(&format!("Missing extension config for {}", extension_key));
// Verify the structure matches what we wrote
let config_obj = config.as_object().unwrap();
assert_eq!(
config_obj.get("name").unwrap().as_str().unwrap(),
format!("test_extension_{}", i)
);
assert_eq!(
config_obj.get("version").unwrap().as_str().unwrap(),
format!("1.0.{}", i)
);
assert!(config_obj.get("enabled").unwrap().as_bool().unwrap());
let settings = config_obj.get("settings").unwrap().as_object().unwrap();
assert_eq!(
settings.get("option1").unwrap().as_str().unwrap(),
format!("value{}", i)
);
assert_eq!(settings.get("option2").unwrap().as_i64().unwrap() as i32, i);
}
Ok(())
}
}