mirror of
https://github.com/aljazceru/goose.git
synced 2026-02-20 14:04:32 +01:00
feat(google_drive): move credentials into keychain, add optional fallback (#1603)
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -2414,6 +2414,7 @@ dependencies = [
|
||||
"image 0.24.9",
|
||||
"include_dir",
|
||||
"indoc",
|
||||
"keyring",
|
||||
"kill_tree",
|
||||
"lazy_static",
|
||||
"lopdf",
|
||||
|
||||
@@ -43,7 +43,8 @@ lopdf = "0.35.0"
|
||||
docx-rs = "0.4.7"
|
||||
image = "0.24.9"
|
||||
umya-spreadsheet = "2.2.3"
|
||||
keyring = { version = "3.6.1", features = ["apple-native", "windows-native", "sync-secret-service"] }
|
||||
|
||||
[dev-dependencies]
|
||||
serial_test = "3.0.0"
|
||||
sysinfo = "0.32.1"
|
||||
sysinfo = "0.32.1"
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
mod token_storage;
|
||||
|
||||
use indoc::indoc;
|
||||
use regex::Regex;
|
||||
use serde_json::{json, Value};
|
||||
use std::{env, fs, future::Future, path::Path, pin::Pin, sync::Arc};
|
||||
use token_storage::{CredentialsManager, KeychainTokenStorage};
|
||||
|
||||
use std::{env, fs, future::Future, io::Write, path::Path, pin::Pin};
|
||||
|
||||
use mcp_core::content::Content;
|
||||
use mcp_core::{
|
||||
handler::{PromptError, ResourceError, ToolError},
|
||||
prompt::Prompt,
|
||||
@@ -14,8 +17,6 @@ use mcp_core::{
|
||||
use mcp_server::router::CapabilitiesBuilder;
|
||||
use mcp_server::Router;
|
||||
|
||||
use mcp_core::content::Content;
|
||||
|
||||
use google_drive3::{
|
||||
self,
|
||||
api::{File, Scope},
|
||||
@@ -28,9 +29,7 @@ use google_drive3::{
|
||||
},
|
||||
DriveHub,
|
||||
};
|
||||
|
||||
use google_sheets4::{self, Sheets};
|
||||
|
||||
use http_body_util::BodyExt;
|
||||
|
||||
/// async function to be pinned by the `present_user_url` method of the trait
|
||||
@@ -70,14 +69,15 @@ pub struct GoogleDriveRouter {
|
||||
instructions: String,
|
||||
drive: DriveHub<HttpsConnector<HttpConnector>>,
|
||||
sheets: Sheets<HttpsConnector<HttpConnector>>,
|
||||
credentials_manager: Arc<CredentialsManager>,
|
||||
}
|
||||
|
||||
impl GoogleDriveRouter {
|
||||
async fn google_auth() -> (
|
||||
DriveHub<HttpsConnector<HttpConnector>>,
|
||||
Sheets<HttpsConnector<HttpConnector>>,
|
||||
Arc<CredentialsManager>,
|
||||
) {
|
||||
let oauth_config = env::var("GOOGLE_DRIVE_OAUTH_CONFIG");
|
||||
let keyfile_path_str = env::var("GOOGLE_DRIVE_OAUTH_PATH")
|
||||
.unwrap_or_else(|_| "./gcp-oauth.keys.json".to_string());
|
||||
let credentials_path_str = env::var("GOOGLE_DRIVE_CREDENTIALS_PATH")
|
||||
@@ -87,7 +87,7 @@ impl GoogleDriveRouter {
|
||||
let keyfile_path = Path::new(expanded_keyfile.as_ref());
|
||||
|
||||
let expanded_credentials = shellexpand::tilde(credentials_path_str.as_str());
|
||||
let credentials_path = Path::new(expanded_credentials.as_ref());
|
||||
let credentials_path = expanded_credentials.to_string();
|
||||
|
||||
tracing::info!(
|
||||
credentials_path = credentials_path_str,
|
||||
@@ -95,35 +95,72 @@ impl GoogleDriveRouter {
|
||||
"Google Drive MCP server authentication config paths"
|
||||
);
|
||||
|
||||
if !keyfile_path.exists() && oauth_config.is_ok() {
|
||||
// attempt to create the path
|
||||
if let Some(parent_dir) = keyfile_path.parent() {
|
||||
let _ = fs::create_dir_all(parent_dir);
|
||||
if let Ok(oauth_config) = env::var("GOOGLE_DRIVE_OAUTH_CONFIG") {
|
||||
// Ensure the parent directory exists (create_dir_all is idempotent)
|
||||
if let Some(parent) = keyfile_path.parent() {
|
||||
if let Err(e) = fs::create_dir_all(parent) {
|
||||
tracing::error!(
|
||||
"Failed to create parent directories for {}: {}",
|
||||
keyfile_path.display(),
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(mut file) = fs::File::create(keyfile_path) {
|
||||
let _ = file.write_all(oauth_config.unwrap().as_bytes());
|
||||
tracing::debug!(
|
||||
"Wrote Google Drive MCP server OAuth config to {}",
|
||||
keyfile_path.display()
|
||||
);
|
||||
// Check if the file exists and whether its content matches
|
||||
// in every other case we attempt to overwrite
|
||||
let need_to_write = match fs::read_to_string(keyfile_path) {
|
||||
Ok(existing) if existing == oauth_config => false,
|
||||
Ok(_) | Err(_) => true,
|
||||
};
|
||||
|
||||
// Overwrite the file if needed
|
||||
if need_to_write {
|
||||
if let Err(e) = fs::write(keyfile_path, &oauth_config) {
|
||||
tracing::error!(
|
||||
"Failed to write OAuth config to {}: {}",
|
||||
keyfile_path.display(),
|
||||
e
|
||||
);
|
||||
} else {
|
||||
tracing::debug!(
|
||||
"Wrote Google Drive MCP server OAuth config to {}",
|
||||
keyfile_path.display()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a credentials manager for storing tokens securely
|
||||
let credentials_manager = Arc::new(CredentialsManager::new(credentials_path.clone()));
|
||||
|
||||
// Read the application secret from the OAuth keyfile
|
||||
let secret = yup_oauth2::read_application_secret(keyfile_path)
|
||||
.await
|
||||
.expect("expected keyfile for google auth");
|
||||
|
||||
// Create custom token storage using our credentials manager
|
||||
let token_storage = KeychainTokenStorage::new(
|
||||
secret
|
||||
.project_id
|
||||
.clone()
|
||||
.unwrap_or("unknown-project-id".to_string())
|
||||
.to_string(),
|
||||
credentials_manager.clone(),
|
||||
);
|
||||
|
||||
// Create the authenticator with the installed flow
|
||||
let auth = InstalledFlowAuthenticator::builder(
|
||||
secret,
|
||||
yup_oauth2::InstalledFlowReturnMethod::HTTPRedirect,
|
||||
)
|
||||
.persist_tokens_to_disk(credentials_path)
|
||||
.with_storage(Box::new(token_storage)) // Use our custom storage
|
||||
.flow_delegate(Box::new(LocalhostBrowserDelegate))
|
||||
.build()
|
||||
.await
|
||||
.expect("expected successful authentication");
|
||||
|
||||
// Create the HTTP client
|
||||
let client =
|
||||
hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
|
||||
.build(
|
||||
@@ -138,11 +175,12 @@ impl GoogleDriveRouter {
|
||||
let drive_hub = DriveHub::new(client.clone(), auth.clone());
|
||||
let sheets_hub = Sheets::new(client, auth);
|
||||
|
||||
(drive_hub, sheets_hub)
|
||||
// Create and return the DriveHub
|
||||
(drive_hub, sheets_hub, credentials_manager)
|
||||
}
|
||||
|
||||
pub async fn new() -> Self {
|
||||
let (drive, sheets) = Self::google_auth().await;
|
||||
let (drive, sheets, credentials_manager) = Self::google_auth().await;
|
||||
|
||||
// handle auth
|
||||
let search_tool = Tool::new(
|
||||
@@ -302,6 +340,7 @@ impl GoogleDriveRouter {
|
||||
instructions,
|
||||
drive,
|
||||
sheets,
|
||||
credentials_manager,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -851,6 +890,7 @@ impl Clone for GoogleDriveRouter {
|
||||
instructions: self.instructions.clone(),
|
||||
drive: self.drive.clone(),
|
||||
sheets: self.sheets.clone(),
|
||||
credentials_manager: self.credentials_manager.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
301
crates/goose-mcp/src/google_drive/token_storage.rs
Normal file
301
crates/goose-mcp/src/google_drive/token_storage.rs
Normal file
@@ -0,0 +1,301 @@
|
||||
use anyhow::Result;
|
||||
use google_drive3::yup_oauth2::storage::{TokenInfo, TokenStorage};
|
||||
use keyring::Entry;
|
||||
use std::env;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, error, warn};
|
||||
|
||||
const KEYCHAIN_SERVICE: &str = "mcp_google_drive";
|
||||
const KEYCHAIN_USERNAME: &str = "oauth_credentials";
|
||||
const KEYCHAIN_DISK_FALLBACK_ENV: &str = "GOOGLE_DRIVE_DISK_FALLBACK";
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Error, Debug)]
|
||||
pub enum AuthError {
|
||||
#[error("Failed to access keychain: {0}")]
|
||||
KeyringError(#[from] keyring::Error),
|
||||
#[error("Failed to access file system: {0}")]
|
||||
FileSystemError(#[from] std::io::Error),
|
||||
#[error("No credentials found")]
|
||||
NotFound,
|
||||
#[error("Critical error: {0}")]
|
||||
Critical(String),
|
||||
#[error("Failed to serialize/deserialize: {0}")]
|
||||
SerializationError(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
/// CredentialsManager handles secure storage of OAuth credentials.
|
||||
/// It attempts to store credentials in the system keychain first,
|
||||
/// with fallback to file system storage if keychain access fails and fallback is enabled.
|
||||
pub struct CredentialsManager {
|
||||
credentials_path: String,
|
||||
fallback_to_disk: bool,
|
||||
}
|
||||
|
||||
impl CredentialsManager {
|
||||
pub fn new(credentials_path: String) -> Self {
|
||||
// Check if we should fall back to disk, must be explicitly enabled
|
||||
let fallback_to_disk = match env::var(KEYCHAIN_DISK_FALLBACK_ENV) {
|
||||
Ok(value) => value.to_lowercase() == "true",
|
||||
Err(_) => false,
|
||||
};
|
||||
|
||||
Self {
|
||||
credentials_path,
|
||||
fallback_to_disk,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read_credentials(&self) -> Result<String, AuthError> {
|
||||
Entry::new(KEYCHAIN_SERVICE, KEYCHAIN_USERNAME)
|
||||
.and_then(|entry| entry.get_password())
|
||||
.inspect(|_| {
|
||||
debug!("Successfully read credentials from keychain");
|
||||
})
|
||||
.or_else(|e| {
|
||||
if self.fallback_to_disk {
|
||||
debug!("Falling back to file system due to keyring error: {}", e);
|
||||
self.read_from_file()
|
||||
} else {
|
||||
match e {
|
||||
keyring::Error::NoEntry => Err(AuthError::NotFound),
|
||||
_ => Err(AuthError::KeyringError(e)),
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn read_from_file(&self) -> Result<String, AuthError> {
|
||||
let path = Path::new(&self.credentials_path);
|
||||
if path.exists() {
|
||||
match fs::read_to_string(path) {
|
||||
Ok(content) => {
|
||||
debug!("Successfully read credentials from file system");
|
||||
Ok(content)
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to read credentials file: {}", e);
|
||||
Err(AuthError::FileSystemError(e))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
debug!("No credentials found in file system");
|
||||
Err(AuthError::NotFound)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write_credentials(&self, content: &str) -> Result<(), AuthError> {
|
||||
Entry::new(KEYCHAIN_SERVICE, KEYCHAIN_USERNAME)
|
||||
.and_then(|entry| entry.set_password(content))
|
||||
.inspect(|_| {
|
||||
debug!("Successfully wrote credentials to keychain");
|
||||
})
|
||||
.or_else(|e| {
|
||||
if self.fallback_to_disk {
|
||||
warn!("Falling back to file system due to keyring error: {}", e);
|
||||
self.write_to_file(content)
|
||||
} else {
|
||||
Err(AuthError::KeyringError(e))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn write_to_file(&self, content: &str) -> Result<(), AuthError> {
|
||||
let path = Path::new(&self.credentials_path);
|
||||
if let Some(parent) = path.parent() {
|
||||
if !parent.exists() {
|
||||
match fs::create_dir_all(parent) {
|
||||
Ok(_) => debug!("Created parent directories for credentials file"),
|
||||
Err(e) => {
|
||||
error!("Failed to create directories for credentials file: {}", e);
|
||||
return Err(AuthError::FileSystemError(e));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match fs::write(path, content) {
|
||||
Ok(_) => {
|
||||
debug!("Successfully wrote credentials to file system");
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to write credentials to file system: {}", e);
|
||||
Err(AuthError::FileSystemError(e))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Storage entry that includes the token, scopes and project it's valid for
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct StorageEntry {
|
||||
token: TokenInfo,
|
||||
scopes: String,
|
||||
project_id: String,
|
||||
}
|
||||
|
||||
/// KeychainTokenStorage implements the TokenStorage trait from yup_oauth2
|
||||
/// to enable secure storage of OAuth tokens in the system keychain.
|
||||
pub struct KeychainTokenStorage {
|
||||
project_id: String,
|
||||
credentials_manager: Arc<CredentialsManager>,
|
||||
}
|
||||
|
||||
impl KeychainTokenStorage {
|
||||
/// Create a new KeychainTokenStorage with the given CredentialsManager
|
||||
pub fn new(project_id: String, credentials_manager: Arc<CredentialsManager>) -> Self {
|
||||
Self {
|
||||
project_id,
|
||||
credentials_manager,
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_scoped_key(&self, scopes: &[&str]) -> String {
|
||||
// Create a key based on the scopes and project_id
|
||||
// Sort so we can be consistent using scopes as the key
|
||||
let mut sorted_scopes = scopes.to_vec();
|
||||
sorted_scopes.sort();
|
||||
sorted_scopes.join(" ")
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TokenStorage for KeychainTokenStorage {
|
||||
/// Store a token in the keychain
|
||||
async fn set(&self, scopes: &[&str], token_info: TokenInfo) -> Result<()> {
|
||||
let key = self.generate_scoped_key(scopes);
|
||||
|
||||
// Create a storage entry that includes the scopes
|
||||
let storage_entry = StorageEntry {
|
||||
token: token_info,
|
||||
scopes: key,
|
||||
project_id: self.project_id.clone(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&storage_entry)?;
|
||||
self.credentials_manager
|
||||
.write_credentials(&json)
|
||||
.map_err(|e| {
|
||||
error!("Failed to write token to keychain: {}", e);
|
||||
anyhow::anyhow!("Failed to write token to keychain: {}", e)
|
||||
})
|
||||
}
|
||||
|
||||
/// Retrieve a token from the keychain
|
||||
async fn get(&self, scopes: &[&str]) -> Option<TokenInfo> {
|
||||
let key = self.generate_scoped_key(scopes);
|
||||
|
||||
match self.credentials_manager.read_credentials() {
|
||||
Ok(json) => {
|
||||
debug!("Successfully read credentials from storage");
|
||||
match serde_json::from_str::<StorageEntry>(&json) {
|
||||
Ok(entry) => {
|
||||
// Check if token has the requested scopes and matches the project_id
|
||||
if entry.project_id == self.project_id && entry.scopes == key {
|
||||
debug!("Successfully retrieved OAuth token from storage");
|
||||
Some(entry.token)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to deserialize token from storage: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(AuthError::NotFound) => {
|
||||
debug!("No OAuth token found in storage");
|
||||
None
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Error reading OAuth token from storage: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for CredentialsManager {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
credentials_path: self.credentials_path.clone(),
|
||||
fallback_to_disk: self.fallback_to_disk,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serial_test::serial;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_token_storage_set_get() {
|
||||
// Create a temporary file for testing
|
||||
let temp_file = NamedTempFile::new().unwrap();
|
||||
let project_id = "test_project_1".to_string();
|
||||
let credentials_manager = Arc::new(CredentialsManager::new(
|
||||
temp_file.path().to_string_lossy().to_string(),
|
||||
));
|
||||
|
||||
let storage = KeychainTokenStorage::new(project_id, credentials_manager);
|
||||
|
||||
// Create a test token
|
||||
let token_info = TokenInfo {
|
||||
access_token: Some("test_access_token".to_string()),
|
||||
refresh_token: Some("test_refresh_token".to_string()),
|
||||
expires_at: None,
|
||||
id_token: None,
|
||||
};
|
||||
|
||||
let scopes = &["https://www.googleapis.com/auth/drive.readonly"];
|
||||
|
||||
// Store the token
|
||||
storage.set(scopes, token_info.clone()).await.unwrap();
|
||||
|
||||
// Retrieve the token
|
||||
let retrieved = storage.get(scopes).await.unwrap();
|
||||
|
||||
// Verify the token matches
|
||||
assert_eq!(retrieved.access_token, token_info.access_token);
|
||||
assert_eq!(retrieved.refresh_token, token_info.refresh_token);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_token_storage_scope_mismatch() {
|
||||
// Create a temporary file for testing
|
||||
let temp_file = NamedTempFile::new().unwrap();
|
||||
let project_id = "test_project_2".to_string();
|
||||
let credentials_manager = Arc::new(CredentialsManager::new(
|
||||
temp_file.path().to_string_lossy().to_string(),
|
||||
));
|
||||
|
||||
let storage = KeychainTokenStorage::new(project_id, credentials_manager);
|
||||
|
||||
// Create a test token
|
||||
let token_info = TokenInfo {
|
||||
access_token: Some("test_access_token".to_string()),
|
||||
refresh_token: Some("test_refresh_token".to_string()),
|
||||
expires_at: None,
|
||||
id_token: None,
|
||||
};
|
||||
|
||||
let scopes1 = &["https://www.googleapis.com/auth/drive.readonly"];
|
||||
let scopes2 = &["https://www.googleapis.com/auth/drive.file"];
|
||||
|
||||
// Store the token with scopes1
|
||||
storage.set(scopes1, token_info).await.unwrap();
|
||||
|
||||
// Try to retrieve with different scopes
|
||||
let result = storage.get(scopes2).await;
|
||||
assert!(result.is_none());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user