mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 22:54:24 +01:00
Add native OAuth 2.0 authentication support to MCP client (#3213)
Co-authored-by: Alex Hancock <alexhancock@block.xyz>
This commit is contained in:
committed by
GitHub
parent
825c2258ef
commit
883bc67d3b
7
Cargo.lock
generated
7
Cargo.lock
generated
@@ -5376,14 +5376,20 @@ version = "0.1.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
|
"axum",
|
||||||
|
"base64 0.22.1",
|
||||||
|
"chrono",
|
||||||
"eventsource-client",
|
"eventsource-client",
|
||||||
"futures",
|
"futures",
|
||||||
"mcp-core",
|
"mcp-core",
|
||||||
|
"nanoid",
|
||||||
"nix 0.30.1",
|
"nix 0.30.1",
|
||||||
"rand 0.8.5",
|
"rand 0.8.5",
|
||||||
"reqwest 0.11.27",
|
"reqwest 0.11.27",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"serde_urlencoded",
|
||||||
|
"sha2",
|
||||||
"thiserror 1.0.69",
|
"thiserror 1.0.69",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-util",
|
"tokio-util",
|
||||||
@@ -5392,6 +5398,7 @@ dependencies = [
|
|||||||
"tracing",
|
"tracing",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
"url",
|
"url",
|
||||||
|
"webbrowser 1.0.4",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|||||||
@@ -25,5 +25,13 @@ tower = { version = "0.4", features = ["timeout", "util"] }
|
|||||||
tower-service = "0.3"
|
tower-service = "0.3"
|
||||||
rand = "0.8"
|
rand = "0.8"
|
||||||
nix = { version = "0.30.1", features = ["process", "signal"] }
|
nix = { version = "0.30.1", features = ["process", "signal"] }
|
||||||
|
# OAuth dependencies
|
||||||
|
axum = { version = "0.8", features = ["query"] }
|
||||||
|
base64 = "0.22"
|
||||||
|
sha2 = "0.10"
|
||||||
|
chrono = { version = "0.4", features = ["serde"] }
|
||||||
|
nanoid = "0.4"
|
||||||
|
webbrowser = "1.0"
|
||||||
|
serde_urlencoded = "0.7"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
|||||||
64
crates/mcp-client/examples/test_auth.rs
Normal file
64
crates/mcp-client/examples/test_auth.rs
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
|
||||||
|
use mcp_client::transport::{StreamableHttpTransport, Transport};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tracing_subscriber::EnvFilter;
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<()> {
|
||||||
|
// Initialize logging
|
||||||
|
tracing_subscriber::fmt()
|
||||||
|
.with_env_filter(
|
||||||
|
EnvFilter::from_default_env()
|
||||||
|
.add_directive("mcp_client=debug".parse().unwrap())
|
||||||
|
.add_directive("eventsource_client=info".parse().unwrap()),
|
||||||
|
)
|
||||||
|
.init();
|
||||||
|
|
||||||
|
println!("Testing Streamable HTTP transport with OAuth 2.0 authentication...");
|
||||||
|
|
||||||
|
// Create the Streamable HTTP transport for any MCP service that supports OAuth
|
||||||
|
// This example uses a hypothetical MCP endpoint - replace with actual service
|
||||||
|
let mcp_endpoint =
|
||||||
|
std::env::var("MCP_ENDPOINT").unwrap_or_else(|_| "https://example.com/mcp".to_string());
|
||||||
|
|
||||||
|
println!("Using MCP endpoint: {}", mcp_endpoint);
|
||||||
|
|
||||||
|
let transport = StreamableHttpTransport::new(&mcp_endpoint, HashMap::new());
|
||||||
|
|
||||||
|
// Start transport
|
||||||
|
let handle = transport.start().await?;
|
||||||
|
|
||||||
|
// Create client
|
||||||
|
let mut client = McpClient::connect(handle, Duration::from_secs(30)).await?;
|
||||||
|
println!("Client created with Streamable HTTP transport\n");
|
||||||
|
|
||||||
|
// Initialize - this will trigger the OAuth flow if authentication is needed
|
||||||
|
// The implementation now includes:
|
||||||
|
// - RFC 8707 Resource Parameter support for proper token audience binding
|
||||||
|
// - Proper OAuth 2.0 discovery with multiple fallback paths
|
||||||
|
// - Dynamic client registration (RFC 7591)
|
||||||
|
// - PKCE for security (RFC 7636)
|
||||||
|
// - MCP-Protocol-Version header as required by the specification
|
||||||
|
let server_info = client
|
||||||
|
.initialize(
|
||||||
|
ClientInfo {
|
||||||
|
name: "streamable-http-auth-test".into(),
|
||||||
|
version: "1.0.0".into(),
|
||||||
|
},
|
||||||
|
ClientCapabilities::default(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
println!("Connected to server: {server_info:?}\n");
|
||||||
|
println!("OAuth 2.0 authentication test completed successfully!");
|
||||||
|
println!("\nKey improvements implemented:");
|
||||||
|
println!("✓ RFC 8707 Resource Parameter implementation");
|
||||||
|
println!("✓ MCP-Protocol-Version header support");
|
||||||
|
println!("✓ Enhanced OAuth discovery with multiple fallback paths");
|
||||||
|
println!("✓ Proper canonical resource URI generation");
|
||||||
|
println!("✓ Full compliance with MCP Authorization specification");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -1,8 +1,13 @@
|
|||||||
pub mod client;
|
pub mod client;
|
||||||
|
pub mod oauth;
|
||||||
pub mod service;
|
pub mod service;
|
||||||
pub mod transport;
|
pub mod transport;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod oauth_tests;
|
||||||
|
|
||||||
pub use client::{ClientCapabilities, ClientInfo, Error, McpClient, McpClientTrait};
|
pub use client::{ClientCapabilities, ClientInfo, Error, McpClient, McpClientTrait};
|
||||||
|
pub use oauth::{authenticate_service, ServiceConfig};
|
||||||
pub use service::McpService;
|
pub use service::McpService;
|
||||||
pub use transport::{
|
pub use transport::{
|
||||||
SseTransport, StdioTransport, StreamableHttpTransport, Transport, TransportHandle,
|
SseTransport, StdioTransport, StreamableHttpTransport, Transport, TransportHandle,
|
||||||
|
|||||||
456
crates/mcp-client/src/oauth.rs
Normal file
456
crates/mcp-client/src/oauth.rs
Normal file
@@ -0,0 +1,456 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use axum::{extract::Query, response::Html, routing::get, Router};
|
||||||
|
use base64::Engine;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
use sha2::Digest;
|
||||||
|
use std::{collections::HashMap, net::SocketAddr, sync::Arc};
|
||||||
|
use tokio::sync::{oneshot, Mutex as TokioMutex};
|
||||||
|
use url::Url;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct OidcEndpoints {
|
||||||
|
authorization_endpoint: String,
|
||||||
|
token_endpoint: String,
|
||||||
|
registration_endpoint: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
struct TokenData {
|
||||||
|
access_token: String,
|
||||||
|
refresh_token: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
struct ClientRegistrationRequest {
|
||||||
|
redirect_uris: Vec<String>,
|
||||||
|
token_endpoint_auth_method: String,
|
||||||
|
grant_types: Vec<String>,
|
||||||
|
response_types: Vec<String>,
|
||||||
|
client_name: String,
|
||||||
|
client_uri: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
struct ClientRegistrationResponse {
|
||||||
|
client_id: String,
|
||||||
|
client_id_issued_at: Option<u64>,
|
||||||
|
#[serde(default)]
|
||||||
|
client_secret: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// OAuth configuration for any service
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ServiceConfig {
|
||||||
|
pub oauth_host: String,
|
||||||
|
pub redirect_uri: String,
|
||||||
|
pub client_name: String,
|
||||||
|
pub client_uri: String,
|
||||||
|
pub discovery_path: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ServiceConfig {
|
||||||
|
/// Create a generic OAuth configuration from an MCP endpoint URL
|
||||||
|
/// Extracts the base URL for OAuth discovery
|
||||||
|
pub fn from_mcp_endpoint(mcp_url: &str) -> Result<Self> {
|
||||||
|
let parsed_url = Url::parse(mcp_url.trim())?;
|
||||||
|
let oauth_host = format!(
|
||||||
|
"{}://{}{}",
|
||||||
|
parsed_url.scheme(),
|
||||||
|
parsed_url.host_str().ok_or_else(|| {
|
||||||
|
anyhow::anyhow!("Invalid MCP URL: no host found in {}", mcp_url)
|
||||||
|
})?,
|
||||||
|
if let Some(port) = parsed_url.port() {
|
||||||
|
format!(":{}", port)
|
||||||
|
} else {
|
||||||
|
String::new()
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
oauth_host,
|
||||||
|
redirect_uri: "http://localhost:8020".to_string(),
|
||||||
|
client_name: "Goose MCP Client".to_string(),
|
||||||
|
client_uri: "https://github.com/block/goose".to_string(),
|
||||||
|
discovery_path: None, // Use standard discovery
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create configuration with custom discovery path for non-standard services
|
||||||
|
pub fn with_custom_discovery(mut self, discovery_path: String) -> Self {
|
||||||
|
self.discovery_path = Some(discovery_path);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the canonical resource URI for the MCP server
|
||||||
|
/// This is used as the resource parameter in OAuth requests (RFC 8707)
|
||||||
|
pub fn get_canonical_resource_uri(&self, mcp_url: &str) -> Result<String> {
|
||||||
|
let parsed_url = Url::parse(mcp_url.trim())?;
|
||||||
|
|
||||||
|
// Build canonical URI: scheme://host[:port][/path]
|
||||||
|
let mut canonical = format!(
|
||||||
|
"{}://{}",
|
||||||
|
parsed_url.scheme().to_lowercase(),
|
||||||
|
parsed_url
|
||||||
|
.host_str()
|
||||||
|
.ok_or_else(|| {
|
||||||
|
anyhow::anyhow!("Invalid MCP URL: no host found in {}", mcp_url)
|
||||||
|
})?
|
||||||
|
.to_lowercase()
|
||||||
|
);
|
||||||
|
|
||||||
|
// Add port if not default
|
||||||
|
if let Some(port) = parsed_url.port() {
|
||||||
|
canonical.push_str(&format!(":{}", port));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add path if present and not just "/"
|
||||||
|
let path = parsed_url.path();
|
||||||
|
if !path.is_empty() && path != "/" {
|
||||||
|
canonical.push_str(path);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(canonical)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct OAuthFlow {
|
||||||
|
endpoints: OidcEndpoints,
|
||||||
|
client_id: String,
|
||||||
|
redirect_url: String,
|
||||||
|
state: String,
|
||||||
|
verifier: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OAuthFlow {
|
||||||
|
fn new(endpoints: OidcEndpoints, client_id: String, redirect_url: String) -> Self {
|
||||||
|
Self {
|
||||||
|
endpoints,
|
||||||
|
client_id,
|
||||||
|
redirect_url,
|
||||||
|
state: nanoid::nanoid!(16),
|
||||||
|
verifier: nanoid::nanoid!(64),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register a dynamic client and return the client_id
|
||||||
|
async fn register_client(endpoints: &OidcEndpoints, config: &ServiceConfig) -> Result<String> {
|
||||||
|
let Some(registration_endpoint) = &endpoints.registration_endpoint else {
|
||||||
|
return Err(anyhow::anyhow!("No registration endpoint available"));
|
||||||
|
};
|
||||||
|
|
||||||
|
let registration_request = ClientRegistrationRequest {
|
||||||
|
redirect_uris: vec![config.redirect_uri.clone()],
|
||||||
|
token_endpoint_auth_method: "none".to_string(),
|
||||||
|
grant_types: vec![
|
||||||
|
"authorization_code".to_string(),
|
||||||
|
"refresh_token".to_string(),
|
||||||
|
],
|
||||||
|
response_types: vec!["code".to_string()],
|
||||||
|
client_name: config.client_name.clone(),
|
||||||
|
client_uri: config.client_uri.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
tracing::info!("Registering dynamic client with OAuth server...");
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
let resp = client
|
||||||
|
.post(registration_endpoint)
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.json(®istration_request)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if !resp.status().is_success() {
|
||||||
|
let status = resp.status();
|
||||||
|
let err_text = resp.text().await?;
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"Failed to register client: {} - {}",
|
||||||
|
status,
|
||||||
|
err_text
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let registration_response: ClientRegistrationResponse = resp.json().await?;
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
"Client registered successfully with ID: {}",
|
||||||
|
registration_response.client_id
|
||||||
|
);
|
||||||
|
Ok(registration_response.client_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_authorization_url(&self, resource: &str) -> String {
|
||||||
|
let challenge = {
|
||||||
|
let digest = sha2::Sha256::digest(self.verifier.as_bytes());
|
||||||
|
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest)
|
||||||
|
};
|
||||||
|
|
||||||
|
let params = [
|
||||||
|
("response_type", "code"),
|
||||||
|
("client_id", &self.client_id),
|
||||||
|
("redirect_uri", &self.redirect_url),
|
||||||
|
("state", &self.state),
|
||||||
|
("code_challenge", &challenge),
|
||||||
|
("code_challenge_method", "S256"),
|
||||||
|
("resource", resource), // RFC 8707 Resource Parameter
|
||||||
|
];
|
||||||
|
|
||||||
|
format!(
|
||||||
|
"{}?{}",
|
||||||
|
self.endpoints.authorization_endpoint,
|
||||||
|
serde_urlencoded::to_string(params).unwrap()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn exchange_code_for_token(&self, code: &str, resource: &str) -> Result<TokenData> {
|
||||||
|
let params = [
|
||||||
|
("grant_type", "authorization_code"),
|
||||||
|
("code", code),
|
||||||
|
("redirect_uri", &self.redirect_url),
|
||||||
|
("code_verifier", &self.verifier),
|
||||||
|
("client_id", &self.client_id),
|
||||||
|
("resource", resource), // RFC 8707 Resource Parameter
|
||||||
|
];
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
let resp = client
|
||||||
|
.post(&self.endpoints.token_endpoint)
|
||||||
|
.header("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
.form(¶ms)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if !resp.status().is_success() {
|
||||||
|
let err_text = resp.text().await?;
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"Failed to exchange code for token: {}",
|
||||||
|
err_text
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let token_response: Value = resp.json().await?;
|
||||||
|
|
||||||
|
let access_token = token_response
|
||||||
|
.get("access_token")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("access_token not found in token response"))?
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
let refresh_token = token_response
|
||||||
|
.get("refresh_token")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
|
||||||
|
Ok(TokenData {
|
||||||
|
access_token,
|
||||||
|
refresh_token,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, resource: &str) -> Result<TokenData> {
|
||||||
|
// Create a channel that will send the auth code from the callback
|
||||||
|
let (tx, rx) = oneshot::channel();
|
||||||
|
let state = self.state.clone();
|
||||||
|
let tx = Arc::new(TokioMutex::new(Some(tx)));
|
||||||
|
|
||||||
|
// Setup a server that will receive the redirect and capture the code
|
||||||
|
let app = Router::new().route(
|
||||||
|
"/",
|
||||||
|
get(move |Query(params): Query<HashMap<String, String>>| {
|
||||||
|
let tx = Arc::clone(&tx);
|
||||||
|
let state = state.clone();
|
||||||
|
async move {
|
||||||
|
let code = params.get("code").cloned();
|
||||||
|
let received_state = params.get("state").cloned();
|
||||||
|
|
||||||
|
if let (Some(code), Some(received_state)) = (code, received_state) {
|
||||||
|
if received_state == state {
|
||||||
|
if let Some(sender) = tx.lock().await.take() {
|
||||||
|
if sender.send(code).is_ok() {
|
||||||
|
return Html(
|
||||||
|
"<h2>Authentication Successful!</h2><p>You can close this window and return to the application.</p>",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Html("<h2>Error</h2><p>Authentication already completed.</p>")
|
||||||
|
} else {
|
||||||
|
Html("<h2>Error</h2><p>State mismatch - possible security issue.</p>")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Html("<h2>Error</h2><p>Authentication failed - missing parameters.</p>")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Start the callback server
|
||||||
|
let redirect_url = Url::parse(&self.redirect_url)?;
|
||||||
|
let port = redirect_url.port().unwrap_or(8020);
|
||||||
|
let addr = SocketAddr::from(([127, 0, 0, 1], port));
|
||||||
|
|
||||||
|
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||||
|
|
||||||
|
let server_handle = tokio::spawn(async move {
|
||||||
|
let server = axum::serve(listener, app);
|
||||||
|
server.await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Open the browser for OAuth
|
||||||
|
let authorization_url = self.get_authorization_url(resource);
|
||||||
|
tracing::info!("Opening browser for OAuth authentication...");
|
||||||
|
|
||||||
|
if webbrowser::open(&authorization_url).is_err() {
|
||||||
|
tracing::warn!("Could not open browser automatically. Please open this URL manually:");
|
||||||
|
tracing::warn!("{}", authorization_url);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the authorization code with a timeout
|
||||||
|
let code = tokio::time::timeout(
|
||||||
|
std::time::Duration::from_secs(120), // 2 minute timeout
|
||||||
|
rx,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(|_| anyhow::anyhow!("Authentication timed out after 2 minutes"))??;
|
||||||
|
|
||||||
|
// Stop the callback server
|
||||||
|
server_handle.abort();
|
||||||
|
|
||||||
|
// Exchange the code for a token
|
||||||
|
self.exchange_code_for_token(&code, resource).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_oauth_endpoints(
|
||||||
|
host: &str,
|
||||||
|
custom_discovery_path: Option<&str>,
|
||||||
|
) -> Result<OidcEndpoints> {
|
||||||
|
let base_url = Url::parse(host)?;
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
|
// Define discovery paths to try, with custom path first if provided
|
||||||
|
let mut discovery_paths = Vec::new();
|
||||||
|
if let Some(custom_path) = custom_discovery_path {
|
||||||
|
discovery_paths.push(custom_path);
|
||||||
|
}
|
||||||
|
discovery_paths.extend([
|
||||||
|
"/.well-known/oauth-authorization-server",
|
||||||
|
"/.well-known/openid_configuration",
|
||||||
|
"/oauth/.well-known/oauth-authorization-server",
|
||||||
|
"/.well-known/oauth_authorization_server", // Some services use underscore
|
||||||
|
]);
|
||||||
|
|
||||||
|
let discovery_paths_for_error = discovery_paths.clone(); // Clone for error message
|
||||||
|
let mut last_error = None;
|
||||||
|
|
||||||
|
// Try each discovery path until one works
|
||||||
|
for path in discovery_paths {
|
||||||
|
match base_url.join(path) {
|
||||||
|
Ok(discovery_url) => {
|
||||||
|
tracing::debug!("Trying OAuth discovery at: {}", discovery_url);
|
||||||
|
|
||||||
|
match client.get(discovery_url.clone()).send().await {
|
||||||
|
Ok(resp) if resp.status().is_success() => {
|
||||||
|
match resp.json::<Value>().await {
|
||||||
|
Ok(oidc_config) => {
|
||||||
|
// Try to parse the OAuth configuration
|
||||||
|
match parse_oauth_config(oidc_config) {
|
||||||
|
Ok(endpoints) => {
|
||||||
|
tracing::info!(
|
||||||
|
"Successfully discovered OAuth endpoints at: {}",
|
||||||
|
discovery_url
|
||||||
|
);
|
||||||
|
return Ok(endpoints);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::debug!(
|
||||||
|
"Invalid OAuth config at {}: {}",
|
||||||
|
discovery_url,
|
||||||
|
e
|
||||||
|
);
|
||||||
|
last_error = Some(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::debug!(
|
||||||
|
"Failed to parse JSON from {}: {}",
|
||||||
|
discovery_url,
|
||||||
|
e
|
||||||
|
);
|
||||||
|
last_error = Some(e.into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(resp) => {
|
||||||
|
tracing::debug!("HTTP {} from {}", resp.status(), discovery_url);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::debug!("Request failed to {}: {}", discovery_url, e);
|
||||||
|
last_error = Some(e.into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::debug!("Invalid discovery URL {}{}: {}", host, path, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(last_error.unwrap_or_else(|| {
|
||||||
|
anyhow::anyhow!(
|
||||||
|
"No OAuth discovery endpoint found at {}. Tried paths: {:?}",
|
||||||
|
host,
|
||||||
|
discovery_paths_for_error
|
||||||
|
)
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_oauth_config(oidc_config: Value) -> Result<OidcEndpoints> {
|
||||||
|
let authorization_endpoint = oidc_config
|
||||||
|
.get("authorization_endpoint")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("authorization_endpoint not found in OAuth configuration"))?
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
let token_endpoint = oidc_config
|
||||||
|
.get("token_endpoint")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("token_endpoint not found in OAuth configuration"))?
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
let registration_endpoint = oidc_config
|
||||||
|
.get("registration_endpoint")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
|
||||||
|
Ok(OidcEndpoints {
|
||||||
|
authorization_endpoint,
|
||||||
|
token_endpoint,
|
||||||
|
registration_endpoint,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Perform OAuth flow for a service
|
||||||
|
pub async fn authenticate_service(config: ServiceConfig, mcp_url: &str) -> Result<String> {
|
||||||
|
tracing::info!("Starting OAuth authentication for service...");
|
||||||
|
|
||||||
|
// Get the canonical resource URI for the MCP server
|
||||||
|
let resource_uri = config.get_canonical_resource_uri(mcp_url)?;
|
||||||
|
tracing::info!("Using resource URI: {}", resource_uri);
|
||||||
|
|
||||||
|
// Get OAuth endpoints using flexible discovery
|
||||||
|
let endpoints =
|
||||||
|
get_oauth_endpoints(&config.oauth_host, config.discovery_path.as_deref()).await?;
|
||||||
|
|
||||||
|
// Register dynamic client to get client_id
|
||||||
|
let client_id = OAuthFlow::register_client(&endpoints, &config).await?;
|
||||||
|
|
||||||
|
// Create and execute OAuth flow with the dynamic client_id
|
||||||
|
let flow = OAuthFlow::new(endpoints, client_id, config.redirect_uri);
|
||||||
|
|
||||||
|
let token_data = flow.execute(&resource_uri).await?;
|
||||||
|
|
||||||
|
tracing::info!("OAuth authentication successful!");
|
||||||
|
Ok(token_data.access_token)
|
||||||
|
}
|
||||||
81
crates/mcp-client/src/oauth_tests.rs
Normal file
81
crates/mcp-client/src/oauth_tests.rs
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use crate::oauth::ServiceConfig;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_canonical_resource_uri_generation() {
|
||||||
|
let config = ServiceConfig {
|
||||||
|
oauth_host: "https://example.com".to_string(),
|
||||||
|
redirect_uri: "http://localhost:8020".to_string(),
|
||||||
|
client_name: "Test Client".to_string(),
|
||||||
|
client_uri: "https://test.com".to_string(),
|
||||||
|
discovery_path: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Test basic URL
|
||||||
|
let result = config
|
||||||
|
.get_canonical_resource_uri("https://mcp.example.com/mcp")
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(result, "https://mcp.example.com/mcp");
|
||||||
|
|
||||||
|
// Test URL with port
|
||||||
|
let result = config
|
||||||
|
.get_canonical_resource_uri("https://mcp.example.com:8443/mcp")
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(result, "https://mcp.example.com:8443/mcp");
|
||||||
|
|
||||||
|
// Test URL without path
|
||||||
|
let result = config
|
||||||
|
.get_canonical_resource_uri("https://mcp.example.com")
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(result, "https://mcp.example.com");
|
||||||
|
|
||||||
|
// Test URL with root path
|
||||||
|
let result = config
|
||||||
|
.get_canonical_resource_uri("https://mcp.example.com/")
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(result, "https://mcp.example.com");
|
||||||
|
|
||||||
|
// Test case normalization
|
||||||
|
let result = config
|
||||||
|
.get_canonical_resource_uri("HTTPS://MCP.EXAMPLE.COM/mcp")
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(result, "https://mcp.example.com/mcp");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_service_config_from_mcp_endpoint() {
|
||||||
|
let config = ServiceConfig::from_mcp_endpoint("https://mcp.example.com/api/mcp").unwrap();
|
||||||
|
|
||||||
|
assert_eq!(config.oauth_host, "https://mcp.example.com");
|
||||||
|
assert_eq!(config.redirect_uri, "http://localhost:8020");
|
||||||
|
assert_eq!(config.client_name, "Goose MCP Client");
|
||||||
|
assert_eq!(config.client_uri, "https://github.com/block/goose");
|
||||||
|
assert!(config.discovery_path.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_service_config_with_port() {
|
||||||
|
let config = ServiceConfig::from_mcp_endpoint("https://mcp.example.com:8443/mcp").unwrap();
|
||||||
|
|
||||||
|
assert_eq!(config.oauth_host, "https://mcp.example.com:8443");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_service_config_invalid_url() {
|
||||||
|
let result = ServiceConfig::from_mcp_endpoint("invalid-url");
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_custom_discovery_path() {
|
||||||
|
let config = ServiceConfig::from_mcp_endpoint("https://mcp.example.com/mcp")
|
||||||
|
.unwrap()
|
||||||
|
.with_custom_discovery("/custom/oauth/discovery".to_string());
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
config.discovery_path,
|
||||||
|
Some("/custom/oauth/discovery".to_string())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
use crate::oauth::{authenticate_service, ServiceConfig};
|
||||||
use crate::transport::Error;
|
use crate::transport::Error;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use eventsource_client::{Client, SSE};
|
use eventsource_client::{Client, SSE};
|
||||||
@@ -8,7 +9,7 @@ use std::collections::HashMap;
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::{mpsc, Mutex, RwLock};
|
use tokio::sync::{mpsc, Mutex, RwLock};
|
||||||
use tokio::time::Duration;
|
use tokio::time::Duration;
|
||||||
use tracing::{debug, error, warn};
|
use tracing::{debug, error, info, warn};
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
use super::{serialize_and_send, Transport, TransportHandle};
|
use super::{serialize_and_send, Transport, TransportHandle};
|
||||||
@@ -91,13 +92,46 @@ impl StreamableHttpActor {
|
|||||||
JsonRpcMessage::Request(JsonRpcRequest { id: Some(_), .. })
|
JsonRpcMessage::Request(JsonRpcRequest { id: Some(_), .. })
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Try to send the request
|
||||||
|
match self.send_request(&message_str, expects_response).await {
|
||||||
|
Ok(()) => Ok(()),
|
||||||
|
Err(Error::HttpError { status, .. }) if status == 401 || status == 403 => {
|
||||||
|
// Authentication challenge - try to authenticate and retry
|
||||||
|
info!(
|
||||||
|
"Received authentication challenge ({}), attempting OAuth flow...",
|
||||||
|
status
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Some(token) = self.attempt_authentication().await? {
|
||||||
|
info!("Authentication successful, retrying request...");
|
||||||
|
self.headers
|
||||||
|
.insert("Authorization".to_string(), format!("Bearer {}", token));
|
||||||
|
self.send_request(&message_str, expects_response).await
|
||||||
|
} else {
|
||||||
|
Err(Error::StreamableHttpError(
|
||||||
|
"Authentication failed - service not supported or OAuth flow failed"
|
||||||
|
.to_string(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => Err(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send an HTTP request to the MCP endpoint
|
||||||
|
async fn send_request(
|
||||||
|
&mut self,
|
||||||
|
message_str: &str,
|
||||||
|
expects_response: bool,
|
||||||
|
) -> Result<(), Error> {
|
||||||
// Build the HTTP request
|
// Build the HTTP request
|
||||||
let mut request = self
|
let mut request = self
|
||||||
.http_client
|
.http_client
|
||||||
.post(&self.mcp_endpoint)
|
.post(&self.mcp_endpoint)
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
.header("Accept", "application/json, text/event-stream")
|
.header("Accept", "application/json, text/event-stream")
|
||||||
.body(message_str);
|
.header("MCP-Protocol-Version", "2025-06-18") // Required protocol version header
|
||||||
|
.body(message_str.to_string());
|
||||||
|
|
||||||
// Add session ID header if we have one
|
// Add session ID header if we have one
|
||||||
if let Some(session_id) = self.session_id.read().await.as_ref() {
|
if let Some(session_id) = self.session_id.read().await.as_ref() {
|
||||||
@@ -173,6 +207,36 @@ impl StreamableHttpActor {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Attempt to authenticate with the service
|
||||||
|
async fn attempt_authentication(&self) -> Result<Option<String>, Error> {
|
||||||
|
info!("Attempting to authenticate with service...");
|
||||||
|
|
||||||
|
// Create a generic OAuth configuration from the MCP endpoint
|
||||||
|
match ServiceConfig::from_mcp_endpoint(&self.mcp_endpoint) {
|
||||||
|
Ok(config) => {
|
||||||
|
info!("Created OAuth config for endpoint: {}", self.mcp_endpoint);
|
||||||
|
|
||||||
|
match authenticate_service(config, &self.mcp_endpoint).await {
|
||||||
|
Ok(token) => {
|
||||||
|
info!("OAuth authentication successful!");
|
||||||
|
Ok(Some(token))
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("OAuth authentication failed: {}", e);
|
||||||
|
Err(Error::StreamableHttpError(format!("OAuth failed: {}", e)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!(
|
||||||
|
"Could not create OAuth config from MCP endpoint {}: {}",
|
||||||
|
self.mcp_endpoint, e
|
||||||
|
);
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Handle streaming HTTP response that uses Server-Sent Events format
|
/// Handle streaming HTTP response that uses Server-Sent Events format
|
||||||
///
|
///
|
||||||
/// This is called when the server responds to an HTTP POST with `text/event-stream`
|
/// This is called when the server responds to an HTTP POST with `text/event-stream`
|
||||||
@@ -263,7 +327,8 @@ impl StreamableHttpTransportHandle {
|
|||||||
let mut request = self
|
let mut request = self
|
||||||
.http_client
|
.http_client
|
||||||
.delete(&self.mcp_endpoint)
|
.delete(&self.mcp_endpoint)
|
||||||
.header("Mcp-Session-Id", session_id);
|
.header("Mcp-Session-Id", session_id)
|
||||||
|
.header("MCP-Protocol-Version", "2025-06-18"); // Required protocol version header
|
||||||
|
|
||||||
// Add custom headers
|
// Add custom headers
|
||||||
for (key, value) in &self.headers {
|
for (key, value) in &self.headers {
|
||||||
@@ -290,7 +355,8 @@ impl StreamableHttpTransportHandle {
|
|||||||
let mut request = self
|
let mut request = self
|
||||||
.http_client
|
.http_client
|
||||||
.get(&self.mcp_endpoint)
|
.get(&self.mcp_endpoint)
|
||||||
.header("Accept", "text/event-stream");
|
.header("Accept", "text/event-stream")
|
||||||
|
.header("MCP-Protocol-Version", "2025-06-18"); // Required protocol version header
|
||||||
|
|
||||||
// Add session ID header if we have one
|
// Add session ID header if we have one
|
||||||
if let Some(session_id) = self.session_id.read().await.as_ref() {
|
if let Some(session_id) = self.session_id.read().await.as_ref() {
|
||||||
|
|||||||
Reference in New Issue
Block a user