mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 06:34:26 +01:00
feat: Adding streamable-http transport support for backend, desktop and cli (#2942)
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -5321,6 +5321,7 @@ dependencies = [
|
||||
"serde_json",
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tower 0.4.13",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
|
||||
@@ -56,6 +56,24 @@ enum ExtensionConfigRequest {
|
||||
display_name: Option<String>,
|
||||
timeout: Option<u64>,
|
||||
},
|
||||
/// Streamable HTTP extension using MCP Streamable HTTP specification.
|
||||
#[serde(rename = "streamable_http")]
|
||||
StreamableHttp {
|
||||
/// The name to identify this extension
|
||||
name: String,
|
||||
/// The URI endpoint for the streamable HTTP extension.
|
||||
uri: String,
|
||||
#[serde(default)]
|
||||
/// Map of environment variable key to values.
|
||||
envs: Envs,
|
||||
/// List of environment variable keys. The server will fetch their values from the keyring.
|
||||
#[serde(default)]
|
||||
env_keys: Vec<String>,
|
||||
/// Custom headers to include in requests.
|
||||
#[serde(default)]
|
||||
headers: std::collections::HashMap<String, String>,
|
||||
timeout: Option<u64>,
|
||||
},
|
||||
/// Frontend extension that provides tools to be executed by the frontend.
|
||||
#[serde(rename = "frontend")]
|
||||
Frontend {
|
||||
@@ -176,6 +194,23 @@ async fn add_extension(
|
||||
timeout,
|
||||
bundled: None,
|
||||
},
|
||||
ExtensionConfigRequest::StreamableHttp {
|
||||
name,
|
||||
uri,
|
||||
envs,
|
||||
env_keys,
|
||||
headers,
|
||||
timeout,
|
||||
} => ExtensionConfig::StreamableHttp {
|
||||
name,
|
||||
uri,
|
||||
envs,
|
||||
env_keys,
|
||||
headers,
|
||||
description: None,
|
||||
timeout,
|
||||
bundled: None,
|
||||
},
|
||||
ExtensionConfigRequest::Stdio {
|
||||
name,
|
||||
cmd,
|
||||
|
||||
@@ -15,7 +15,7 @@ use crate::config::permission::PermissionLevel;
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ExtensionError {
|
||||
#[error("Failed to start the MCP server from configuration `{0}` `{1}`")]
|
||||
Initialization(ExtensionConfig, ClientError),
|
||||
Initialization(Box<ExtensionConfig>, ClientError),
|
||||
#[error("Failed a client call to an MCP server: {0}")]
|
||||
Client(#[from] ClientError),
|
||||
#[error("User Message exceeded context-limit. History could not be truncated to accommodate.")]
|
||||
@@ -54,7 +54,7 @@ impl Envs {
|
||||
"LD_AUDIT", // Loads a monitoring library that can intercept execution
|
||||
"LD_DEBUG", // Enables verbose linker logging (information disclosure risk)
|
||||
"LD_BIND_NOW", // Forces immediate symbol resolution, affecting ASLR
|
||||
"LD_ASSUME_KERNEL", // Tricks linker into thinking it’s running on an older kernel
|
||||
"LD_ASSUME_KERNEL", // Tricks linker into thinking it's running on an older kernel
|
||||
// 🍎 macOS dynamic linker variables
|
||||
"DYLD_LIBRARY_PATH", // Same as LD_LIBRARY_PATH but for macOS
|
||||
"DYLD_INSERT_LIBRARIES", // macOS equivalent of LD_PRELOAD
|
||||
@@ -168,6 +168,26 @@ pub enum ExtensionConfig {
|
||||
#[serde(default)]
|
||||
bundled: Option<bool>,
|
||||
},
|
||||
/// Streamable HTTP client with a URI endpoint using MCP Streamable HTTP specification
|
||||
#[serde(rename = "streamable_http")]
|
||||
StreamableHttp {
|
||||
/// The name used to identify this extension
|
||||
name: String,
|
||||
uri: String,
|
||||
#[serde(default)]
|
||||
envs: Envs,
|
||||
#[serde(default)]
|
||||
env_keys: Vec<String>,
|
||||
#[serde(default)]
|
||||
headers: HashMap<String, String>,
|
||||
description: Option<String>,
|
||||
// NOTE: set timeout to be optional for compatibility.
|
||||
// However, new configurations should include this field.
|
||||
timeout: Option<u64>,
|
||||
/// Whether this extension is bundled with Goose
|
||||
#[serde(default)]
|
||||
bundled: Option<bool>,
|
||||
},
|
||||
/// Frontend-provided tools that will be called through the frontend
|
||||
#[serde(rename = "frontend")]
|
||||
Frontend {
|
||||
@@ -207,6 +227,24 @@ impl ExtensionConfig {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn streamable_http<S: Into<String>, T: Into<u64>>(
|
||||
name: S,
|
||||
uri: S,
|
||||
description: S,
|
||||
timeout: T,
|
||||
) -> Self {
|
||||
Self::StreamableHttp {
|
||||
name: name.into(),
|
||||
uri: uri.into(),
|
||||
envs: Envs::default(),
|
||||
env_keys: Vec::new(),
|
||||
headers: HashMap::new(),
|
||||
description: Some(description.into()),
|
||||
timeout: Some(timeout.into()),
|
||||
bundled: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn stdio<S: Into<String>, T: Into<u64>>(
|
||||
name: S,
|
||||
cmd: S,
|
||||
@@ -263,6 +301,7 @@ impl ExtensionConfig {
|
||||
pub fn name(&self) -> String {
|
||||
match self {
|
||||
Self::Sse { name, .. } => name,
|
||||
Self::StreamableHttp { name, .. } => name,
|
||||
Self::Stdio { name, .. } => name,
|
||||
Self::Builtin { name, .. } => name,
|
||||
Self::Frontend { name, .. } => name,
|
||||
@@ -275,6 +314,9 @@ impl std::fmt::Display for ExtensionConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ExtensionConfig::Sse { name, uri, .. } => write!(f, "SSE({}: {})", name, uri),
|
||||
ExtensionConfig::StreamableHttp { name, uri, .. } => {
|
||||
write!(f, "StreamableHttp({}: {})", name, uri)
|
||||
}
|
||||
ExtensionConfig::Stdio {
|
||||
name, cmd, args, ..
|
||||
} => {
|
||||
|
||||
@@ -18,7 +18,7 @@ use crate::agents::extension::Envs;
|
||||
use crate::config::{Config, ExtensionConfigManager};
|
||||
use crate::prompt_template;
|
||||
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
|
||||
use mcp_client::transport::{SseTransport, StdioTransport, Transport};
|
||||
use mcp_client::transport::{SseTransport, StdioTransport, StreamableHttpTransport, Transport};
|
||||
use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError};
|
||||
use serde_json::Value;
|
||||
|
||||
@@ -195,6 +195,28 @@ impl ExtensionManager {
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
ExtensionConfig::StreamableHttp {
|
||||
uri,
|
||||
envs,
|
||||
env_keys,
|
||||
headers,
|
||||
timeout,
|
||||
..
|
||||
} => {
|
||||
let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?;
|
||||
let transport =
|
||||
StreamableHttpTransport::with_headers(uri, all_envs, headers.clone());
|
||||
let handle = transport.start().await?;
|
||||
Box::new(
|
||||
McpClient::connect(
|
||||
handle,
|
||||
Duration::from_secs(
|
||||
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
||||
),
|
||||
)
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
ExtensionConfig::Stdio {
|
||||
cmd,
|
||||
args,
|
||||
@@ -256,7 +278,7 @@ impl ExtensionManager {
|
||||
let init_result = client
|
||||
.initialize(info, capabilities)
|
||||
.await
|
||||
.map_err(|e| ExtensionError::Initialization(config.clone(), e))?;
|
||||
.map_err(|e| ExtensionError::Initialization(Box::new(config.clone()), e))?;
|
||||
|
||||
if let Some(instructions) = init_result.instructions {
|
||||
self.instructions
|
||||
@@ -752,10 +774,13 @@ impl ExtensionManager {
|
||||
ExtensionConfig::Sse {
|
||||
description, name, ..
|
||||
}
|
||||
| ExtensionConfig::StreamableHttp {
|
||||
description, name, ..
|
||||
}
|
||||
| ExtensionConfig::Stdio {
|
||||
description, name, ..
|
||||
} => {
|
||||
// For SSE/Stdio, use description if available
|
||||
// For SSE/StreamableHttp/Stdio, use description if available
|
||||
description
|
||||
.as_ref()
|
||||
.map(|s| s.to_string())
|
||||
|
||||
@@ -9,6 +9,7 @@ workspace = true
|
||||
[dependencies]
|
||||
mcp-core = { path = "../mcp-core" }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tokio-util = { version = "0.7", features = ["io"] }
|
||||
reqwest = { version = "0.11", default-features = false, features = ["json", "stream", "rustls-tls-native-roots"] }
|
||||
eventsource-client = "0.12.0"
|
||||
futures = "0.3"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use anyhow::Result;
|
||||
use futures::lock::Mutex;
|
||||
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
|
||||
use mcp_client::transport::{SseTransport, Transport};
|
||||
use mcp_client::transport::{SseTransport, StreamableHttpTransport, Transport};
|
||||
use mcp_client::StdioTransport;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
@@ -20,6 +20,7 @@ async fn main() -> Result<()> {
|
||||
.init();
|
||||
|
||||
test_transport(sse_transport().await?).await?;
|
||||
test_transport(streamable_http_transport().await?).await?;
|
||||
test_transport(stdio_transport().await?).await?;
|
||||
|
||||
// Test broken transport
|
||||
@@ -52,6 +53,22 @@ async fn sse_transport() -> Result<SseTransport> {
|
||||
))
|
||||
}
|
||||
|
||||
async fn streamable_http_transport() -> Result<StreamableHttpTransport> {
|
||||
let port = "60054";
|
||||
|
||||
tokio::process::Command::new("npx")
|
||||
.env("PORT", port)
|
||||
.arg("@modelcontextprotocol/server-everything")
|
||||
.arg("streamable-http")
|
||||
.spawn()?;
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
|
||||
Ok(StreamableHttpTransport::new(
|
||||
format!("http://localhost:{}/mcp", port),
|
||||
HashMap::new(),
|
||||
))
|
||||
}
|
||||
|
||||
async fn stdio_transport() -> Result<StdioTransport> {
|
||||
Ok(StdioTransport::new(
|
||||
"npx",
|
||||
|
||||
93
crates/mcp-client/examples/streamable_http.rs
Normal file
93
crates/mcp-client/examples/streamable_http.rs
Normal file
@@ -0,0 +1,93 @@
|
||||
use anyhow::Result;
|
||||
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
|
||||
use mcp_client::transport::{StreamableHttpTransport, Transport};
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
// Initialize logging
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(
|
||||
EnvFilter::from_default_env()
|
||||
.add_directive("mcp_client=debug".parse().unwrap())
|
||||
.add_directive("eventsource_client=info".parse().unwrap()),
|
||||
)
|
||||
.init();
|
||||
|
||||
// Create example headers
|
||||
let mut headers = HashMap::new();
|
||||
headers.insert("X-Custom-Header".to_string(), "example-value".to_string());
|
||||
headers.insert(
|
||||
"User-Agent".to_string(),
|
||||
"MCP-StreamableHttp-Client/1.0".to_string(),
|
||||
);
|
||||
|
||||
// Create the Streamable HTTP transport with headers
|
||||
let transport =
|
||||
StreamableHttpTransport::with_headers("http://localhost:8000/mcp", HashMap::new(), headers);
|
||||
|
||||
// Start transport
|
||||
let handle = transport.start().await?;
|
||||
|
||||
// Create client
|
||||
let mut client = McpClient::connect(handle, Duration::from_secs(10)).await?;
|
||||
println!("Client created with Streamable HTTP transport\n");
|
||||
|
||||
// Initialize
|
||||
let server_info = client
|
||||
.initialize(
|
||||
ClientInfo {
|
||||
name: "streamable-http-client".into(),
|
||||
version: "1.0.0".into(),
|
||||
},
|
||||
ClientCapabilities::default(),
|
||||
)
|
||||
.await?;
|
||||
println!("Connected to server: {server_info:?}\n");
|
||||
|
||||
// Give the server a moment to fully initialize
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
|
||||
// List tools
|
||||
let tools = client.list_tools(None).await?;
|
||||
println!("Available tools: {tools:?}\n");
|
||||
|
||||
// Call tool if available
|
||||
if !tools.tools.is_empty() {
|
||||
let tool_result = client
|
||||
.call_tool(
|
||||
&tools.tools[0].name,
|
||||
serde_json::json!({ "message": "Hello from Streamable HTTP transport!" }),
|
||||
)
|
||||
.await?;
|
||||
println!("Tool result: {tool_result:?}\n");
|
||||
}
|
||||
|
||||
// List resources
|
||||
let resources = client.list_resources(None).await?;
|
||||
println!("Resources: {resources:?}\n");
|
||||
|
||||
// Read resource if available
|
||||
if !resources.resources.is_empty() {
|
||||
let resource = client.read_resource(&resources.resources[0].uri).await?;
|
||||
println!("Resource content: {resource:?}\n");
|
||||
}
|
||||
|
||||
// List prompts
|
||||
let prompts = client.list_prompts(None).await?;
|
||||
println!("Available prompts: {prompts:?}\n");
|
||||
|
||||
// Get prompt if available
|
||||
if !prompts.prompts.is_empty() {
|
||||
let prompt_result = client
|
||||
.get_prompt(&prompts.prompts[0].name, serde_json::json!({}))
|
||||
.await?;
|
||||
println!("Prompt result: {prompt_result:?}\n");
|
||||
}
|
||||
|
||||
println!("Streamable HTTP transport example completed successfully!");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -4,4 +4,6 @@ pub mod transport;
|
||||
|
||||
pub use client::{ClientCapabilities, ClientInfo, Error, McpClient, McpClientTrait};
|
||||
pub use service::McpService;
|
||||
pub use transport::{SseTransport, StdioTransport, Transport, TransportHandle};
|
||||
pub use transport::{
|
||||
SseTransport, StdioTransport, StreamableHttpTransport, Transport, TransportHandle,
|
||||
};
|
||||
|
||||
@@ -30,6 +30,12 @@ pub enum Error {
|
||||
|
||||
#[error("HTTP error: {status} - {message}")]
|
||||
HttpError { status: u16, message: String },
|
||||
|
||||
#[error("Streamable HTTP error: {0}")]
|
||||
StreamableHttpError(String),
|
||||
|
||||
#[error("Session error: {0}")]
|
||||
SessionError(String),
|
||||
}
|
||||
|
||||
/// A message that can be sent through the transport
|
||||
@@ -78,3 +84,6 @@ pub use stdio::StdioTransport;
|
||||
|
||||
pub mod sse;
|
||||
pub use sse::SseTransport;
|
||||
|
||||
pub mod streamable_http;
|
||||
pub use streamable_http::StreamableHttpTransport;
|
||||
|
||||
447
crates/mcp-client/src/transport/streamable_http.rs
Normal file
447
crates/mcp-client/src/transport/streamable_http.rs
Normal file
@@ -0,0 +1,447 @@
|
||||
use crate::transport::Error;
|
||||
use async_trait::async_trait;
|
||||
use eventsource_client::{Client, SSE};
|
||||
use futures::TryStreamExt;
|
||||
use mcp_core::protocol::{JsonRpcMessage, JsonRpcRequest};
|
||||
use reqwest::Client as HttpClient;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, Mutex, RwLock};
|
||||
use tokio::time::Duration;
|
||||
use tracing::{debug, error, warn};
|
||||
use url::Url;
|
||||
|
||||
use super::{serialize_and_send, Transport, TransportHandle};
|
||||
|
||||
// Default timeout for HTTP requests
|
||||
const HTTP_TIMEOUT_SECS: u64 = 30;
|
||||
|
||||
/// The Streamable HTTP transport actor that handles:
|
||||
/// - HTTP POST requests to send messages to the server
|
||||
/// - Optional streaming responses for receiving multiple responses and server-initiated messages
|
||||
/// - Session management with session IDs
|
||||
pub struct StreamableHttpActor {
|
||||
/// Receives messages (requests/notifications) from the handle
|
||||
receiver: mpsc::Receiver<String>,
|
||||
/// Sends messages (responses) back to the handle
|
||||
sender: mpsc::Sender<JsonRpcMessage>,
|
||||
/// MCP endpoint URL
|
||||
mcp_endpoint: String,
|
||||
/// HTTP client for sending requests
|
||||
http_client: HttpClient,
|
||||
/// Optional session ID for stateful connections
|
||||
session_id: Arc<RwLock<Option<String>>>,
|
||||
/// Environment variables to set
|
||||
env: HashMap<String, String>,
|
||||
/// Custom headers to include in requests
|
||||
headers: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl StreamableHttpActor {
|
||||
pub fn new(
|
||||
receiver: mpsc::Receiver<String>,
|
||||
sender: mpsc::Sender<JsonRpcMessage>,
|
||||
mcp_endpoint: String,
|
||||
session_id: Arc<RwLock<Option<String>>>,
|
||||
env: HashMap<String, String>,
|
||||
headers: HashMap<String, String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
receiver,
|
||||
sender,
|
||||
mcp_endpoint,
|
||||
http_client: HttpClient::builder()
|
||||
.timeout(Duration::from_secs(HTTP_TIMEOUT_SECS))
|
||||
.build()
|
||||
.unwrap(),
|
||||
session_id,
|
||||
env,
|
||||
headers,
|
||||
}
|
||||
}
|
||||
|
||||
/// Main entry point for the actor
|
||||
pub async fn run(mut self) {
|
||||
// Set environment variables
|
||||
for (key, value) in &self.env {
|
||||
std::env::set_var(key, value);
|
||||
}
|
||||
|
||||
// Handle outgoing messages
|
||||
while let Some(message_str) = self.receiver.recv().await {
|
||||
if let Err(e) = self.handle_outgoing_message(message_str).await {
|
||||
error!("Error handling outgoing message: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
debug!("StreamableHttpActor shut down");
|
||||
}
|
||||
|
||||
/// Handle an outgoing message by sending it via HTTP POST
|
||||
async fn handle_outgoing_message(&mut self, message_str: String) -> Result<(), Error> {
|
||||
debug!("Sending message to MCP endpoint: {}", message_str);
|
||||
|
||||
// Parse the message to determine if it's a request that expects a response
|
||||
let parsed_message: JsonRpcMessage =
|
||||
serde_json::from_str(&message_str).map_err(Error::Serialization)?;
|
||||
|
||||
let expects_response = matches!(
|
||||
parsed_message,
|
||||
JsonRpcMessage::Request(JsonRpcRequest { id: Some(_), .. })
|
||||
);
|
||||
|
||||
// Build the HTTP request
|
||||
let mut request = self
|
||||
.http_client
|
||||
.post(&self.mcp_endpoint)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Accept", "application/json, text/event-stream")
|
||||
.body(message_str);
|
||||
|
||||
// Add session ID header if we have one
|
||||
if let Some(session_id) = self.session_id.read().await.as_ref() {
|
||||
request = request.header("Mcp-Session-Id", session_id);
|
||||
}
|
||||
|
||||
// Add custom headers
|
||||
for (key, value) in &self.headers {
|
||||
request = request.header(key, value);
|
||||
}
|
||||
|
||||
// Send the request
|
||||
let response = request
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| Error::StreamableHttpError(format!("HTTP request failed: {}", e)))?;
|
||||
|
||||
// Handle HTTP error status codes
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
if status.as_u16() == 404 {
|
||||
// Session not found - clear our session ID
|
||||
*self.session_id.write().await = None;
|
||||
return Err(Error::SessionError(
|
||||
"Session expired or not found".to_string(),
|
||||
));
|
||||
}
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(Error::HttpError {
|
||||
status: status.as_u16(),
|
||||
message: error_text,
|
||||
});
|
||||
}
|
||||
|
||||
// Check for session ID in response headers
|
||||
if let Some(session_id_header) = response.headers().get("Mcp-Session-Id") {
|
||||
if let Ok(session_id) = session_id_header.to_str() {
|
||||
debug!("Received session ID: {}", session_id);
|
||||
*self.session_id.write().await = Some(session_id.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// Handle the response based on content type
|
||||
let content_type = response
|
||||
.headers()
|
||||
.get("content-type")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.unwrap_or("");
|
||||
|
||||
if content_type.starts_with("text/event-stream") {
|
||||
// Handle streaming HTTP response (server chose to stream multiple messages back)
|
||||
if expects_response {
|
||||
self.handle_streaming_response(response).await?;
|
||||
}
|
||||
} else if content_type.starts_with("application/json") || expects_response {
|
||||
// Handle single JSON response
|
||||
let response_text = response.text().await.map_err(|e| {
|
||||
Error::StreamableHttpError(format!("Failed to read response: {}", e))
|
||||
})?;
|
||||
|
||||
if !response_text.is_empty() {
|
||||
let json_message: JsonRpcMessage =
|
||||
serde_json::from_str(&response_text).map_err(Error::Serialization)?;
|
||||
|
||||
let _ = self.sender.send(json_message).await;
|
||||
}
|
||||
}
|
||||
// For notifications and responses, we get 202 Accepted with no body
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle streaming HTTP response that uses Server-Sent Events format
|
||||
///
|
||||
/// This is called when the server responds to an HTTP POST with `text/event-stream`
|
||||
/// content-type, indicating it wants to stream multiple JSON-RPC messages back
|
||||
/// rather than sending a single response. This is part of the Streamable HTTP
|
||||
/// specification, not a separate SSE transport.
|
||||
async fn handle_streaming_response(
|
||||
&mut self,
|
||||
response: reqwest::Response,
|
||||
) -> Result<(), Error> {
|
||||
use futures::StreamExt;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio_util::io::StreamReader;
|
||||
|
||||
// Convert the response body to a stream reader
|
||||
let stream = response
|
||||
.bytes_stream()
|
||||
.map(|result| result.map_err(std::io::Error::other));
|
||||
let reader = StreamReader::new(stream);
|
||||
let mut lines = tokio::io::BufReader::new(reader).lines();
|
||||
|
||||
let mut event_type = String::new();
|
||||
let mut event_data = String::new();
|
||||
let mut event_id = String::new();
|
||||
|
||||
while let Ok(Some(line)) = lines.next_line().await {
|
||||
if line.is_empty() {
|
||||
// Empty line indicates end of event
|
||||
if !event_data.is_empty() {
|
||||
// Parse the streamed data as JSON-RPC message
|
||||
match serde_json::from_str::<JsonRpcMessage>(&event_data) {
|
||||
Ok(message) => {
|
||||
debug!("Received streaming HTTP response message: {:?}", message);
|
||||
let _ = self.sender.send(message).await;
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Failed to parse streaming HTTP response message: {}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Reset for next event
|
||||
event_type.clear();
|
||||
event_data.clear();
|
||||
event_id.clear();
|
||||
} else if let Some(field_data) = line.strip_prefix("data: ") {
|
||||
if !event_data.is_empty() {
|
||||
event_data.push('\n');
|
||||
}
|
||||
event_data.push_str(field_data);
|
||||
} else if let Some(field_data) = line.strip_prefix("event: ") {
|
||||
event_type = field_data.to_string();
|
||||
} else if let Some(field_data) = line.strip_prefix("id: ") {
|
||||
event_id = field_data.to_string();
|
||||
}
|
||||
// Ignore other fields (retry, etc.) - we only care about data
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct StreamableHttpTransportHandle {
|
||||
sender: mpsc::Sender<String>,
|
||||
receiver: Arc<Mutex<mpsc::Receiver<JsonRpcMessage>>>,
|
||||
session_id: Arc<RwLock<Option<String>>>,
|
||||
mcp_endpoint: String,
|
||||
http_client: HttpClient,
|
||||
headers: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TransportHandle for StreamableHttpTransportHandle {
|
||||
async fn send(&self, message: JsonRpcMessage) -> Result<(), Error> {
|
||||
serialize_and_send(&self.sender, message).await
|
||||
}
|
||||
|
||||
async fn receive(&self) -> Result<JsonRpcMessage, Error> {
|
||||
let mut receiver = self.receiver.lock().await;
|
||||
receiver.recv().await.ok_or(Error::ChannelClosed)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamableHttpTransportHandle {
|
||||
/// Manually terminate the session by sending HTTP DELETE
|
||||
pub async fn terminate_session(&self) -> Result<(), Error> {
|
||||
if let Some(session_id) = self.session_id.read().await.as_ref() {
|
||||
let mut request = self
|
||||
.http_client
|
||||
.delete(&self.mcp_endpoint)
|
||||
.header("Mcp-Session-Id", session_id);
|
||||
|
||||
// Add custom headers
|
||||
for (key, value) in &self.headers {
|
||||
request = request.header(key, value);
|
||||
}
|
||||
|
||||
match request.send().await {
|
||||
Ok(response) => {
|
||||
if response.status().as_u16() == 405 {
|
||||
// Method not allowed - server doesn't support session termination
|
||||
debug!("Server doesn't support session termination");
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to terminate session: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create a GET request to establish a streaming connection for server-initiated messages
|
||||
pub async fn listen_for_server_messages(&self) -> Result<(), Error> {
|
||||
let mut request = self
|
||||
.http_client
|
||||
.get(&self.mcp_endpoint)
|
||||
.header("Accept", "text/event-stream");
|
||||
|
||||
// Add session ID header if we have one
|
||||
if let Some(session_id) = self.session_id.read().await.as_ref() {
|
||||
request = request.header("Mcp-Session-Id", session_id);
|
||||
}
|
||||
|
||||
// Add custom headers
|
||||
for (key, value) in &self.headers {
|
||||
request = request.header(key, value);
|
||||
}
|
||||
|
||||
let response = request.send().await.map_err(|e| {
|
||||
Error::StreamableHttpError(format!("Failed to start GET streaming connection: {}", e))
|
||||
})?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
if response.status().as_u16() == 405 {
|
||||
// Method not allowed - server doesn't support GET streaming connections
|
||||
debug!("Server doesn't support GET streaming connections");
|
||||
return Ok(());
|
||||
}
|
||||
return Err(Error::HttpError {
|
||||
status: response.status().as_u16(),
|
||||
message: "Failed to establish GET streaming connection".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Handle the streaming connection in a separate task
|
||||
let receiver = self.receiver.clone();
|
||||
let url = response.url().clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let client = match eventsource_client::ClientBuilder::for_url(url.as_str()) {
|
||||
Ok(builder) => builder.build(),
|
||||
Err(e) => {
|
||||
error!(
|
||||
"Failed to create streaming client for GET connection: {}",
|
||||
e
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let mut stream = client.stream();
|
||||
while let Ok(Some(event)) = stream.try_next().await {
|
||||
match event {
|
||||
SSE::Event(e) if e.event_type == "message" || e.event_type.is_empty() => {
|
||||
match serde_json::from_str::<JsonRpcMessage>(&e.data) {
|
||||
Ok(message) => {
|
||||
debug!("Received GET streaming message: {:?}", message);
|
||||
let receiver_guard = receiver.lock().await;
|
||||
// We can't send through the receiver since it's for outbound messages
|
||||
// This would need a different channel for server-initiated messages
|
||||
drop(receiver_guard);
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Failed to parse GET streaming message: {}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct StreamableHttpTransport {
|
||||
mcp_endpoint: String,
|
||||
env: HashMap<String, String>,
|
||||
headers: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl StreamableHttpTransport {
|
||||
pub fn new<S: Into<String>>(mcp_endpoint: S, env: HashMap<String, String>) -> Self {
|
||||
Self {
|
||||
mcp_endpoint: mcp_endpoint.into(),
|
||||
env,
|
||||
headers: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_headers<S: Into<String>>(
|
||||
mcp_endpoint: S,
|
||||
env: HashMap<String, String>,
|
||||
headers: HashMap<String, String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
mcp_endpoint: mcp_endpoint.into(),
|
||||
env,
|
||||
headers,
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate that the URL is a valid MCP endpoint
|
||||
pub fn validate_endpoint(endpoint: &str) -> Result<(), Error> {
|
||||
Url::parse(endpoint)
|
||||
.map_err(|e| Error::StreamableHttpError(format!("Invalid MCP endpoint URL: {}", e)))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Transport for StreamableHttpTransport {
|
||||
type Handle = StreamableHttpTransportHandle;
|
||||
|
||||
async fn start(&self) -> Result<Self::Handle, Error> {
|
||||
// Validate the endpoint URL
|
||||
Self::validate_endpoint(&self.mcp_endpoint)?;
|
||||
|
||||
// Create channels for communication
|
||||
let (tx, rx) = mpsc::channel(32);
|
||||
let (otx, orx) = mpsc::channel(32);
|
||||
|
||||
let session_id: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None));
|
||||
let session_id_clone = Arc::clone(&session_id);
|
||||
|
||||
// Create and spawn the actor
|
||||
let actor = StreamableHttpActor::new(
|
||||
rx,
|
||||
otx,
|
||||
self.mcp_endpoint.clone(),
|
||||
session_id,
|
||||
self.env.clone(),
|
||||
self.headers.clone(),
|
||||
);
|
||||
|
||||
tokio::spawn(actor.run());
|
||||
|
||||
// Create the handle
|
||||
let handle = StreamableHttpTransportHandle {
|
||||
sender: tx,
|
||||
receiver: Arc::new(Mutex::new(orx)),
|
||||
session_id: session_id_clone,
|
||||
mcp_endpoint: self.mcp_endpoint.clone(),
|
||||
http_client: HttpClient::builder()
|
||||
.timeout(Duration::from_secs(HTTP_TIMEOUT_SECS))
|
||||
.build()
|
||||
.unwrap(),
|
||||
headers: self.headers.clone(),
|
||||
};
|
||||
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
async fn close(&self) -> Result<(), Error> {
|
||||
// The transport is closed when the actor task completes
|
||||
// No additional cleanup needed
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1315,6 +1315,54 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"description": "Streamable HTTP client with a URI endpoint using MCP Streamable HTTP specification",
|
||||
"required": [
|
||||
"name",
|
||||
"uri",
|
||||
"type"
|
||||
],
|
||||
"properties": {
|
||||
"bundled": {
|
||||
"type": "boolean",
|
||||
"description": "Whether this extension is bundled with Goose",
|
||||
"nullable": true
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"nullable": true
|
||||
},
|
||||
"env_keys": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"envs": {
|
||||
"$ref": "#/components/schemas/Envs"
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "The name used to identify this extension"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"format": "int64",
|
||||
"nullable": true,
|
||||
"minimum": 0
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"streamable_http"
|
||||
]
|
||||
},
|
||||
"uri": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"description": "Frontend-provided tools that will be called through the frontend",
|
||||
|
||||
@@ -124,6 +124,21 @@ export type ExtensionConfig = {
|
||||
name: string;
|
||||
timeout?: number | null;
|
||||
type: 'builtin';
|
||||
} | {
|
||||
/**
|
||||
* Whether this extension is bundled with Goose
|
||||
*/
|
||||
bundled?: boolean | null;
|
||||
description?: string | null;
|
||||
env_keys?: Array<string>;
|
||||
envs?: Envs;
|
||||
/**
|
||||
* The name used to identify this extension
|
||||
*/
|
||||
name: string;
|
||||
timeout?: number | null;
|
||||
type: 'streamable_http';
|
||||
uri: string;
|
||||
} | {
|
||||
/**
|
||||
* Whether this extension is bundled with Goose
|
||||
|
||||
@@ -36,7 +36,7 @@ interface CreateScheduleModalProps {
|
||||
// Interface for clean extension in YAML
|
||||
interface CleanExtension {
|
||||
name: string;
|
||||
type: 'stdio' | 'sse' | 'builtin' | 'frontend';
|
||||
type: 'stdio' | 'sse' | 'builtin' | 'frontend' | 'streamable_http';
|
||||
cmd?: string;
|
||||
args?: string[];
|
||||
uri?: string;
|
||||
@@ -160,6 +160,8 @@ function recipeToYaml(recipe: Recipe, executionMode: ExecutionMode): string {
|
||||
|
||||
if (ext.type === 'sse' && extAny.uri) {
|
||||
cleanExt.uri = extAny.uri as string;
|
||||
} else if (ext.type === 'streamable_http' && extAny.uri) {
|
||||
cleanExt.uri = extAny.uri as string;
|
||||
} else if (ext.type === 'stdio') {
|
||||
if (extAny.cmd) {
|
||||
cleanExt.cmd = extAny.cmd as string;
|
||||
@@ -195,7 +197,8 @@ function recipeToYaml(recipe: Recipe, executionMode: ExecutionMode): string {
|
||||
cleanExt.type = 'stdio';
|
||||
cleanExt.cmd = extAny.command as string;
|
||||
} else if (extAny.uri) {
|
||||
cleanExt.type = 'sse';
|
||||
// Default to streamable_http for URI-based extensions for forward compatibility
|
||||
cleanExt.type = 'streamable_http';
|
||||
cleanExt.uri = extAny.uri as string;
|
||||
} else if (extAny.tools) {
|
||||
cleanExt.type = 'frontend';
|
||||
|
||||
@@ -72,6 +72,26 @@ function getSseConfig(remoteUrl: string, name: string, description: string, time
|
||||
return config;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build an extension config for Streamable HTTP from the deeplink URL
|
||||
*/
|
||||
function getStreamableHttpConfig(
|
||||
remoteUrl: string,
|
||||
name: string,
|
||||
description: string,
|
||||
timeout: number
|
||||
) {
|
||||
const config: ExtensionConfig = {
|
||||
name,
|
||||
type: 'streamable_http',
|
||||
uri: remoteUrl,
|
||||
description,
|
||||
timeout: timeout,
|
||||
};
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles adding an extension from a deeplink URL
|
||||
*/
|
||||
@@ -120,9 +140,12 @@ export async function addExtensionFromDeepLink(
|
||||
|
||||
const cmd = parsedUrl.searchParams.get('cmd');
|
||||
const remoteUrl = parsedUrl.searchParams.get('url');
|
||||
const transportType = parsedUrl.searchParams.get('transport') || 'sse'; // Default to SSE for backward compatibility
|
||||
|
||||
const config = remoteUrl
|
||||
? getSseConfig(remoteUrl, name, description || '', timeout)
|
||||
? transportType === 'streamable_http'
|
||||
? getStreamableHttpConfig(remoteUrl, name, description || '', timeout)
|
||||
: getSseConfig(remoteUrl, name, description || '', timeout)
|
||||
: getStdioConfig(cmd!, parsedUrl, name, description || '', timeout);
|
||||
|
||||
// Check if extension requires env vars and go to settings if so
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Input } from '../../../ui/input';
|
||||
|
||||
interface ExtensionConfigFieldsProps {
|
||||
type: 'stdio' | 'sse' | 'builtin';
|
||||
type: 'stdio' | 'sse' | 'streamable_http' | 'builtin';
|
||||
full_cmd: string;
|
||||
endpoint: string;
|
||||
onChange: (key: string, value: string) => void;
|
||||
|
||||
@@ -3,7 +3,7 @@ import { Select } from '../../../ui/Select';
|
||||
|
||||
interface ExtensionInfoFieldsProps {
|
||||
name: string;
|
||||
type: 'stdio' | 'sse' | 'builtin';
|
||||
type: 'stdio' | 'sse' | 'streamable_http' | 'builtin';
|
||||
description: string;
|
||||
onChange: (key: string, value: string) => void;
|
||||
submitAttempted: boolean;
|
||||
@@ -43,7 +43,17 @@ export default function ExtensionInfoFields({
|
||||
<div className="w-[200px]">
|
||||
<label className="text-sm font-medium mb-2 block text-textStandard">Type</label>
|
||||
<Select
|
||||
value={{ value: type, label: type.toUpperCase() }}
|
||||
value={{
|
||||
value: type,
|
||||
label:
|
||||
type === 'stdio'
|
||||
? 'STDIO'
|
||||
: type === 'sse'
|
||||
? 'SSE'
|
||||
: type === 'streamable_http'
|
||||
? 'HTTP'
|
||||
: type.toUpperCase(),
|
||||
}}
|
||||
onChange={(newValue: unknown) => {
|
||||
const option = newValue as { value: string; label: string } | null;
|
||||
if (option) {
|
||||
@@ -53,6 +63,7 @@ export default function ExtensionInfoFields({
|
||||
options={[
|
||||
{ value: 'stdio', label: 'Standard IO (STDIO)' },
|
||||
{ value: 'sse', label: 'Server-Sent Events (SSE)' },
|
||||
{ value: 'streamable_http', label: 'Streamable HTTP' },
|
||||
]}
|
||||
isSearchable={false}
|
||||
/>
|
||||
|
||||
@@ -3,6 +3,7 @@ import { Button } from '../../../ui/button';
|
||||
import Modal from '../../../Modal';
|
||||
import { ExtensionFormData } from '../utils';
|
||||
import EnvVarsSection from './EnvVarsSection';
|
||||
import HeadersSection from './HeadersSection';
|
||||
import ExtensionConfigFields from './ExtensionConfigFields';
|
||||
import { PlusIcon, Edit, Trash2, AlertTriangle } from 'lucide-react';
|
||||
import ExtensionInfoFields from './ExtensionInfoFields';
|
||||
@@ -34,13 +35,18 @@ export default function ExtensionModal({
|
||||
const [submitAttempted, setSubmitAttempted] = useState(false);
|
||||
const [showCloseConfirmation, setShowCloseConfirmation] = useState(false);
|
||||
const [hasPendingEnvVars, setHasPendingEnvVars] = useState(false);
|
||||
const [hasPendingHeaders, setHasPendingHeaders] = useState(false);
|
||||
|
||||
// Function to check if form has been modified
|
||||
const hasFormChanges = (): boolean => {
|
||||
// Check if command/endpoint has changed
|
||||
const commandChanged =
|
||||
(formData.type === 'stdio' && formData.cmd !== initialData.cmd) ||
|
||||
(formData.type === 'sse' && formData.endpoint !== initialData.endpoint);
|
||||
(formData.type === 'sse' && formData.endpoint !== initialData.endpoint) ||
|
||||
(formData.type === 'streamable_http' && formData.endpoint !== initialData.endpoint);
|
||||
|
||||
// Check if headers have changed
|
||||
const headersChanged = formData.headers.some((header) => header.isEdited === true);
|
||||
|
||||
// Check if any environment variables have been modified
|
||||
const envVarsChanged = formData.envVars.some((envVar) => envVar.isEdited === true);
|
||||
@@ -60,10 +66,11 @@ export default function ExtensionModal({
|
||||
);
|
||||
|
||||
// Check if there are pending environment variables being typed
|
||||
const hasPendingInput = hasPendingEnvVars;
|
||||
const hasPendingInput = hasPendingEnvVars || hasPendingHeaders;
|
||||
|
||||
return (
|
||||
commandChanged ||
|
||||
headersChanged ||
|
||||
envVarsChanged ||
|
||||
envVarsAdded ||
|
||||
envVarsRemoved ||
|
||||
@@ -123,6 +130,37 @@ export default function ExtensionModal({
|
||||
});
|
||||
};
|
||||
|
||||
const handleAddHeader = (key: string, value: string) => {
|
||||
setFormData({
|
||||
...formData,
|
||||
headers: [...formData.headers, { key, value, isEdited: true }],
|
||||
});
|
||||
};
|
||||
|
||||
const handleRemoveHeader = (index: number) => {
|
||||
const newHeaders = [...formData.headers];
|
||||
newHeaders.splice(index, 1);
|
||||
setFormData({
|
||||
...formData,
|
||||
headers: newHeaders,
|
||||
});
|
||||
};
|
||||
|
||||
const handleHeaderChange = (index: number, field: 'key' | 'value', value: string) => {
|
||||
const newHeaders = [...formData.headers];
|
||||
newHeaders[index][field] = value;
|
||||
|
||||
// Mark as edited if it's a value change
|
||||
if (field === 'value') {
|
||||
newHeaders[index].isEdited = true;
|
||||
}
|
||||
|
||||
setFormData({
|
||||
...formData,
|
||||
headers: newHeaders,
|
||||
});
|
||||
};
|
||||
|
||||
// Function to store a secret value
|
||||
const storeSecret = async (key: string, value: string) => {
|
||||
try {
|
||||
@@ -159,7 +197,10 @@ export default function ExtensionModal({
|
||||
const isConfigValid = () => {
|
||||
return (
|
||||
(formData.type === 'stdio' && !!formData.cmd && formData.cmd.trim() !== '') ||
|
||||
(formData.type === 'sse' && !!formData.endpoint && formData.endpoint.trim() !== '')
|
||||
(formData.type === 'sse' && !!formData.endpoint && formData.endpoint.trim() !== '') ||
|
||||
(formData.type === 'streamable_http' &&
|
||||
!!formData.endpoint &&
|
||||
formData.endpoint.trim() !== '')
|
||||
);
|
||||
};
|
||||
|
||||
@@ -169,6 +210,12 @@ export default function ExtensionModal({
|
||||
);
|
||||
};
|
||||
|
||||
const isHeadersValid = () => {
|
||||
return formData.headers.every(
|
||||
({ key, value }) => (key === '' && value === '') || (key !== '' && value !== '')
|
||||
);
|
||||
};
|
||||
|
||||
const isTimeoutValid = () => {
|
||||
// Check if timeout is not undefined, null, or empty string
|
||||
if (formData.timeout === undefined || formData.timeout === null) {
|
||||
@@ -185,7 +232,9 @@ export default function ExtensionModal({
|
||||
|
||||
// Form validation
|
||||
const isFormValid = () => {
|
||||
return isNameValid() && isConfigValid() && isEnvVarsValid() && isTimeoutValid();
|
||||
return (
|
||||
isNameValid() && isConfigValid() && isEnvVarsValid() && isHeadersValid() && isTimeoutValid()
|
||||
);
|
||||
};
|
||||
|
||||
// Handle submit with validation and secret storage
|
||||
@@ -344,6 +393,25 @@ export default function ExtensionModal({
|
||||
onPendingInputChange={setHasPendingEnvVars}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Request Headers - Only for streamable_http */}
|
||||
{formData.type === 'streamable_http' && (
|
||||
<>
|
||||
{/* Divider */}
|
||||
<hr className="border-t border-borderSubtle mb-4" />
|
||||
|
||||
<div className="mb-6">
|
||||
<HeadersSection
|
||||
headers={formData.headers}
|
||||
onAdd={handleAddHeader}
|
||||
onRemove={handleRemoveHeader}
|
||||
onChange={handleHeaderChange}
|
||||
submitAttempted={submitAttempted}
|
||||
onPendingInputChange={setHasPendingHeaders}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</Modal>
|
||||
|
||||
@@ -0,0 +1,160 @@
|
||||
import React from 'react';
|
||||
import { Button } from '../../../ui/button';
|
||||
import { Plus, X } from 'lucide-react';
|
||||
import { Input } from '../../../ui/input';
|
||||
import { cn } from '../../../../utils';
|
||||
|
||||
interface HeadersSectionProps {
|
||||
headers: { key: string; value: string; isEdited?: boolean }[];
|
||||
onAdd: (key: string, value: string) => void;
|
||||
onRemove: (index: number) => void;
|
||||
onChange: (index: number, field: 'key' | 'value', value: string) => void;
|
||||
submitAttempted: boolean;
|
||||
onPendingInputChange?: (hasPending: boolean) => void;
|
||||
}
|
||||
|
||||
export default function HeadersSection({
|
||||
headers,
|
||||
onAdd,
|
||||
onRemove,
|
||||
onChange,
|
||||
submitAttempted,
|
||||
onPendingInputChange,
|
||||
}: HeadersSectionProps) {
|
||||
const [newKey, setNewKey] = React.useState('');
|
||||
const [newValue, setNewValue] = React.useState('');
|
||||
const [validationError, setValidationError] = React.useState<string | null>(null);
|
||||
const [invalidFields, setInvalidFields] = React.useState<{ key: boolean; value: boolean }>({
|
||||
key: false,
|
||||
value: false,
|
||||
});
|
||||
|
||||
// Track pending input changes
|
||||
React.useEffect(() => {
|
||||
const hasPendingInput = newKey.trim() !== '' || newValue.trim() !== '';
|
||||
onPendingInputChange?.(hasPendingInput);
|
||||
}, [newKey, newValue, onPendingInputChange]);
|
||||
|
||||
const handleAdd = () => {
|
||||
const keyEmpty = !newKey.trim();
|
||||
const valueEmpty = !newValue.trim();
|
||||
const keyHasSpaces = newKey.includes(' ');
|
||||
|
||||
if (keyEmpty || valueEmpty) {
|
||||
setInvalidFields({
|
||||
key: keyEmpty,
|
||||
value: valueEmpty,
|
||||
});
|
||||
setValidationError('Both header name and value must be entered');
|
||||
return;
|
||||
}
|
||||
|
||||
if (keyHasSpaces) {
|
||||
setInvalidFields({
|
||||
key: true,
|
||||
value: false,
|
||||
});
|
||||
setValidationError('Header name cannot contain spaces');
|
||||
return;
|
||||
}
|
||||
|
||||
setValidationError(null);
|
||||
setInvalidFields({ key: false, value: false });
|
||||
onAdd(newKey, newValue);
|
||||
setNewKey('');
|
||||
setNewValue('');
|
||||
};
|
||||
|
||||
const clearValidation = () => {
|
||||
setValidationError(null);
|
||||
setInvalidFields({ key: false, value: false });
|
||||
};
|
||||
|
||||
const isFieldInvalid = (index: number, field: 'key' | 'value') => {
|
||||
if (!submitAttempted) return false;
|
||||
const value = headers[index][field].trim();
|
||||
return value === '';
|
||||
};
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="relative mb-2">
|
||||
<label className="text-sm font-medium text-textStandard mb-2 block">Request Headers</label>
|
||||
<p className="text-xs text-textSubtle mb-4">
|
||||
Add custom HTTP headers to include in requests to the MCP server. Click the "+" button to
|
||||
add after filling both fields.
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid grid-cols-[1fr_1fr_auto] gap-2 items-center">
|
||||
{/* Existing headers */}
|
||||
{headers.map((header, index) => (
|
||||
<React.Fragment key={index}>
|
||||
<div className="relative">
|
||||
<Input
|
||||
value={header.key}
|
||||
onChange={(e) => onChange(index, 'key', e.target.value)}
|
||||
placeholder="Header name"
|
||||
className={cn(
|
||||
'w-full text-textStandard border-borderSubtle hover:border-borderStandard',
|
||||
isFieldInvalid(index, 'key') && 'border-red-500 focus:border-red-500'
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
<div className="relative">
|
||||
<Input
|
||||
value={header.value}
|
||||
onChange={(e) => onChange(index, 'value', e.target.value)}
|
||||
placeholder="Value"
|
||||
className={cn(
|
||||
'w-full text-textStandard border-borderSubtle hover:border-borderStandard',
|
||||
isFieldInvalid(index, 'value') && 'border-red-500 focus:border-red-500'
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
<Button
|
||||
onClick={() => onRemove(index)}
|
||||
variant="ghost"
|
||||
className="group p-2 h-auto text-iconSubtle hover:bg-transparent"
|
||||
>
|
||||
<X className="h-3 w-3 text-gray-400 group-hover:text-white group-hover:drop-shadow-sm transition-all" />
|
||||
</Button>
|
||||
</React.Fragment>
|
||||
))}
|
||||
|
||||
{/* Empty row with Add button */}
|
||||
<Input
|
||||
value={newKey}
|
||||
onChange={(e) => {
|
||||
setNewKey(e.target.value);
|
||||
clearValidation();
|
||||
}}
|
||||
placeholder="Header name"
|
||||
className={cn(
|
||||
'w-full text-textStandard border-borderSubtle hover:border-borderStandard',
|
||||
invalidFields.key && 'border-red-500 focus:border-red-500'
|
||||
)}
|
||||
/>
|
||||
<Input
|
||||
value={newValue}
|
||||
onChange={(e) => {
|
||||
setNewValue(e.target.value);
|
||||
clearValidation();
|
||||
}}
|
||||
placeholder="Value"
|
||||
className={cn(
|
||||
'w-full text-textStandard border-borderSubtle hover:border-borderStandard',
|
||||
invalidFields.value && 'border-red-500 focus:border-red-500'
|
||||
)}
|
||||
/>
|
||||
<Button
|
||||
onClick={handleAdd}
|
||||
variant="ghost"
|
||||
className="flex items-center justify-start gap-1 px-2 pr-4 text-sm rounded-full text-textStandard bg-bgApp border border-borderSubtle hover:border-borderStandard transition-colors min-w-[60px] h-9 [&>svg]:!size-4"
|
||||
>
|
||||
<Plus /> Add
|
||||
</Button>
|
||||
</div>
|
||||
{validationError && <div className="mt-2 text-red-500 text-sm">{validationError}</div>}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -93,6 +93,14 @@ export function getSubtitle(config: ExtensionConfig): SubtitleParts {
|
||||
return { description, command };
|
||||
}
|
||||
|
||||
if (config.type === 'streamable_http') {
|
||||
const description = config.description
|
||||
? `Streamable HTTP extension: ${config.description}`
|
||||
: 'Streamable HTTP extension';
|
||||
const command = config.uri || null;
|
||||
return { description, command };
|
||||
}
|
||||
|
||||
return {
|
||||
description: 'Unknown type of extension',
|
||||
command: null,
|
||||
|
||||
@@ -21,7 +21,7 @@ import { ExtensionConfig } from '../../../api/types.gen';
|
||||
export interface ExtensionFormData {
|
||||
name: string;
|
||||
description: string;
|
||||
type: 'stdio' | 'sse' | 'builtin';
|
||||
type: 'stdio' | 'sse' | 'streamable_http' | 'builtin';
|
||||
cmd?: string;
|
||||
endpoint?: string;
|
||||
enabled: boolean;
|
||||
@@ -31,6 +31,11 @@ export interface ExtensionFormData {
|
||||
value: string;
|
||||
isEdited?: boolean;
|
||||
}[];
|
||||
headers: {
|
||||
key: string;
|
||||
value: string;
|
||||
isEdited?: boolean;
|
||||
}[];
|
||||
}
|
||||
|
||||
export function getDefaultFormData(): ExtensionFormData {
|
||||
@@ -43,12 +48,14 @@ export function getDefaultFormData(): ExtensionFormData {
|
||||
enabled: true,
|
||||
timeout: 300,
|
||||
envVars: [],
|
||||
headers: [],
|
||||
};
|
||||
}
|
||||
|
||||
export function extensionToFormData(extension: FixedExtensionEntry): ExtensionFormData {
|
||||
// Type guard: Check if 'envs' property exists for this variant
|
||||
const hasEnvs = extension.type === 'sse' || extension.type === 'stdio';
|
||||
const hasEnvs =
|
||||
extension.type === 'sse' || extension.type === 'streamable_http' || extension.type === 'stdio';
|
||||
|
||||
// Handle both envs (legacy) and env_keys (new secrets)
|
||||
let envVars = [];
|
||||
@@ -75,16 +82,32 @@ export function extensionToFormData(extension: FixedExtensionEntry): ExtensionFo
|
||||
);
|
||||
}
|
||||
|
||||
// Handle headers for streamable_http
|
||||
let headers = [];
|
||||
if (extension.type === 'streamable_http' && 'headers' in extension && extension.headers) {
|
||||
headers.push(
|
||||
...Object.entries(extension.headers).map(([key, value]) => ({
|
||||
key,
|
||||
value: value as string,
|
||||
isEdited: false, // Mark as not edited initially
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
name: extension.name || '',
|
||||
description:
|
||||
extension.type === 'stdio' || extension.type === 'sse' ? extension.description || '' : '',
|
||||
extension.type === 'stdio' || extension.type === 'sse' || extension.type === 'streamable_http'
|
||||
? extension.description || ''
|
||||
: '',
|
||||
type: extension.type === 'frontend' ? 'stdio' : extension.type,
|
||||
cmd: extension.type === 'stdio' ? combineCmdAndArgs(extension.cmd, extension.args) : undefined,
|
||||
endpoint: extension.type === 'sse' ? extension.uri : undefined,
|
||||
endpoint:
|
||||
extension.type === 'sse' || extension.type === 'streamable_http' ? extension.uri : undefined,
|
||||
enabled: extension.enabled,
|
||||
timeout: 'timeout' in extension ? (extension.timeout ?? undefined) : undefined,
|
||||
envVars,
|
||||
headers,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -114,6 +137,27 @@ export function createExtensionConfig(formData: ExtensionFormData): ExtensionCon
|
||||
uri: formData.endpoint || '',
|
||||
...(env_keys.length > 0 ? { env_keys } : {}),
|
||||
};
|
||||
} else if (formData.type === 'streamable_http') {
|
||||
// Extract headers
|
||||
const headers = formData.headers
|
||||
.filter(({ key, value }) => key.length > 0 && value.length > 0)
|
||||
.reduce(
|
||||
(acc, header) => {
|
||||
acc[header.key] = header.value;
|
||||
return acc;
|
||||
},
|
||||
{} as Record<string, string>
|
||||
);
|
||||
|
||||
return {
|
||||
type: 'streamable_http',
|
||||
name: formData.name,
|
||||
description: formData.description,
|
||||
timeout: formData.timeout,
|
||||
uri: formData.endpoint || '',
|
||||
...(env_keys.length > 0 ? { env_keys } : {}),
|
||||
...(Object.keys(headers).length > 0 ? { headers } : {}),
|
||||
};
|
||||
} else {
|
||||
// For other types
|
||||
return {
|
||||
|
||||
@@ -17,6 +17,14 @@ export type ExtensionConfig =
|
||||
env_keys?: string[];
|
||||
timeout?: number;
|
||||
}
|
||||
| {
|
||||
type: 'streamable_http';
|
||||
name: string;
|
||||
uri: string;
|
||||
env_keys?: string[];
|
||||
headers?: Record<string, string>;
|
||||
timeout?: number;
|
||||
}
|
||||
| {
|
||||
type: 'stdio';
|
||||
name: string;
|
||||
@@ -73,6 +81,10 @@ export async function addExtension(
|
||||
name: sanitizeName(extension.name),
|
||||
uri: extension.uri,
|
||||
}),
|
||||
...(extension.type === 'streamable_http' && {
|
||||
name: sanitizeName(extension.name),
|
||||
uri: extension.uri,
|
||||
}),
|
||||
...(extension.type === 'builtin' && {
|
||||
name: sanitizeName(extension.name),
|
||||
}),
|
||||
|
||||
Reference in New Issue
Block a user