mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
StreamableHttp - GET request standalone SSE (#561)
This commit is contained in:
@@ -11,6 +11,7 @@ from mcp.server.streamableHttp import (
|
|||||||
MCP_SESSION_ID_HEADER,
|
MCP_SESSION_ID_HEADER,
|
||||||
StreamableHTTPServerTransport,
|
StreamableHTTPServerTransport,
|
||||||
)
|
)
|
||||||
|
from pydantic import AnyUrl
|
||||||
from starlette.applications import Starlette
|
from starlette.applications import Starlette
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
@@ -92,6 +93,9 @@ def main(
|
|||||||
if i < count - 1: # Don't wait after the last notification
|
if i < count - 1: # Don't wait after the last notification
|
||||||
await anyio.sleep(interval)
|
await anyio.sleep(interval)
|
||||||
|
|
||||||
|
# This will send a resource notificaiton though standalone SSE
|
||||||
|
# established by GET request
|
||||||
|
await ctx.session.send_resource_updated(uri=AnyUrl("http:///test_resource"))
|
||||||
return [
|
return [
|
||||||
types.TextContent(
|
types.TextContent(
|
||||||
type="text",
|
type="text",
|
||||||
|
|||||||
@@ -50,6 +50,9 @@ LAST_EVENT_ID_HEADER = "last-event-id"
|
|||||||
CONTENT_TYPE_JSON = "application/json"
|
CONTENT_TYPE_JSON = "application/json"
|
||||||
CONTENT_TYPE_SSE = "text/event-stream"
|
CONTENT_TYPE_SSE = "text/event-stream"
|
||||||
|
|
||||||
|
# Special key for the standalone GET stream
|
||||||
|
GET_STREAM_KEY = "_GET_stream"
|
||||||
|
|
||||||
# Session ID validation pattern (visible ASCII characters ranging from 0x21 to 0x7E)
|
# Session ID validation pattern (visible ASCII characters ranging from 0x21 to 0x7E)
|
||||||
# Pattern ensures entire string contains only valid characters by using ^ and $ anchors
|
# Pattern ensures entire string contains only valid characters by using ^ and $ anchors
|
||||||
SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$")
|
SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$")
|
||||||
@@ -443,10 +446,19 @@ class StreamableHTTPServerTransport:
|
|||||||
return
|
return
|
||||||
|
|
||||||
async def _handle_get_request(self, request: Request, send: Send) -> None:
|
async def _handle_get_request(self, request: Request, send: Send) -> None:
|
||||||
"""Handle GET requests for SSE stream establishment."""
|
"""
|
||||||
# Validate session ID if server has one
|
Handle GET request to establish SSE.
|
||||||
if not await self._validate_session(request, send):
|
|
||||||
return
|
This allows the server to communicate to the client without the client
|
||||||
|
first sending data via HTTP POST. The server can send JSON-RPC requests
|
||||||
|
and notifications on this stream.
|
||||||
|
"""
|
||||||
|
writer = self._read_stream_writer
|
||||||
|
if writer is None:
|
||||||
|
raise ValueError(
|
||||||
|
"No read stream writer available. Ensure connect() is called first."
|
||||||
|
)
|
||||||
|
|
||||||
# Validate Accept header - must include text/event-stream
|
# Validate Accept header - must include text/event-stream
|
||||||
_, has_sse = self._check_accept_headers(request)
|
_, has_sse = self._check_accept_headers(request)
|
||||||
|
|
||||||
@@ -458,13 +470,80 @@ class StreamableHTTPServerTransport:
|
|||||||
await response(request.scope, request.receive, send)
|
await response(request.scope, request.receive, send)
|
||||||
return
|
return
|
||||||
|
|
||||||
# TODO: Implement SSE stream for GET requests
|
if not await self._validate_session(request, send):
|
||||||
# For now, return 405 Method Not Allowed
|
return
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Cache-Control": "no-cache, no-transform",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"Content-Type": CONTENT_TYPE_SSE,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.mcp_session_id:
|
||||||
|
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
|
||||||
|
|
||||||
|
# Check if we already have an active GET stream
|
||||||
|
if GET_STREAM_KEY in self._request_streams:
|
||||||
response = self._create_error_response(
|
response = self._create_error_response(
|
||||||
"SSE stream from GET request not implemented yet",
|
"Conflict: Only one SSE stream is allowed per session",
|
||||||
HTTPStatus.METHOD_NOT_ALLOWED,
|
HTTPStatus.CONFLICT,
|
||||||
)
|
)
|
||||||
await response(request.scope, request.receive, send)
|
await response(request.scope, request.receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create SSE stream
|
||||||
|
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
|
||||||
|
dict[str, Any]
|
||||||
|
](0)
|
||||||
|
|
||||||
|
async def standalone_sse_writer():
|
||||||
|
try:
|
||||||
|
# Create a standalone message stream for server-initiated messages
|
||||||
|
standalone_stream_writer, standalone_stream_reader = (
|
||||||
|
anyio.create_memory_object_stream[JSONRPCMessage](0)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register this stream using the special key
|
||||||
|
self._request_streams[GET_STREAM_KEY] = standalone_stream_writer
|
||||||
|
|
||||||
|
async with sse_stream_writer, standalone_stream_reader:
|
||||||
|
# Process messages from the standalone stream
|
||||||
|
async for received_message in standalone_stream_reader:
|
||||||
|
# For the standalone stream, we handle:
|
||||||
|
# - JSONRPCNotification (server sends notifications to client)
|
||||||
|
# - JSONRPCRequest (server sends requests to client)
|
||||||
|
# We should NOT receive JSONRPCResponse
|
||||||
|
|
||||||
|
# Send the message via SSE
|
||||||
|
event_data = {
|
||||||
|
"event": "message",
|
||||||
|
"data": received_message.model_dump_json(
|
||||||
|
by_alias=True, exclude_none=True
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
await sse_stream_writer.send(event_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Error in standalone SSE writer: {e}")
|
||||||
|
finally:
|
||||||
|
logger.debug("Closing standalone SSE writer")
|
||||||
|
# Remove the stream from request_streams
|
||||||
|
self._request_streams.pop(GET_STREAM_KEY, None)
|
||||||
|
|
||||||
|
# Create and start EventSourceResponse
|
||||||
|
response = EventSourceResponse(
|
||||||
|
content=sse_stream_reader,
|
||||||
|
data_sender_callable=standalone_sse_writer,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# This will send headers immediately and establish the SSE connection
|
||||||
|
await response(request.scope, request.receive, send)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Error in standalone SSE response: {e}")
|
||||||
|
# Clean up the request stream
|
||||||
|
self._request_streams.pop(GET_STREAM_KEY, None)
|
||||||
|
|
||||||
async def _handle_delete_request(self, request: Request, send: Send) -> None:
|
async def _handle_delete_request(self, request: Request, send: Send) -> None:
|
||||||
"""Handle DELETE requests for explicit session termination."""
|
"""Handle DELETE requests for explicit session termination."""
|
||||||
@@ -611,13 +690,10 @@ class StreamableHTTPServerTransport:
|
|||||||
else:
|
else:
|
||||||
target_request_id = str(message.root.id)
|
target_request_id = str(message.root.id)
|
||||||
|
|
||||||
# Send to the specific request stream if available
|
request_stream_id = target_request_id or GET_STREAM_KEY
|
||||||
if (
|
if request_stream_id in self._request_streams:
|
||||||
target_request_id
|
|
||||||
and target_request_id in self._request_streams
|
|
||||||
):
|
|
||||||
try:
|
try:
|
||||||
await self._request_streams[target_request_id].send(
|
await self._request_streams[request_stream_id].send(
|
||||||
message
|
message
|
||||||
)
|
)
|
||||||
except (
|
except (
|
||||||
@@ -625,7 +701,7 @@ class StreamableHTTPServerTransport:
|
|||||||
anyio.ClosedResourceError,
|
anyio.ClosedResourceError,
|
||||||
):
|
):
|
||||||
# Stream might be closed, remove from registry
|
# Stream might be closed, remove from registry
|
||||||
self._request_streams.pop(target_request_id, None)
|
self._request_streams.pop(request_stream_id, None)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Error in message router: {e}")
|
logger.exception(f"Error in message router: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -541,3 +541,92 @@ def test_json_response(json_response_server, json_server_url):
|
|||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.headers.get("Content-Type") == "application/json"
|
assert response.headers.get("Content-Type") == "application/json"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_sse_stream(basic_server, basic_server_url):
|
||||||
|
"""Test establishing an SSE stream via GET request."""
|
||||||
|
# First, we need to initialize a session
|
||||||
|
mcp_url = f"{basic_server_url}/mcp"
|
||||||
|
init_response = requests.post(
|
||||||
|
mcp_url,
|
||||||
|
headers={
|
||||||
|
"Accept": "application/json, text/event-stream",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json=INIT_REQUEST,
|
||||||
|
)
|
||||||
|
assert init_response.status_code == 200
|
||||||
|
|
||||||
|
# Get the session ID
|
||||||
|
session_id = init_response.headers.get(MCP_SESSION_ID_HEADER)
|
||||||
|
assert session_id is not None
|
||||||
|
|
||||||
|
# Now attempt to establish an SSE stream via GET
|
||||||
|
get_response = requests.get(
|
||||||
|
mcp_url,
|
||||||
|
headers={
|
||||||
|
"Accept": "text/event-stream",
|
||||||
|
MCP_SESSION_ID_HEADER: session_id,
|
||||||
|
},
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify we got a successful response with the right content type
|
||||||
|
assert get_response.status_code == 200
|
||||||
|
assert get_response.headers.get("Content-Type") == "text/event-stream"
|
||||||
|
|
||||||
|
# Test that a second GET request gets rejected (only one stream allowed)
|
||||||
|
second_get = requests.get(
|
||||||
|
mcp_url,
|
||||||
|
headers={
|
||||||
|
"Accept": "text/event-stream",
|
||||||
|
MCP_SESSION_ID_HEADER: session_id,
|
||||||
|
},
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should get CONFLICT (409) since there's already a stream
|
||||||
|
# Note: This might fail if the first stream fully closed before this runs,
|
||||||
|
# but generally it should work in the test environment where it runs quickly
|
||||||
|
assert second_get.status_code == 409
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_validation(basic_server, basic_server_url):
|
||||||
|
"""Test validation for GET requests."""
|
||||||
|
# First, we need to initialize a session
|
||||||
|
mcp_url = f"{basic_server_url}/mcp"
|
||||||
|
init_response = requests.post(
|
||||||
|
mcp_url,
|
||||||
|
headers={
|
||||||
|
"Accept": "application/json, text/event-stream",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json=INIT_REQUEST,
|
||||||
|
)
|
||||||
|
assert init_response.status_code == 200
|
||||||
|
|
||||||
|
# Get the session ID
|
||||||
|
session_id = init_response.headers.get(MCP_SESSION_ID_HEADER)
|
||||||
|
assert session_id is not None
|
||||||
|
|
||||||
|
# Test without Accept header
|
||||||
|
response = requests.get(
|
||||||
|
mcp_url,
|
||||||
|
headers={
|
||||||
|
MCP_SESSION_ID_HEADER: session_id,
|
||||||
|
},
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
assert response.status_code == 406
|
||||||
|
assert "Not Acceptable" in response.text
|
||||||
|
|
||||||
|
# Test with wrong Accept header
|
||||||
|
response = requests.get(
|
||||||
|
mcp_url,
|
||||||
|
headers={
|
||||||
|
"Accept": "application/json",
|
||||||
|
MCP_SESSION_ID_HEADER: session_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 406
|
||||||
|
assert "Not Acceptable" in response.text
|
||||||
|
|||||||
Reference in New Issue
Block a user