feat(google_drive): move credentials into keychain, add optional fallback (#1603)

This commit is contained in:
Kalvin C
2025-03-12 09:28:45 -07:00
committed by GitHub
parent d6cb7c6d87
commit 22c8b32c78
4 changed files with 365 additions and 22 deletions

1
Cargo.lock generated
View File

@@ -2414,6 +2414,7 @@ dependencies = [
"image 0.24.9",
"include_dir",
"indoc",
"keyring",
"kill_tree",
"lazy_static",
"lopdf",

View File

@@ -43,7 +43,8 @@ lopdf = "0.35.0"
docx-rs = "0.4.7"
image = "0.24.9"
umya-spreadsheet = "2.2.3"
keyring = { version = "3.6.1", features = ["apple-native", "windows-native", "sync-secret-service"] }
[dev-dependencies]
serial_test = "3.0.0"
sysinfo = "0.32.1"
sysinfo = "0.32.1"

View File

@@ -1,9 +1,12 @@
mod token_storage;
use indoc::indoc;
use regex::Regex;
use serde_json::{json, Value};
use std::{env, fs, future::Future, path::Path, pin::Pin, sync::Arc};
use token_storage::{CredentialsManager, KeychainTokenStorage};
use std::{env, fs, future::Future, io::Write, path::Path, pin::Pin};
use mcp_core::content::Content;
use mcp_core::{
handler::{PromptError, ResourceError, ToolError},
prompt::Prompt,
@@ -14,8 +17,6 @@ use mcp_core::{
use mcp_server::router::CapabilitiesBuilder;
use mcp_server::Router;
use mcp_core::content::Content;
use google_drive3::{
self,
api::{File, Scope},
@@ -28,9 +29,7 @@ use google_drive3::{
},
DriveHub,
};
use google_sheets4::{self, Sheets};
use http_body_util::BodyExt;
/// async function to be pinned by the `present_user_url` method of the trait
@@ -70,14 +69,15 @@ pub struct GoogleDriveRouter {
instructions: String,
drive: DriveHub<HttpsConnector<HttpConnector>>,
sheets: Sheets<HttpsConnector<HttpConnector>>,
credentials_manager: Arc<CredentialsManager>,
}
impl GoogleDriveRouter {
async fn google_auth() -> (
DriveHub<HttpsConnector<HttpConnector>>,
Sheets<HttpsConnector<HttpConnector>>,
Arc<CredentialsManager>,
) {
let oauth_config = env::var("GOOGLE_DRIVE_OAUTH_CONFIG");
let keyfile_path_str = env::var("GOOGLE_DRIVE_OAUTH_PATH")
.unwrap_or_else(|_| "./gcp-oauth.keys.json".to_string());
let credentials_path_str = env::var("GOOGLE_DRIVE_CREDENTIALS_PATH")
@@ -87,7 +87,7 @@ impl GoogleDriveRouter {
let keyfile_path = Path::new(expanded_keyfile.as_ref());
let expanded_credentials = shellexpand::tilde(credentials_path_str.as_str());
let credentials_path = Path::new(expanded_credentials.as_ref());
let credentials_path = expanded_credentials.to_string();
tracing::info!(
credentials_path = credentials_path_str,
@@ -95,35 +95,72 @@ impl GoogleDriveRouter {
"Google Drive MCP server authentication config paths"
);
if !keyfile_path.exists() && oauth_config.is_ok() {
// attempt to create the path
if let Some(parent_dir) = keyfile_path.parent() {
let _ = fs::create_dir_all(parent_dir);
if let Ok(oauth_config) = env::var("GOOGLE_DRIVE_OAUTH_CONFIG") {
// Ensure the parent directory exists (create_dir_all is idempotent)
if let Some(parent) = keyfile_path.parent() {
if let Err(e) = fs::create_dir_all(parent) {
tracing::error!(
"Failed to create parent directories for {}: {}",
keyfile_path.display(),
e
);
}
}
if let Ok(mut file) = fs::File::create(keyfile_path) {
let _ = file.write_all(oauth_config.unwrap().as_bytes());
tracing::debug!(
"Wrote Google Drive MCP server OAuth config to {}",
keyfile_path.display()
);
// Check if the file exists and whether its content matches
// in every other case we attempt to overwrite
let need_to_write = match fs::read_to_string(keyfile_path) {
Ok(existing) if existing == oauth_config => false,
Ok(_) | Err(_) => true,
};
// Overwrite the file if needed
if need_to_write {
if let Err(e) = fs::write(keyfile_path, &oauth_config) {
tracing::error!(
"Failed to write OAuth config to {}: {}",
keyfile_path.display(),
e
);
} else {
tracing::debug!(
"Wrote Google Drive MCP server OAuth config to {}",
keyfile_path.display()
);
}
}
}
// Create a credentials manager for storing tokens securely
let credentials_manager = Arc::new(CredentialsManager::new(credentials_path.clone()));
// Read the application secret from the OAuth keyfile
let secret = yup_oauth2::read_application_secret(keyfile_path)
.await
.expect("expected keyfile for google auth");
// Create custom token storage using our credentials manager
let token_storage = KeychainTokenStorage::new(
secret
.project_id
.clone()
.unwrap_or("unknown-project-id".to_string())
.to_string(),
credentials_manager.clone(),
);
// Create the authenticator with the installed flow
let auth = InstalledFlowAuthenticator::builder(
secret,
yup_oauth2::InstalledFlowReturnMethod::HTTPRedirect,
)
.persist_tokens_to_disk(credentials_path)
.with_storage(Box::new(token_storage)) // Use our custom storage
.flow_delegate(Box::new(LocalhostBrowserDelegate))
.build()
.await
.expect("expected successful authentication");
// Create the HTTP client
let client =
hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
.build(
@@ -138,11 +175,12 @@ impl GoogleDriveRouter {
let drive_hub = DriveHub::new(client.clone(), auth.clone());
let sheets_hub = Sheets::new(client, auth);
(drive_hub, sheets_hub)
// Create and return the DriveHub
(drive_hub, sheets_hub, credentials_manager)
}
pub async fn new() -> Self {
let (drive, sheets) = Self::google_auth().await;
let (drive, sheets, credentials_manager) = Self::google_auth().await;
// handle auth
let search_tool = Tool::new(
@@ -302,6 +340,7 @@ impl GoogleDriveRouter {
instructions,
drive,
sheets,
credentials_manager,
}
}
@@ -851,6 +890,7 @@ impl Clone for GoogleDriveRouter {
instructions: self.instructions.clone(),
drive: self.drive.clone(),
sheets: self.sheets.clone(),
credentials_manager: self.credentials_manager.clone(),
}
}
}

View File

@@ -0,0 +1,301 @@
use anyhow::Result;
use google_drive3::yup_oauth2::storage::{TokenInfo, TokenStorage};
use keyring::Entry;
use std::env;
use std::fs;
use std::path::Path;
use std::sync::Arc;
use thiserror::Error;
use tracing::{debug, error, warn};
const KEYCHAIN_SERVICE: &str = "mcp_google_drive";
const KEYCHAIN_USERNAME: &str = "oauth_credentials";
const KEYCHAIN_DISK_FALLBACK_ENV: &str = "GOOGLE_DRIVE_DISK_FALLBACK";
#[allow(dead_code)]
#[derive(Error, Debug)]
pub enum AuthError {
#[error("Failed to access keychain: {0}")]
KeyringError(#[from] keyring::Error),
#[error("Failed to access file system: {0}")]
FileSystemError(#[from] std::io::Error),
#[error("No credentials found")]
NotFound,
#[error("Critical error: {0}")]
Critical(String),
#[error("Failed to serialize/deserialize: {0}")]
SerializationError(#[from] serde_json::Error),
}
/// CredentialsManager handles secure storage of OAuth credentials.
/// It attempts to store credentials in the system keychain first,
/// with fallback to file system storage if keychain access fails and fallback is enabled.
pub struct CredentialsManager {
credentials_path: String,
fallback_to_disk: bool,
}
impl CredentialsManager {
pub fn new(credentials_path: String) -> Self {
// Check if we should fall back to disk, must be explicitly enabled
let fallback_to_disk = match env::var(KEYCHAIN_DISK_FALLBACK_ENV) {
Ok(value) => value.to_lowercase() == "true",
Err(_) => false,
};
Self {
credentials_path,
fallback_to_disk,
}
}
pub fn read_credentials(&self) -> Result<String, AuthError> {
Entry::new(KEYCHAIN_SERVICE, KEYCHAIN_USERNAME)
.and_then(|entry| entry.get_password())
.inspect(|_| {
debug!("Successfully read credentials from keychain");
})
.or_else(|e| {
if self.fallback_to_disk {
debug!("Falling back to file system due to keyring error: {}", e);
self.read_from_file()
} else {
match e {
keyring::Error::NoEntry => Err(AuthError::NotFound),
_ => Err(AuthError::KeyringError(e)),
}
}
})
}
fn read_from_file(&self) -> Result<String, AuthError> {
let path = Path::new(&self.credentials_path);
if path.exists() {
match fs::read_to_string(path) {
Ok(content) => {
debug!("Successfully read credentials from file system");
Ok(content)
}
Err(e) => {
error!("Failed to read credentials file: {}", e);
Err(AuthError::FileSystemError(e))
}
}
} else {
debug!("No credentials found in file system");
Err(AuthError::NotFound)
}
}
pub fn write_credentials(&self, content: &str) -> Result<(), AuthError> {
Entry::new(KEYCHAIN_SERVICE, KEYCHAIN_USERNAME)
.and_then(|entry| entry.set_password(content))
.inspect(|_| {
debug!("Successfully wrote credentials to keychain");
})
.or_else(|e| {
if self.fallback_to_disk {
warn!("Falling back to file system due to keyring error: {}", e);
self.write_to_file(content)
} else {
Err(AuthError::KeyringError(e))
}
})
}
fn write_to_file(&self, content: &str) -> Result<(), AuthError> {
let path = Path::new(&self.credentials_path);
if let Some(parent) = path.parent() {
if !parent.exists() {
match fs::create_dir_all(parent) {
Ok(_) => debug!("Created parent directories for credentials file"),
Err(e) => {
error!("Failed to create directories for credentials file: {}", e);
return Err(AuthError::FileSystemError(e));
}
}
}
}
match fs::write(path, content) {
Ok(_) => {
debug!("Successfully wrote credentials to file system");
Ok(())
}
Err(e) => {
error!("Failed to write credentials to file system: {}", e);
Err(AuthError::FileSystemError(e))
}
}
}
}
/// Storage entry that includes the token, scopes and project it's valid for
#[derive(serde::Serialize, serde::Deserialize)]
struct StorageEntry {
token: TokenInfo,
scopes: String,
project_id: String,
}
/// KeychainTokenStorage implements the TokenStorage trait from yup_oauth2
/// to enable secure storage of OAuth tokens in the system keychain.
pub struct KeychainTokenStorage {
project_id: String,
credentials_manager: Arc<CredentialsManager>,
}
impl KeychainTokenStorage {
/// Create a new KeychainTokenStorage with the given CredentialsManager
pub fn new(project_id: String, credentials_manager: Arc<CredentialsManager>) -> Self {
Self {
project_id,
credentials_manager,
}
}
fn generate_scoped_key(&self, scopes: &[&str]) -> String {
// Create a key based on the scopes and project_id
// Sort so we can be consistent using scopes as the key
let mut sorted_scopes = scopes.to_vec();
sorted_scopes.sort();
sorted_scopes.join(" ")
}
}
#[async_trait::async_trait]
impl TokenStorage for KeychainTokenStorage {
/// Store a token in the keychain
async fn set(&self, scopes: &[&str], token_info: TokenInfo) -> Result<()> {
let key = self.generate_scoped_key(scopes);
// Create a storage entry that includes the scopes
let storage_entry = StorageEntry {
token: token_info,
scopes: key,
project_id: self.project_id.clone(),
};
let json = serde_json::to_string(&storage_entry)?;
self.credentials_manager
.write_credentials(&json)
.map_err(|e| {
error!("Failed to write token to keychain: {}", e);
anyhow::anyhow!("Failed to write token to keychain: {}", e)
})
}
/// Retrieve a token from the keychain
async fn get(&self, scopes: &[&str]) -> Option<TokenInfo> {
let key = self.generate_scoped_key(scopes);
match self.credentials_manager.read_credentials() {
Ok(json) => {
debug!("Successfully read credentials from storage");
match serde_json::from_str::<StorageEntry>(&json) {
Ok(entry) => {
// Check if token has the requested scopes and matches the project_id
if entry.project_id == self.project_id && entry.scopes == key {
debug!("Successfully retrieved OAuth token from storage");
Some(entry.token)
} else {
None
}
}
Err(e) => {
warn!("Failed to deserialize token from storage: {}", e);
None
}
}
}
Err(AuthError::NotFound) => {
debug!("No OAuth token found in storage");
None
}
Err(e) => {
warn!("Error reading OAuth token from storage: {}", e);
None
}
}
}
}
impl Clone for CredentialsManager {
fn clone(&self) -> Self {
Self {
credentials_path: self.credentials_path.clone(),
fallback_to_disk: self.fallback_to_disk,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
use tempfile::NamedTempFile;
#[tokio::test]
#[serial]
async fn test_token_storage_set_get() {
// Create a temporary file for testing
let temp_file = NamedTempFile::new().unwrap();
let project_id = "test_project_1".to_string();
let credentials_manager = Arc::new(CredentialsManager::new(
temp_file.path().to_string_lossy().to_string(),
));
let storage = KeychainTokenStorage::new(project_id, credentials_manager);
// Create a test token
let token_info = TokenInfo {
access_token: Some("test_access_token".to_string()),
refresh_token: Some("test_refresh_token".to_string()),
expires_at: None,
id_token: None,
};
let scopes = &["https://www.googleapis.com/auth/drive.readonly"];
// Store the token
storage.set(scopes, token_info.clone()).await.unwrap();
// Retrieve the token
let retrieved = storage.get(scopes).await.unwrap();
// Verify the token matches
assert_eq!(retrieved.access_token, token_info.access_token);
assert_eq!(retrieved.refresh_token, token_info.refresh_token);
}
#[tokio::test]
#[serial]
async fn test_token_storage_scope_mismatch() {
// Create a temporary file for testing
let temp_file = NamedTempFile::new().unwrap();
let project_id = "test_project_2".to_string();
let credentials_manager = Arc::new(CredentialsManager::new(
temp_file.path().to_string_lossy().to_string(),
));
let storage = KeychainTokenStorage::new(project_id, credentials_manager);
// Create a test token
let token_info = TokenInfo {
access_token: Some("test_access_token".to_string()),
refresh_token: Some("test_refresh_token".to_string()),
expires_at: None,
id_token: None,
};
let scopes1 = &["https://www.googleapis.com/auth/drive.readonly"];
let scopes2 = &["https://www.googleapis.com/auth/drive.file"];
// Store the token with scopes1
storage.set(scopes1, token_info).await.unwrap();
// Try to retrieve with different scopes
let result = storage.get(scopes2).await;
assert!(result.is_none());
}
}