mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 14:44:21 +01:00
Add native OAuth 2.0 authentication support to MCP client (#3213)
Co-authored-by: Alex Hancock <alexhancock@block.xyz>
This commit is contained in:
committed by
GitHub
parent
825c2258ef
commit
883bc67d3b
7
Cargo.lock
generated
7
Cargo.lock
generated
@@ -5376,14 +5376,20 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"axum",
|
||||
"base64 0.22.1",
|
||||
"chrono",
|
||||
"eventsource-client",
|
||||
"futures",
|
||||
"mcp-core",
|
||||
"nanoid",
|
||||
"nix 0.30.1",
|
||||
"rand 0.8.5",
|
||||
"reqwest 0.11.27",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"sha2",
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
@@ -5392,6 +5398,7 @@ dependencies = [
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"url",
|
||||
"webbrowser 1.0.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -25,5 +25,13 @@ tower = { version = "0.4", features = ["timeout", "util"] }
|
||||
tower-service = "0.3"
|
||||
rand = "0.8"
|
||||
nix = { version = "0.30.1", features = ["process", "signal"] }
|
||||
# OAuth dependencies
|
||||
axum = { version = "0.8", features = ["query"] }
|
||||
base64 = "0.22"
|
||||
sha2 = "0.10"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
nanoid = "0.4"
|
||||
webbrowser = "1.0"
|
||||
serde_urlencoded = "0.7"
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
64
crates/mcp-client/examples/test_auth.rs
Normal file
64
crates/mcp-client/examples/test_auth.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
use anyhow::Result;
|
||||
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
|
||||
use mcp_client::transport::{StreamableHttpTransport, Transport};
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
// Initialize logging
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(
|
||||
EnvFilter::from_default_env()
|
||||
.add_directive("mcp_client=debug".parse().unwrap())
|
||||
.add_directive("eventsource_client=info".parse().unwrap()),
|
||||
)
|
||||
.init();
|
||||
|
||||
println!("Testing Streamable HTTP transport with OAuth 2.0 authentication...");
|
||||
|
||||
// Create the Streamable HTTP transport for any MCP service that supports OAuth
|
||||
// This example uses a hypothetical MCP endpoint - replace with actual service
|
||||
let mcp_endpoint =
|
||||
std::env::var("MCP_ENDPOINT").unwrap_or_else(|_| "https://example.com/mcp".to_string());
|
||||
|
||||
println!("Using MCP endpoint: {}", mcp_endpoint);
|
||||
|
||||
let transport = StreamableHttpTransport::new(&mcp_endpoint, HashMap::new());
|
||||
|
||||
// Start transport
|
||||
let handle = transport.start().await?;
|
||||
|
||||
// Create client
|
||||
let mut client = McpClient::connect(handle, Duration::from_secs(30)).await?;
|
||||
println!("Client created with Streamable HTTP transport\n");
|
||||
|
||||
// Initialize - this will trigger the OAuth flow if authentication is needed
|
||||
// The implementation now includes:
|
||||
// - RFC 8707 Resource Parameter support for proper token audience binding
|
||||
// - Proper OAuth 2.0 discovery with multiple fallback paths
|
||||
// - Dynamic client registration (RFC 7591)
|
||||
// - PKCE for security (RFC 7636)
|
||||
// - MCP-Protocol-Version header as required by the specification
|
||||
let server_info = client
|
||||
.initialize(
|
||||
ClientInfo {
|
||||
name: "streamable-http-auth-test".into(),
|
||||
version: "1.0.0".into(),
|
||||
},
|
||||
ClientCapabilities::default(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
println!("Connected to server: {server_info:?}\n");
|
||||
println!("OAuth 2.0 authentication test completed successfully!");
|
||||
println!("\nKey improvements implemented:");
|
||||
println!("✓ RFC 8707 Resource Parameter implementation");
|
||||
println!("✓ MCP-Protocol-Version header support");
|
||||
println!("✓ Enhanced OAuth discovery with multiple fallback paths");
|
||||
println!("✓ Proper canonical resource URI generation");
|
||||
println!("✓ Full compliance with MCP Authorization specification");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,8 +1,13 @@
|
||||
pub mod client;
|
||||
pub mod oauth;
|
||||
pub mod service;
|
||||
pub mod transport;
|
||||
|
||||
#[cfg(test)]
|
||||
mod oauth_tests;
|
||||
|
||||
pub use client::{ClientCapabilities, ClientInfo, Error, McpClient, McpClientTrait};
|
||||
pub use oauth::{authenticate_service, ServiceConfig};
|
||||
pub use service::McpService;
|
||||
pub use transport::{
|
||||
SseTransport, StdioTransport, StreamableHttpTransport, Transport, TransportHandle,
|
||||
|
||||
456
crates/mcp-client/src/oauth.rs
Normal file
456
crates/mcp-client/src/oauth.rs
Normal file
@@ -0,0 +1,456 @@
|
||||
use anyhow::Result;
|
||||
use axum::{extract::Query, response::Html, routing::get, Router};
|
||||
use base64::Engine;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use sha2::Digest;
|
||||
use std::{collections::HashMap, net::SocketAddr, sync::Arc};
|
||||
use tokio::sync::{oneshot, Mutex as TokioMutex};
|
||||
use url::Url;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct OidcEndpoints {
|
||||
authorization_endpoint: String,
|
||||
token_endpoint: String,
|
||||
registration_endpoint: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct TokenData {
|
||||
access_token: String,
|
||||
refresh_token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct ClientRegistrationRequest {
|
||||
redirect_uris: Vec<String>,
|
||||
token_endpoint_auth_method: String,
|
||||
grant_types: Vec<String>,
|
||||
response_types: Vec<String>,
|
||||
client_name: String,
|
||||
client_uri: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct ClientRegistrationResponse {
|
||||
client_id: String,
|
||||
client_id_issued_at: Option<u64>,
|
||||
#[serde(default)]
|
||||
client_secret: Option<String>,
|
||||
}
|
||||
|
||||
/// OAuth configuration for any service
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ServiceConfig {
|
||||
pub oauth_host: String,
|
||||
pub redirect_uri: String,
|
||||
pub client_name: String,
|
||||
pub client_uri: String,
|
||||
pub discovery_path: Option<String>,
|
||||
}
|
||||
|
||||
impl ServiceConfig {
|
||||
/// Create a generic OAuth configuration from an MCP endpoint URL
|
||||
/// Extracts the base URL for OAuth discovery
|
||||
pub fn from_mcp_endpoint(mcp_url: &str) -> Result<Self> {
|
||||
let parsed_url = Url::parse(mcp_url.trim())?;
|
||||
let oauth_host = format!(
|
||||
"{}://{}{}",
|
||||
parsed_url.scheme(),
|
||||
parsed_url.host_str().ok_or_else(|| {
|
||||
anyhow::anyhow!("Invalid MCP URL: no host found in {}", mcp_url)
|
||||
})?,
|
||||
if let Some(port) = parsed_url.port() {
|
||||
format!(":{}", port)
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
oauth_host,
|
||||
redirect_uri: "http://localhost:8020".to_string(),
|
||||
client_name: "Goose MCP Client".to_string(),
|
||||
client_uri: "https://github.com/block/goose".to_string(),
|
||||
discovery_path: None, // Use standard discovery
|
||||
})
|
||||
}
|
||||
|
||||
/// Create configuration with custom discovery path for non-standard services
|
||||
pub fn with_custom_discovery(mut self, discovery_path: String) -> Self {
|
||||
self.discovery_path = Some(discovery_path);
|
||||
self
|
||||
}
|
||||
|
||||
/// Get the canonical resource URI for the MCP server
|
||||
/// This is used as the resource parameter in OAuth requests (RFC 8707)
|
||||
pub fn get_canonical_resource_uri(&self, mcp_url: &str) -> Result<String> {
|
||||
let parsed_url = Url::parse(mcp_url.trim())?;
|
||||
|
||||
// Build canonical URI: scheme://host[:port][/path]
|
||||
let mut canonical = format!(
|
||||
"{}://{}",
|
||||
parsed_url.scheme().to_lowercase(),
|
||||
parsed_url
|
||||
.host_str()
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!("Invalid MCP URL: no host found in {}", mcp_url)
|
||||
})?
|
||||
.to_lowercase()
|
||||
);
|
||||
|
||||
// Add port if not default
|
||||
if let Some(port) = parsed_url.port() {
|
||||
canonical.push_str(&format!(":{}", port));
|
||||
}
|
||||
|
||||
// Add path if present and not just "/"
|
||||
let path = parsed_url.path();
|
||||
if !path.is_empty() && path != "/" {
|
||||
canonical.push_str(path);
|
||||
}
|
||||
|
||||
Ok(canonical)
|
||||
}
|
||||
}
|
||||
|
||||
struct OAuthFlow {
|
||||
endpoints: OidcEndpoints,
|
||||
client_id: String,
|
||||
redirect_url: String,
|
||||
state: String,
|
||||
verifier: String,
|
||||
}
|
||||
|
||||
impl OAuthFlow {
|
||||
fn new(endpoints: OidcEndpoints, client_id: String, redirect_url: String) -> Self {
|
||||
Self {
|
||||
endpoints,
|
||||
client_id,
|
||||
redirect_url,
|
||||
state: nanoid::nanoid!(16),
|
||||
verifier: nanoid::nanoid!(64),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a dynamic client and return the client_id
|
||||
async fn register_client(endpoints: &OidcEndpoints, config: &ServiceConfig) -> Result<String> {
|
||||
let Some(registration_endpoint) = &endpoints.registration_endpoint else {
|
||||
return Err(anyhow::anyhow!("No registration endpoint available"));
|
||||
};
|
||||
|
||||
let registration_request = ClientRegistrationRequest {
|
||||
redirect_uris: vec![config.redirect_uri.clone()],
|
||||
token_endpoint_auth_method: "none".to_string(),
|
||||
grant_types: vec![
|
||||
"authorization_code".to_string(),
|
||||
"refresh_token".to_string(),
|
||||
],
|
||||
response_types: vec!["code".to_string()],
|
||||
client_name: config.client_name.clone(),
|
||||
client_uri: config.client_uri.clone(),
|
||||
};
|
||||
|
||||
tracing::info!("Registering dynamic client with OAuth server...");
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post(registration_endpoint)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(®istration_request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let err_text = resp.text().await?;
|
||||
return Err(anyhow::anyhow!(
|
||||
"Failed to register client: {} - {}",
|
||||
status,
|
||||
err_text
|
||||
));
|
||||
}
|
||||
|
||||
let registration_response: ClientRegistrationResponse = resp.json().await?;
|
||||
|
||||
tracing::info!(
|
||||
"Client registered successfully with ID: {}",
|
||||
registration_response.client_id
|
||||
);
|
||||
Ok(registration_response.client_id)
|
||||
}
|
||||
|
||||
fn get_authorization_url(&self, resource: &str) -> String {
|
||||
let challenge = {
|
||||
let digest = sha2::Sha256::digest(self.verifier.as_bytes());
|
||||
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest)
|
||||
};
|
||||
|
||||
let params = [
|
||||
("response_type", "code"),
|
||||
("client_id", &self.client_id),
|
||||
("redirect_uri", &self.redirect_url),
|
||||
("state", &self.state),
|
||||
("code_challenge", &challenge),
|
||||
("code_challenge_method", "S256"),
|
||||
("resource", resource), // RFC 8707 Resource Parameter
|
||||
];
|
||||
|
||||
format!(
|
||||
"{}?{}",
|
||||
self.endpoints.authorization_endpoint,
|
||||
serde_urlencoded::to_string(params).unwrap()
|
||||
)
|
||||
}
|
||||
|
||||
async fn exchange_code_for_token(&self, code: &str, resource: &str) -> Result<TokenData> {
|
||||
let params = [
|
||||
("grant_type", "authorization_code"),
|
||||
("code", code),
|
||||
("redirect_uri", &self.redirect_url),
|
||||
("code_verifier", &self.verifier),
|
||||
("client_id", &self.client_id),
|
||||
("resource", resource), // RFC 8707 Resource Parameter
|
||||
];
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post(&self.endpoints.token_endpoint)
|
||||
.header("Content-Type", "application/x-www-form-urlencoded")
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let err_text = resp.text().await?;
|
||||
return Err(anyhow::anyhow!(
|
||||
"Failed to exchange code for token: {}",
|
||||
err_text
|
||||
));
|
||||
}
|
||||
|
||||
let token_response: Value = resp.json().await?;
|
||||
|
||||
let access_token = token_response
|
||||
.get("access_token")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("access_token not found in token response"))?
|
||||
.to_string();
|
||||
|
||||
let refresh_token = token_response
|
||||
.get("refresh_token")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
Ok(TokenData {
|
||||
access_token,
|
||||
refresh_token,
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, resource: &str) -> Result<TokenData> {
|
||||
// Create a channel that will send the auth code from the callback
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let state = self.state.clone();
|
||||
let tx = Arc::new(TokioMutex::new(Some(tx)));
|
||||
|
||||
// Setup a server that will receive the redirect and capture the code
|
||||
let app = Router::new().route(
|
||||
"/",
|
||||
get(move |Query(params): Query<HashMap<String, String>>| {
|
||||
let tx = Arc::clone(&tx);
|
||||
let state = state.clone();
|
||||
async move {
|
||||
let code = params.get("code").cloned();
|
||||
let received_state = params.get("state").cloned();
|
||||
|
||||
if let (Some(code), Some(received_state)) = (code, received_state) {
|
||||
if received_state == state {
|
||||
if let Some(sender) = tx.lock().await.take() {
|
||||
if sender.send(code).is_ok() {
|
||||
return Html(
|
||||
"<h2>Authentication Successful!</h2><p>You can close this window and return to the application.</p>",
|
||||
);
|
||||
}
|
||||
}
|
||||
Html("<h2>Error</h2><p>Authentication already completed.</p>")
|
||||
} else {
|
||||
Html("<h2>Error</h2><p>State mismatch - possible security issue.</p>")
|
||||
}
|
||||
} else {
|
||||
Html("<h2>Error</h2><p>Authentication failed - missing parameters.</p>")
|
||||
}
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
// Start the callback server
|
||||
let redirect_url = Url::parse(&self.redirect_url)?;
|
||||
let port = redirect_url.port().unwrap_or(8020);
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], port));
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
|
||||
let server_handle = tokio::spawn(async move {
|
||||
let server = axum::serve(listener, app);
|
||||
server.await.unwrap();
|
||||
});
|
||||
|
||||
// Open the browser for OAuth
|
||||
let authorization_url = self.get_authorization_url(resource);
|
||||
tracing::info!("Opening browser for OAuth authentication...");
|
||||
|
||||
if webbrowser::open(&authorization_url).is_err() {
|
||||
tracing::warn!("Could not open browser automatically. Please open this URL manually:");
|
||||
tracing::warn!("{}", authorization_url);
|
||||
}
|
||||
|
||||
// Wait for the authorization code with a timeout
|
||||
let code = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(120), // 2 minute timeout
|
||||
rx,
|
||||
)
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("Authentication timed out after 2 minutes"))??;
|
||||
|
||||
// Stop the callback server
|
||||
server_handle.abort();
|
||||
|
||||
// Exchange the code for a token
|
||||
self.exchange_code_for_token(&code, resource).await
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_oauth_endpoints(
|
||||
host: &str,
|
||||
custom_discovery_path: Option<&str>,
|
||||
) -> Result<OidcEndpoints> {
|
||||
let base_url = Url::parse(host)?;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Define discovery paths to try, with custom path first if provided
|
||||
let mut discovery_paths = Vec::new();
|
||||
if let Some(custom_path) = custom_discovery_path {
|
||||
discovery_paths.push(custom_path);
|
||||
}
|
||||
discovery_paths.extend([
|
||||
"/.well-known/oauth-authorization-server",
|
||||
"/.well-known/openid_configuration",
|
||||
"/oauth/.well-known/oauth-authorization-server",
|
||||
"/.well-known/oauth_authorization_server", // Some services use underscore
|
||||
]);
|
||||
|
||||
let discovery_paths_for_error = discovery_paths.clone(); // Clone for error message
|
||||
let mut last_error = None;
|
||||
|
||||
// Try each discovery path until one works
|
||||
for path in discovery_paths {
|
||||
match base_url.join(path) {
|
||||
Ok(discovery_url) => {
|
||||
tracing::debug!("Trying OAuth discovery at: {}", discovery_url);
|
||||
|
||||
match client.get(discovery_url.clone()).send().await {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
match resp.json::<Value>().await {
|
||||
Ok(oidc_config) => {
|
||||
// Try to parse the OAuth configuration
|
||||
match parse_oauth_config(oidc_config) {
|
||||
Ok(endpoints) => {
|
||||
tracing::info!(
|
||||
"Successfully discovered OAuth endpoints at: {}",
|
||||
discovery_url
|
||||
);
|
||||
return Ok(endpoints);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::debug!(
|
||||
"Invalid OAuth config at {}: {}",
|
||||
discovery_url,
|
||||
e
|
||||
);
|
||||
last_error = Some(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::debug!(
|
||||
"Failed to parse JSON from {}: {}",
|
||||
discovery_url,
|
||||
e
|
||||
);
|
||||
last_error = Some(e.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(resp) => {
|
||||
tracing::debug!("HTTP {} from {}", resp.status(), discovery_url);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::debug!("Request failed to {}: {}", discovery_url, e);
|
||||
last_error = Some(e.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::debug!("Invalid discovery URL {}{}: {}", host, path, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_error.unwrap_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"No OAuth discovery endpoint found at {}. Tried paths: {:?}",
|
||||
host,
|
||||
discovery_paths_for_error
|
||||
)
|
||||
}))
|
||||
}
|
||||
|
||||
fn parse_oauth_config(oidc_config: Value) -> Result<OidcEndpoints> {
|
||||
let authorization_endpoint = oidc_config
|
||||
.get("authorization_endpoint")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("authorization_endpoint not found in OAuth configuration"))?
|
||||
.to_string();
|
||||
|
||||
let token_endpoint = oidc_config
|
||||
.get("token_endpoint")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("token_endpoint not found in OAuth configuration"))?
|
||||
.to_string();
|
||||
|
||||
let registration_endpoint = oidc_config
|
||||
.get("registration_endpoint")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
Ok(OidcEndpoints {
|
||||
authorization_endpoint,
|
||||
token_endpoint,
|
||||
registration_endpoint,
|
||||
})
|
||||
}
|
||||
|
||||
/// Perform OAuth flow for a service
|
||||
pub async fn authenticate_service(config: ServiceConfig, mcp_url: &str) -> Result<String> {
|
||||
tracing::info!("Starting OAuth authentication for service...");
|
||||
|
||||
// Get the canonical resource URI for the MCP server
|
||||
let resource_uri = config.get_canonical_resource_uri(mcp_url)?;
|
||||
tracing::info!("Using resource URI: {}", resource_uri);
|
||||
|
||||
// Get OAuth endpoints using flexible discovery
|
||||
let endpoints =
|
||||
get_oauth_endpoints(&config.oauth_host, config.discovery_path.as_deref()).await?;
|
||||
|
||||
// Register dynamic client to get client_id
|
||||
let client_id = OAuthFlow::register_client(&endpoints, &config).await?;
|
||||
|
||||
// Create and execute OAuth flow with the dynamic client_id
|
||||
let flow = OAuthFlow::new(endpoints, client_id, config.redirect_uri);
|
||||
|
||||
let token_data = flow.execute(&resource_uri).await?;
|
||||
|
||||
tracing::info!("OAuth authentication successful!");
|
||||
Ok(token_data.access_token)
|
||||
}
|
||||
81
crates/mcp-client/src/oauth_tests.rs
Normal file
81
crates/mcp-client/src/oauth_tests.rs
Normal file
@@ -0,0 +1,81 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::oauth::ServiceConfig;
|
||||
|
||||
#[test]
|
||||
fn test_canonical_resource_uri_generation() {
|
||||
let config = ServiceConfig {
|
||||
oauth_host: "https://example.com".to_string(),
|
||||
redirect_uri: "http://localhost:8020".to_string(),
|
||||
client_name: "Test Client".to_string(),
|
||||
client_uri: "https://test.com".to_string(),
|
||||
discovery_path: None,
|
||||
};
|
||||
|
||||
// Test basic URL
|
||||
let result = config
|
||||
.get_canonical_resource_uri("https://mcp.example.com/mcp")
|
||||
.unwrap();
|
||||
assert_eq!(result, "https://mcp.example.com/mcp");
|
||||
|
||||
// Test URL with port
|
||||
let result = config
|
||||
.get_canonical_resource_uri("https://mcp.example.com:8443/mcp")
|
||||
.unwrap();
|
||||
assert_eq!(result, "https://mcp.example.com:8443/mcp");
|
||||
|
||||
// Test URL without path
|
||||
let result = config
|
||||
.get_canonical_resource_uri("https://mcp.example.com")
|
||||
.unwrap();
|
||||
assert_eq!(result, "https://mcp.example.com");
|
||||
|
||||
// Test URL with root path
|
||||
let result = config
|
||||
.get_canonical_resource_uri("https://mcp.example.com/")
|
||||
.unwrap();
|
||||
assert_eq!(result, "https://mcp.example.com");
|
||||
|
||||
// Test case normalization
|
||||
let result = config
|
||||
.get_canonical_resource_uri("HTTPS://MCP.EXAMPLE.COM/mcp")
|
||||
.unwrap();
|
||||
assert_eq!(result, "https://mcp.example.com/mcp");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_service_config_from_mcp_endpoint() {
|
||||
let config = ServiceConfig::from_mcp_endpoint("https://mcp.example.com/api/mcp").unwrap();
|
||||
|
||||
assert_eq!(config.oauth_host, "https://mcp.example.com");
|
||||
assert_eq!(config.redirect_uri, "http://localhost:8020");
|
||||
assert_eq!(config.client_name, "Goose MCP Client");
|
||||
assert_eq!(config.client_uri, "https://github.com/block/goose");
|
||||
assert!(config.discovery_path.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_service_config_with_port() {
|
||||
let config = ServiceConfig::from_mcp_endpoint("https://mcp.example.com:8443/mcp").unwrap();
|
||||
|
||||
assert_eq!(config.oauth_host, "https://mcp.example.com:8443");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_service_config_invalid_url() {
|
||||
let result = ServiceConfig::from_mcp_endpoint("invalid-url");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_discovery_path() {
|
||||
let config = ServiceConfig::from_mcp_endpoint("https://mcp.example.com/mcp")
|
||||
.unwrap()
|
||||
.with_custom_discovery("/custom/oauth/discovery".to_string());
|
||||
|
||||
assert_eq!(
|
||||
config.discovery_path,
|
||||
Some("/custom/oauth/discovery".to_string())
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::oauth::{authenticate_service, ServiceConfig};
|
||||
use crate::transport::Error;
|
||||
use async_trait::async_trait;
|
||||
use eventsource_client::{Client, SSE};
|
||||
@@ -8,7 +9,7 @@ use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, Mutex, RwLock};
|
||||
use tokio::time::Duration;
|
||||
use tracing::{debug, error, warn};
|
||||
use tracing::{debug, error, info, warn};
|
||||
use url::Url;
|
||||
|
||||
use super::{serialize_and_send, Transport, TransportHandle};
|
||||
@@ -91,13 +92,46 @@ impl StreamableHttpActor {
|
||||
JsonRpcMessage::Request(JsonRpcRequest { id: Some(_), .. })
|
||||
);
|
||||
|
||||
// Try to send the request
|
||||
match self.send_request(&message_str, expects_response).await {
|
||||
Ok(()) => Ok(()),
|
||||
Err(Error::HttpError { status, .. }) if status == 401 || status == 403 => {
|
||||
// Authentication challenge - try to authenticate and retry
|
||||
info!(
|
||||
"Received authentication challenge ({}), attempting OAuth flow...",
|
||||
status
|
||||
);
|
||||
|
||||
if let Some(token) = self.attempt_authentication().await? {
|
||||
info!("Authentication successful, retrying request...");
|
||||
self.headers
|
||||
.insert("Authorization".to_string(), format!("Bearer {}", token));
|
||||
self.send_request(&message_str, expects_response).await
|
||||
} else {
|
||||
Err(Error::StreamableHttpError(
|
||||
"Authentication failed - service not supported or OAuth flow failed"
|
||||
.to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
/// Send an HTTP request to the MCP endpoint
|
||||
async fn send_request(
|
||||
&mut self,
|
||||
message_str: &str,
|
||||
expects_response: bool,
|
||||
) -> Result<(), Error> {
|
||||
// Build the HTTP request
|
||||
let mut request = self
|
||||
.http_client
|
||||
.post(&self.mcp_endpoint)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Accept", "application/json, text/event-stream")
|
||||
.body(message_str);
|
||||
.header("MCP-Protocol-Version", "2025-06-18") // Required protocol version header
|
||||
.body(message_str.to_string());
|
||||
|
||||
// Add session ID header if we have one
|
||||
if let Some(session_id) = self.session_id.read().await.as_ref() {
|
||||
@@ -173,6 +207,36 @@ impl StreamableHttpActor {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Attempt to authenticate with the service
|
||||
async fn attempt_authentication(&self) -> Result<Option<String>, Error> {
|
||||
info!("Attempting to authenticate with service...");
|
||||
|
||||
// Create a generic OAuth configuration from the MCP endpoint
|
||||
match ServiceConfig::from_mcp_endpoint(&self.mcp_endpoint) {
|
||||
Ok(config) => {
|
||||
info!("Created OAuth config for endpoint: {}", self.mcp_endpoint);
|
||||
|
||||
match authenticate_service(config, &self.mcp_endpoint).await {
|
||||
Ok(token) => {
|
||||
info!("OAuth authentication successful!");
|
||||
Ok(Some(token))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("OAuth authentication failed: {}", e);
|
||||
Err(Error::StreamableHttpError(format!("OAuth failed: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Could not create OAuth config from MCP endpoint {}: {}",
|
||||
self.mcp_endpoint, e
|
||||
);
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle streaming HTTP response that uses Server-Sent Events format
|
||||
///
|
||||
/// This is called when the server responds to an HTTP POST with `text/event-stream`
|
||||
@@ -263,7 +327,8 @@ impl StreamableHttpTransportHandle {
|
||||
let mut request = self
|
||||
.http_client
|
||||
.delete(&self.mcp_endpoint)
|
||||
.header("Mcp-Session-Id", session_id);
|
||||
.header("Mcp-Session-Id", session_id)
|
||||
.header("MCP-Protocol-Version", "2025-06-18"); // Required protocol version header
|
||||
|
||||
// Add custom headers
|
||||
for (key, value) in &self.headers {
|
||||
@@ -290,7 +355,8 @@ impl StreamableHttpTransportHandle {
|
||||
let mut request = self
|
||||
.http_client
|
||||
.get(&self.mcp_endpoint)
|
||||
.header("Accept", "text/event-stream");
|
||||
.header("Accept", "text/event-stream")
|
||||
.header("MCP-Protocol-Version", "2025-06-18"); // Required protocol version header
|
||||
|
||||
// Add session ID header if we have one
|
||||
if let Some(session_id) = self.session_id.read().await.as_ref() {
|
||||
|
||||
Reference in New Issue
Block a user