StreamableHttp - GET request standalone SSE (#561)

This commit is contained in:
ihrpr
2025-05-02 13:52:27 +01:00
committed by GitHub
parent 72b66a58b1
commit 46523afe30
3 changed files with 186 additions and 17 deletions

View File

@@ -11,6 +11,7 @@ from mcp.server.streamableHttp import (
MCP_SESSION_ID_HEADER, MCP_SESSION_ID_HEADER,
StreamableHTTPServerTransport, StreamableHTTPServerTransport,
) )
from pydantic import AnyUrl
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
@@ -92,6 +93,9 @@ def main(
if i < count - 1: # Don't wait after the last notification if i < count - 1: # Don't wait after the last notification
await anyio.sleep(interval) await anyio.sleep(interval)
# This will send a resource notificaiton though standalone SSE
# established by GET request
await ctx.session.send_resource_updated(uri=AnyUrl("http:///test_resource"))
return [ return [
types.TextContent( types.TextContent(
type="text", type="text",

View File

@@ -50,6 +50,9 @@ LAST_EVENT_ID_HEADER = "last-event-id"
CONTENT_TYPE_JSON = "application/json" CONTENT_TYPE_JSON = "application/json"
CONTENT_TYPE_SSE = "text/event-stream" CONTENT_TYPE_SSE = "text/event-stream"
# Special key for the standalone GET stream
GET_STREAM_KEY = "_GET_stream"
# Session ID validation pattern (visible ASCII characters ranging from 0x21 to 0x7E) # Session ID validation pattern (visible ASCII characters ranging from 0x21 to 0x7E)
# Pattern ensures entire string contains only valid characters by using ^ and $ anchors # Pattern ensures entire string contains only valid characters by using ^ and $ anchors
SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$") SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$")
@@ -443,10 +446,19 @@ class StreamableHTTPServerTransport:
return return
async def _handle_get_request(self, request: Request, send: Send) -> None: async def _handle_get_request(self, request: Request, send: Send) -> None:
"""Handle GET requests for SSE stream establishment.""" """
# Validate session ID if server has one Handle GET request to establish SSE.
if not await self._validate_session(request, send):
return This allows the server to communicate to the client without the client
first sending data via HTTP POST. The server can send JSON-RPC requests
and notifications on this stream.
"""
writer = self._read_stream_writer
if writer is None:
raise ValueError(
"No read stream writer available. Ensure connect() is called first."
)
# Validate Accept header - must include text/event-stream # Validate Accept header - must include text/event-stream
_, has_sse = self._check_accept_headers(request) _, has_sse = self._check_accept_headers(request)
@@ -458,13 +470,80 @@ class StreamableHTTPServerTransport:
await response(request.scope, request.receive, send) await response(request.scope, request.receive, send)
return return
# TODO: Implement SSE stream for GET requests if not await self._validate_session(request, send):
# For now, return 405 Method Not Allowed return
headers = {
"Cache-Control": "no-cache, no-transform",
"Connection": "keep-alive",
"Content-Type": CONTENT_TYPE_SSE,
}
if self.mcp_session_id:
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
# Check if we already have an active GET stream
if GET_STREAM_KEY in self._request_streams:
response = self._create_error_response( response = self._create_error_response(
"SSE stream from GET request not implemented yet", "Conflict: Only one SSE stream is allowed per session",
HTTPStatus.METHOD_NOT_ALLOWED, HTTPStatus.CONFLICT,
) )
await response(request.scope, request.receive, send) await response(request.scope, request.receive, send)
return
# Create SSE stream
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
dict[str, Any]
](0)
async def standalone_sse_writer():
try:
# Create a standalone message stream for server-initiated messages
standalone_stream_writer, standalone_stream_reader = (
anyio.create_memory_object_stream[JSONRPCMessage](0)
)
# Register this stream using the special key
self._request_streams[GET_STREAM_KEY] = standalone_stream_writer
async with sse_stream_writer, standalone_stream_reader:
# Process messages from the standalone stream
async for received_message in standalone_stream_reader:
# For the standalone stream, we handle:
# - JSONRPCNotification (server sends notifications to client)
# - JSONRPCRequest (server sends requests to client)
# We should NOT receive JSONRPCResponse
# Send the message via SSE
event_data = {
"event": "message",
"data": received_message.model_dump_json(
by_alias=True, exclude_none=True
),
}
await sse_stream_writer.send(event_data)
except Exception as e:
logger.exception(f"Error in standalone SSE writer: {e}")
finally:
logger.debug("Closing standalone SSE writer")
# Remove the stream from request_streams
self._request_streams.pop(GET_STREAM_KEY, None)
# Create and start EventSourceResponse
response = EventSourceResponse(
content=sse_stream_reader,
data_sender_callable=standalone_sse_writer,
headers=headers,
)
try:
# This will send headers immediately and establish the SSE connection
await response(request.scope, request.receive, send)
except Exception as e:
logger.exception(f"Error in standalone SSE response: {e}")
# Clean up the request stream
self._request_streams.pop(GET_STREAM_KEY, None)
async def _handle_delete_request(self, request: Request, send: Send) -> None: async def _handle_delete_request(self, request: Request, send: Send) -> None:
"""Handle DELETE requests for explicit session termination.""" """Handle DELETE requests for explicit session termination."""
@@ -611,13 +690,10 @@ class StreamableHTTPServerTransport:
else: else:
target_request_id = str(message.root.id) target_request_id = str(message.root.id)
# Send to the specific request stream if available request_stream_id = target_request_id or GET_STREAM_KEY
if ( if request_stream_id in self._request_streams:
target_request_id
and target_request_id in self._request_streams
):
try: try:
await self._request_streams[target_request_id].send( await self._request_streams[request_stream_id].send(
message message
) )
except ( except (
@@ -625,7 +701,7 @@ class StreamableHTTPServerTransport:
anyio.ClosedResourceError, anyio.ClosedResourceError,
): ):
# Stream might be closed, remove from registry # Stream might be closed, remove from registry
self._request_streams.pop(target_request_id, None) self._request_streams.pop(request_stream_id, None)
except Exception as e: except Exception as e:
logger.exception(f"Error in message router: {e}") logger.exception(f"Error in message router: {e}")

View File

@@ -541,3 +541,92 @@ def test_json_response(json_response_server, json_server_url):
) )
assert response.status_code == 200 assert response.status_code == 200
assert response.headers.get("Content-Type") == "application/json" assert response.headers.get("Content-Type") == "application/json"
def test_get_sse_stream(basic_server, basic_server_url):
"""Test establishing an SSE stream via GET request."""
# First, we need to initialize a session
mcp_url = f"{basic_server_url}/mcp"
init_response = requests.post(
mcp_url,
headers={
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
},
json=INIT_REQUEST,
)
assert init_response.status_code == 200
# Get the session ID
session_id = init_response.headers.get(MCP_SESSION_ID_HEADER)
assert session_id is not None
# Now attempt to establish an SSE stream via GET
get_response = requests.get(
mcp_url,
headers={
"Accept": "text/event-stream",
MCP_SESSION_ID_HEADER: session_id,
},
stream=True,
)
# Verify we got a successful response with the right content type
assert get_response.status_code == 200
assert get_response.headers.get("Content-Type") == "text/event-stream"
# Test that a second GET request gets rejected (only one stream allowed)
second_get = requests.get(
mcp_url,
headers={
"Accept": "text/event-stream",
MCP_SESSION_ID_HEADER: session_id,
},
stream=True,
)
# Should get CONFLICT (409) since there's already a stream
# Note: This might fail if the first stream fully closed before this runs,
# but generally it should work in the test environment where it runs quickly
assert second_get.status_code == 409
def test_get_validation(basic_server, basic_server_url):
"""Test validation for GET requests."""
# First, we need to initialize a session
mcp_url = f"{basic_server_url}/mcp"
init_response = requests.post(
mcp_url,
headers={
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
},
json=INIT_REQUEST,
)
assert init_response.status_code == 200
# Get the session ID
session_id = init_response.headers.get(MCP_SESSION_ID_HEADER)
assert session_id is not None
# Test without Accept header
response = requests.get(
mcp_url,
headers={
MCP_SESSION_ID_HEADER: session_id,
},
stream=True,
)
assert response.status_code == 406
assert "Not Acceptable" in response.text
# Test with wrong Accept header
response = requests.get(
mcp_url,
headers={
"Accept": "application/json",
MCP_SESSION_ID_HEADER: session_id,
},
)
assert response.status_code == 406
assert "Not Acceptable" in response.text