mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
StreamableHttp - GET request standalone SSE (#561)
This commit is contained in:
@@ -11,6 +11,7 @@ from mcp.server.streamableHttp import (
|
||||
MCP_SESSION_ID_HEADER,
|
||||
StreamableHTTPServerTransport,
|
||||
)
|
||||
from pydantic import AnyUrl
|
||||
from starlette.applications import Starlette
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
@@ -92,6 +93,9 @@ def main(
|
||||
if i < count - 1: # Don't wait after the last notification
|
||||
await anyio.sleep(interval)
|
||||
|
||||
# This will send a resource notificaiton though standalone SSE
|
||||
# established by GET request
|
||||
await ctx.session.send_resource_updated(uri=AnyUrl("http:///test_resource"))
|
||||
return [
|
||||
types.TextContent(
|
||||
type="text",
|
||||
|
||||
@@ -50,6 +50,9 @@ LAST_EVENT_ID_HEADER = "last-event-id"
|
||||
CONTENT_TYPE_JSON = "application/json"
|
||||
CONTENT_TYPE_SSE = "text/event-stream"
|
||||
|
||||
# Special key for the standalone GET stream
|
||||
GET_STREAM_KEY = "_GET_stream"
|
||||
|
||||
# Session ID validation pattern (visible ASCII characters ranging from 0x21 to 0x7E)
|
||||
# Pattern ensures entire string contains only valid characters by using ^ and $ anchors
|
||||
SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$")
|
||||
@@ -443,10 +446,19 @@ class StreamableHTTPServerTransport:
|
||||
return
|
||||
|
||||
async def _handle_get_request(self, request: Request, send: Send) -> None:
|
||||
"""Handle GET requests for SSE stream establishment."""
|
||||
# Validate session ID if server has one
|
||||
if not await self._validate_session(request, send):
|
||||
return
|
||||
"""
|
||||
Handle GET request to establish SSE.
|
||||
|
||||
This allows the server to communicate to the client without the client
|
||||
first sending data via HTTP POST. The server can send JSON-RPC requests
|
||||
and notifications on this stream.
|
||||
"""
|
||||
writer = self._read_stream_writer
|
||||
if writer is None:
|
||||
raise ValueError(
|
||||
"No read stream writer available. Ensure connect() is called first."
|
||||
)
|
||||
|
||||
# Validate Accept header - must include text/event-stream
|
||||
_, has_sse = self._check_accept_headers(request)
|
||||
|
||||
@@ -458,13 +470,80 @@ class StreamableHTTPServerTransport:
|
||||
await response(request.scope, request.receive, send)
|
||||
return
|
||||
|
||||
# TODO: Implement SSE stream for GET requests
|
||||
# For now, return 405 Method Not Allowed
|
||||
if not await self._validate_session(request, send):
|
||||
return
|
||||
|
||||
headers = {
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
"Connection": "keep-alive",
|
||||
"Content-Type": CONTENT_TYPE_SSE,
|
||||
}
|
||||
|
||||
if self.mcp_session_id:
|
||||
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
|
||||
|
||||
# Check if we already have an active GET stream
|
||||
if GET_STREAM_KEY in self._request_streams:
|
||||
response = self._create_error_response(
|
||||
"SSE stream from GET request not implemented yet",
|
||||
HTTPStatus.METHOD_NOT_ALLOWED,
|
||||
"Conflict: Only one SSE stream is allowed per session",
|
||||
HTTPStatus.CONFLICT,
|
||||
)
|
||||
await response(request.scope, request.receive, send)
|
||||
return
|
||||
|
||||
# Create SSE stream
|
||||
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
|
||||
dict[str, Any]
|
||||
](0)
|
||||
|
||||
async def standalone_sse_writer():
|
||||
try:
|
||||
# Create a standalone message stream for server-initiated messages
|
||||
standalone_stream_writer, standalone_stream_reader = (
|
||||
anyio.create_memory_object_stream[JSONRPCMessage](0)
|
||||
)
|
||||
|
||||
# Register this stream using the special key
|
||||
self._request_streams[GET_STREAM_KEY] = standalone_stream_writer
|
||||
|
||||
async with sse_stream_writer, standalone_stream_reader:
|
||||
# Process messages from the standalone stream
|
||||
async for received_message in standalone_stream_reader:
|
||||
# For the standalone stream, we handle:
|
||||
# - JSONRPCNotification (server sends notifications to client)
|
||||
# - JSONRPCRequest (server sends requests to client)
|
||||
# We should NOT receive JSONRPCResponse
|
||||
|
||||
# Send the message via SSE
|
||||
event_data = {
|
||||
"event": "message",
|
||||
"data": received_message.model_dump_json(
|
||||
by_alias=True, exclude_none=True
|
||||
),
|
||||
}
|
||||
|
||||
await sse_stream_writer.send(event_data)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in standalone SSE writer: {e}")
|
||||
finally:
|
||||
logger.debug("Closing standalone SSE writer")
|
||||
# Remove the stream from request_streams
|
||||
self._request_streams.pop(GET_STREAM_KEY, None)
|
||||
|
||||
# Create and start EventSourceResponse
|
||||
response = EventSourceResponse(
|
||||
content=sse_stream_reader,
|
||||
data_sender_callable=standalone_sse_writer,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
try:
|
||||
# This will send headers immediately and establish the SSE connection
|
||||
await response(request.scope, request.receive, send)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in standalone SSE response: {e}")
|
||||
# Clean up the request stream
|
||||
self._request_streams.pop(GET_STREAM_KEY, None)
|
||||
|
||||
async def _handle_delete_request(self, request: Request, send: Send) -> None:
|
||||
"""Handle DELETE requests for explicit session termination."""
|
||||
@@ -611,13 +690,10 @@ class StreamableHTTPServerTransport:
|
||||
else:
|
||||
target_request_id = str(message.root.id)
|
||||
|
||||
# Send to the specific request stream if available
|
||||
if (
|
||||
target_request_id
|
||||
and target_request_id in self._request_streams
|
||||
):
|
||||
request_stream_id = target_request_id or GET_STREAM_KEY
|
||||
if request_stream_id in self._request_streams:
|
||||
try:
|
||||
await self._request_streams[target_request_id].send(
|
||||
await self._request_streams[request_stream_id].send(
|
||||
message
|
||||
)
|
||||
except (
|
||||
@@ -625,7 +701,7 @@ class StreamableHTTPServerTransport:
|
||||
anyio.ClosedResourceError,
|
||||
):
|
||||
# Stream might be closed, remove from registry
|
||||
self._request_streams.pop(target_request_id, None)
|
||||
self._request_streams.pop(request_stream_id, None)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in message router: {e}")
|
||||
|
||||
|
||||
@@ -541,3 +541,92 @@ def test_json_response(json_response_server, json_server_url):
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.headers.get("Content-Type") == "application/json"
|
||||
|
||||
|
||||
def test_get_sse_stream(basic_server, basic_server_url):
|
||||
"""Test establishing an SSE stream via GET request."""
|
||||
# First, we need to initialize a session
|
||||
mcp_url = f"{basic_server_url}/mcp"
|
||||
init_response = requests.post(
|
||||
mcp_url,
|
||||
headers={
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=INIT_REQUEST,
|
||||
)
|
||||
assert init_response.status_code == 200
|
||||
|
||||
# Get the session ID
|
||||
session_id = init_response.headers.get(MCP_SESSION_ID_HEADER)
|
||||
assert session_id is not None
|
||||
|
||||
# Now attempt to establish an SSE stream via GET
|
||||
get_response = requests.get(
|
||||
mcp_url,
|
||||
headers={
|
||||
"Accept": "text/event-stream",
|
||||
MCP_SESSION_ID_HEADER: session_id,
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Verify we got a successful response with the right content type
|
||||
assert get_response.status_code == 200
|
||||
assert get_response.headers.get("Content-Type") == "text/event-stream"
|
||||
|
||||
# Test that a second GET request gets rejected (only one stream allowed)
|
||||
second_get = requests.get(
|
||||
mcp_url,
|
||||
headers={
|
||||
"Accept": "text/event-stream",
|
||||
MCP_SESSION_ID_HEADER: session_id,
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Should get CONFLICT (409) since there's already a stream
|
||||
# Note: This might fail if the first stream fully closed before this runs,
|
||||
# but generally it should work in the test environment where it runs quickly
|
||||
assert second_get.status_code == 409
|
||||
|
||||
|
||||
def test_get_validation(basic_server, basic_server_url):
|
||||
"""Test validation for GET requests."""
|
||||
# First, we need to initialize a session
|
||||
mcp_url = f"{basic_server_url}/mcp"
|
||||
init_response = requests.post(
|
||||
mcp_url,
|
||||
headers={
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=INIT_REQUEST,
|
||||
)
|
||||
assert init_response.status_code == 200
|
||||
|
||||
# Get the session ID
|
||||
session_id = init_response.headers.get(MCP_SESSION_ID_HEADER)
|
||||
assert session_id is not None
|
||||
|
||||
# Test without Accept header
|
||||
response = requests.get(
|
||||
mcp_url,
|
||||
headers={
|
||||
MCP_SESSION_ID_HEADER: session_id,
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
assert response.status_code == 406
|
||||
assert "Not Acceptable" in response.text
|
||||
|
||||
# Test with wrong Accept header
|
||||
response = requests.get(
|
||||
mcp_url,
|
||||
headers={
|
||||
"Accept": "application/json",
|
||||
MCP_SESSION_ID_HEADER: session_id,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 406
|
||||
assert "Not Acceptable" in response.text
|
||||
|
||||
Reference in New Issue
Block a user