diff --git a/Cargo.lock b/Cargo.lock index 13b491a1..fbb41a55 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2414,6 +2414,7 @@ dependencies = [ "image 0.24.9", "include_dir", "indoc", + "keyring", "kill_tree", "lazy_static", "lopdf", diff --git a/crates/goose-mcp/Cargo.toml b/crates/goose-mcp/Cargo.toml index 602b071f..cd718134 100644 --- a/crates/goose-mcp/Cargo.toml +++ b/crates/goose-mcp/Cargo.toml @@ -43,7 +43,8 @@ lopdf = "0.35.0" docx-rs = "0.4.7" image = "0.24.9" umya-spreadsheet = "2.2.3" +keyring = { version = "3.6.1", features = ["apple-native", "windows-native", "sync-secret-service"] } [dev-dependencies] serial_test = "3.0.0" -sysinfo = "0.32.1" \ No newline at end of file +sysinfo = "0.32.1" diff --git a/crates/goose-mcp/src/google_drive/mod.rs b/crates/goose-mcp/src/google_drive/mod.rs index 61028d7a..1bd545ff 100644 --- a/crates/goose-mcp/src/google_drive/mod.rs +++ b/crates/goose-mcp/src/google_drive/mod.rs @@ -1,9 +1,12 @@ +mod token_storage; + use indoc::indoc; use regex::Regex; use serde_json::{json, Value}; +use std::{env, fs, future::Future, path::Path, pin::Pin, sync::Arc}; +use token_storage::{CredentialsManager, KeychainTokenStorage}; -use std::{env, fs, future::Future, io::Write, path::Path, pin::Pin}; - +use mcp_core::content::Content; use mcp_core::{ handler::{PromptError, ResourceError, ToolError}, prompt::Prompt, @@ -14,8 +17,6 @@ use mcp_core::{ use mcp_server::router::CapabilitiesBuilder; use mcp_server::Router; -use mcp_core::content::Content; - use google_drive3::{ self, api::{File, Scope}, @@ -28,9 +29,7 @@ use google_drive3::{ }, DriveHub, }; - use google_sheets4::{self, Sheets}; - use http_body_util::BodyExt; /// async function to be pinned by the `present_user_url` method of the trait @@ -70,14 +69,15 @@ pub struct GoogleDriveRouter { instructions: String, drive: DriveHub>, sheets: Sheets>, + credentials_manager: Arc, } impl GoogleDriveRouter { async fn google_auth() -> ( DriveHub>, Sheets>, + Arc, ) { - let oauth_config = env::var("GOOGLE_DRIVE_OAUTH_CONFIG"); let keyfile_path_str = env::var("GOOGLE_DRIVE_OAUTH_PATH") .unwrap_or_else(|_| "./gcp-oauth.keys.json".to_string()); let credentials_path_str = env::var("GOOGLE_DRIVE_CREDENTIALS_PATH") @@ -87,7 +87,7 @@ impl GoogleDriveRouter { let keyfile_path = Path::new(expanded_keyfile.as_ref()); let expanded_credentials = shellexpand::tilde(credentials_path_str.as_str()); - let credentials_path = Path::new(expanded_credentials.as_ref()); + let credentials_path = expanded_credentials.to_string(); tracing::info!( credentials_path = credentials_path_str, @@ -95,35 +95,72 @@ impl GoogleDriveRouter { "Google Drive MCP server authentication config paths" ); - if !keyfile_path.exists() && oauth_config.is_ok() { - // attempt to create the path - if let Some(parent_dir) = keyfile_path.parent() { - let _ = fs::create_dir_all(parent_dir); + if let Ok(oauth_config) = env::var("GOOGLE_DRIVE_OAUTH_CONFIG") { + // Ensure the parent directory exists (create_dir_all is idempotent) + if let Some(parent) = keyfile_path.parent() { + if let Err(e) = fs::create_dir_all(parent) { + tracing::error!( + "Failed to create parent directories for {}: {}", + keyfile_path.display(), + e + ); + } } - if let Ok(mut file) = fs::File::create(keyfile_path) { - let _ = file.write_all(oauth_config.unwrap().as_bytes()); - tracing::debug!( - "Wrote Google Drive MCP server OAuth config to {}", - keyfile_path.display() - ); + // Check if the file exists and whether its content matches + // in every other case we attempt to overwrite + let need_to_write = match fs::read_to_string(keyfile_path) { + Ok(existing) if existing == oauth_config => false, + Ok(_) | Err(_) => true, + }; + + // Overwrite the file if needed + if need_to_write { + if let Err(e) = fs::write(keyfile_path, &oauth_config) { + tracing::error!( + "Failed to write OAuth config to {}: {}", + keyfile_path.display(), + e + ); + } else { + tracing::debug!( + "Wrote Google Drive MCP server OAuth config to {}", + keyfile_path.display() + ); + } } } + // Create a credentials manager for storing tokens securely + let credentials_manager = Arc::new(CredentialsManager::new(credentials_path.clone())); + + // Read the application secret from the OAuth keyfile let secret = yup_oauth2::read_application_secret(keyfile_path) .await .expect("expected keyfile for google auth"); + // Create custom token storage using our credentials manager + let token_storage = KeychainTokenStorage::new( + secret + .project_id + .clone() + .unwrap_or("unknown-project-id".to_string()) + .to_string(), + credentials_manager.clone(), + ); + + // Create the authenticator with the installed flow let auth = InstalledFlowAuthenticator::builder( secret, yup_oauth2::InstalledFlowReturnMethod::HTTPRedirect, ) - .persist_tokens_to_disk(credentials_path) + .with_storage(Box::new(token_storage)) // Use our custom storage .flow_delegate(Box::new(LocalhostBrowserDelegate)) .build() .await .expect("expected successful authentication"); + // Create the HTTP client let client = hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()) .build( @@ -138,11 +175,12 @@ impl GoogleDriveRouter { let drive_hub = DriveHub::new(client.clone(), auth.clone()); let sheets_hub = Sheets::new(client, auth); - (drive_hub, sheets_hub) + // Create and return the DriveHub + (drive_hub, sheets_hub, credentials_manager) } pub async fn new() -> Self { - let (drive, sheets) = Self::google_auth().await; + let (drive, sheets, credentials_manager) = Self::google_auth().await; // handle auth let search_tool = Tool::new( @@ -302,6 +340,7 @@ impl GoogleDriveRouter { instructions, drive, sheets, + credentials_manager, } } @@ -851,6 +890,7 @@ impl Clone for GoogleDriveRouter { instructions: self.instructions.clone(), drive: self.drive.clone(), sheets: self.sheets.clone(), + credentials_manager: self.credentials_manager.clone(), } } } diff --git a/crates/goose-mcp/src/google_drive/token_storage.rs b/crates/goose-mcp/src/google_drive/token_storage.rs new file mode 100644 index 00000000..f40ab4ba --- /dev/null +++ b/crates/goose-mcp/src/google_drive/token_storage.rs @@ -0,0 +1,301 @@ +use anyhow::Result; +use google_drive3::yup_oauth2::storage::{TokenInfo, TokenStorage}; +use keyring::Entry; +use std::env; +use std::fs; +use std::path::Path; +use std::sync::Arc; +use thiserror::Error; +use tracing::{debug, error, warn}; + +const KEYCHAIN_SERVICE: &str = "mcp_google_drive"; +const KEYCHAIN_USERNAME: &str = "oauth_credentials"; +const KEYCHAIN_DISK_FALLBACK_ENV: &str = "GOOGLE_DRIVE_DISK_FALLBACK"; + +#[allow(dead_code)] +#[derive(Error, Debug)] +pub enum AuthError { + #[error("Failed to access keychain: {0}")] + KeyringError(#[from] keyring::Error), + #[error("Failed to access file system: {0}")] + FileSystemError(#[from] std::io::Error), + #[error("No credentials found")] + NotFound, + #[error("Critical error: {0}")] + Critical(String), + #[error("Failed to serialize/deserialize: {0}")] + SerializationError(#[from] serde_json::Error), +} + +/// CredentialsManager handles secure storage of OAuth credentials. +/// It attempts to store credentials in the system keychain first, +/// with fallback to file system storage if keychain access fails and fallback is enabled. +pub struct CredentialsManager { + credentials_path: String, + fallback_to_disk: bool, +} + +impl CredentialsManager { + pub fn new(credentials_path: String) -> Self { + // Check if we should fall back to disk, must be explicitly enabled + let fallback_to_disk = match env::var(KEYCHAIN_DISK_FALLBACK_ENV) { + Ok(value) => value.to_lowercase() == "true", + Err(_) => false, + }; + + Self { + credentials_path, + fallback_to_disk, + } + } + + pub fn read_credentials(&self) -> Result { + Entry::new(KEYCHAIN_SERVICE, KEYCHAIN_USERNAME) + .and_then(|entry| entry.get_password()) + .inspect(|_| { + debug!("Successfully read credentials from keychain"); + }) + .or_else(|e| { + if self.fallback_to_disk { + debug!("Falling back to file system due to keyring error: {}", e); + self.read_from_file() + } else { + match e { + keyring::Error::NoEntry => Err(AuthError::NotFound), + _ => Err(AuthError::KeyringError(e)), + } + } + }) + } + + fn read_from_file(&self) -> Result { + let path = Path::new(&self.credentials_path); + if path.exists() { + match fs::read_to_string(path) { + Ok(content) => { + debug!("Successfully read credentials from file system"); + Ok(content) + } + Err(e) => { + error!("Failed to read credentials file: {}", e); + Err(AuthError::FileSystemError(e)) + } + } + } else { + debug!("No credentials found in file system"); + Err(AuthError::NotFound) + } + } + + pub fn write_credentials(&self, content: &str) -> Result<(), AuthError> { + Entry::new(KEYCHAIN_SERVICE, KEYCHAIN_USERNAME) + .and_then(|entry| entry.set_password(content)) + .inspect(|_| { + debug!("Successfully wrote credentials to keychain"); + }) + .or_else(|e| { + if self.fallback_to_disk { + warn!("Falling back to file system due to keyring error: {}", e); + self.write_to_file(content) + } else { + Err(AuthError::KeyringError(e)) + } + }) + } + + fn write_to_file(&self, content: &str) -> Result<(), AuthError> { + let path = Path::new(&self.credentials_path); + if let Some(parent) = path.parent() { + if !parent.exists() { + match fs::create_dir_all(parent) { + Ok(_) => debug!("Created parent directories for credentials file"), + Err(e) => { + error!("Failed to create directories for credentials file: {}", e); + return Err(AuthError::FileSystemError(e)); + } + } + } + } + + match fs::write(path, content) { + Ok(_) => { + debug!("Successfully wrote credentials to file system"); + Ok(()) + } + Err(e) => { + error!("Failed to write credentials to file system: {}", e); + Err(AuthError::FileSystemError(e)) + } + } + } +} + +/// Storage entry that includes the token, scopes and project it's valid for +#[derive(serde::Serialize, serde::Deserialize)] +struct StorageEntry { + token: TokenInfo, + scopes: String, + project_id: String, +} + +/// KeychainTokenStorage implements the TokenStorage trait from yup_oauth2 +/// to enable secure storage of OAuth tokens in the system keychain. +pub struct KeychainTokenStorage { + project_id: String, + credentials_manager: Arc, +} + +impl KeychainTokenStorage { + /// Create a new KeychainTokenStorage with the given CredentialsManager + pub fn new(project_id: String, credentials_manager: Arc) -> Self { + Self { + project_id, + credentials_manager, + } + } + + fn generate_scoped_key(&self, scopes: &[&str]) -> String { + // Create a key based on the scopes and project_id + // Sort so we can be consistent using scopes as the key + let mut sorted_scopes = scopes.to_vec(); + sorted_scopes.sort(); + sorted_scopes.join(" ") + } +} + +#[async_trait::async_trait] +impl TokenStorage for KeychainTokenStorage { + /// Store a token in the keychain + async fn set(&self, scopes: &[&str], token_info: TokenInfo) -> Result<()> { + let key = self.generate_scoped_key(scopes); + + // Create a storage entry that includes the scopes + let storage_entry = StorageEntry { + token: token_info, + scopes: key, + project_id: self.project_id.clone(), + }; + + let json = serde_json::to_string(&storage_entry)?; + self.credentials_manager + .write_credentials(&json) + .map_err(|e| { + error!("Failed to write token to keychain: {}", e); + anyhow::anyhow!("Failed to write token to keychain: {}", e) + }) + } + + /// Retrieve a token from the keychain + async fn get(&self, scopes: &[&str]) -> Option { + let key = self.generate_scoped_key(scopes); + + match self.credentials_manager.read_credentials() { + Ok(json) => { + debug!("Successfully read credentials from storage"); + match serde_json::from_str::(&json) { + Ok(entry) => { + // Check if token has the requested scopes and matches the project_id + if entry.project_id == self.project_id && entry.scopes == key { + debug!("Successfully retrieved OAuth token from storage"); + Some(entry.token) + } else { + None + } + } + Err(e) => { + warn!("Failed to deserialize token from storage: {}", e); + None + } + } + } + Err(AuthError::NotFound) => { + debug!("No OAuth token found in storage"); + None + } + Err(e) => { + warn!("Error reading OAuth token from storage: {}", e); + None + } + } + } +} + +impl Clone for CredentialsManager { + fn clone(&self) -> Self { + Self { + credentials_path: self.credentials_path.clone(), + fallback_to_disk: self.fallback_to_disk, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serial_test::serial; + use tempfile::NamedTempFile; + + #[tokio::test] + #[serial] + async fn test_token_storage_set_get() { + // Create a temporary file for testing + let temp_file = NamedTempFile::new().unwrap(); + let project_id = "test_project_1".to_string(); + let credentials_manager = Arc::new(CredentialsManager::new( + temp_file.path().to_string_lossy().to_string(), + )); + + let storage = KeychainTokenStorage::new(project_id, credentials_manager); + + // Create a test token + let token_info = TokenInfo { + access_token: Some("test_access_token".to_string()), + refresh_token: Some("test_refresh_token".to_string()), + expires_at: None, + id_token: None, + }; + + let scopes = &["https://www.googleapis.com/auth/drive.readonly"]; + + // Store the token + storage.set(scopes, token_info.clone()).await.unwrap(); + + // Retrieve the token + let retrieved = storage.get(scopes).await.unwrap(); + + // Verify the token matches + assert_eq!(retrieved.access_token, token_info.access_token); + assert_eq!(retrieved.refresh_token, token_info.refresh_token); + } + + #[tokio::test] + #[serial] + async fn test_token_storage_scope_mismatch() { + // Create a temporary file for testing + let temp_file = NamedTempFile::new().unwrap(); + let project_id = "test_project_2".to_string(); + let credentials_manager = Arc::new(CredentialsManager::new( + temp_file.path().to_string_lossy().to_string(), + )); + + let storage = KeychainTokenStorage::new(project_id, credentials_manager); + + // Create a test token + let token_info = TokenInfo { + access_token: Some("test_access_token".to_string()), + refresh_token: Some("test_refresh_token".to_string()), + expires_at: None, + id_token: None, + }; + + let scopes1 = &["https://www.googleapis.com/auth/drive.readonly"]; + let scopes2 = &["https://www.googleapis.com/auth/drive.file"]; + + // Store the token with scopes1 + storage.set(scopes1, token_info).await.unwrap(); + + // Try to retrieve with different scopes + let result = storage.get(scopes2).await; + assert!(result.is_none()); + } +}