StreamableHttp -- resumability support for servers (#587)

This commit is contained in:
ihrpr
2025-05-02 14:10:40 +01:00
committed by GitHub
parent 9dfc925090
commit 3978c6e1b9
5 changed files with 340 additions and 55 deletions

View File

@@ -10,10 +10,11 @@ responses, with streaming support for long-running operations.
import json
import logging
import re
from collections.abc import AsyncGenerator
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass
from http import HTTPStatus
from typing import Any
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
@@ -57,6 +58,63 @@ GET_STREAM_KEY = "_GET_stream"
# Pattern ensures entire string contains only valid characters by using ^ and $ anchors
SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$")
# Type aliases
StreamId = str
EventId = str
@dataclass
class EventMessage:
"""
A JSONRPCMessage with an optional event ID for stream resumability.
"""
message: JSONRPCMessage
event_id: str | None = None
EventCallback = Callable[[EventMessage], Awaitable[None]]
class EventStore(ABC):
"""
Interface for resumability support via event storage.
"""
@abstractmethod
async def store_event(
self, stream_id: StreamId, message: JSONRPCMessage
) -> EventId:
"""
Stores an event for later retrieval.
Args:
stream_id: ID of the stream the event belongs to
message: The JSON-RPC message to store
Returns:
The generated event ID for the stored event
"""
pass
@abstractmethod
async def replay_events_after(
self,
last_event_id: EventId,
send_callback: EventCallback,
) -> StreamId | None:
"""
Replays events that occurred after the specified event ID.
Args:
last_event_id: The ID of the last event the client received
send_callback: A callback function to send events to the client
Returns:
The stream ID of the replayed events
"""
pass
class StreamableHTTPServerTransport:
"""
@@ -76,6 +134,7 @@ class StreamableHTTPServerTransport:
self,
mcp_session_id: str | None,
is_json_response_enabled: bool = False,
event_store: EventStore | None = None,
) -> None:
"""
Initialize a new StreamableHTTP server transport.
@@ -85,6 +144,9 @@ class StreamableHTTPServerTransport:
Must contain only visible ASCII characters (0x21-0x7E).
is_json_response_enabled: If True, return JSON responses for requests
instead of SSE streams. Default is False.
event_store: Event store for resumability support. If provided,
resumability will be enabled, allowing clients to
reconnect and resume messages.
Raises:
ValueError: If the session ID contains invalid characters.
@@ -98,8 +160,9 @@ class StreamableHTTPServerTransport:
self.mcp_session_id = mcp_session_id
self.is_json_response_enabled = is_json_response_enabled
self._event_store = event_store
self._request_streams: dict[
RequestId, MemoryObjectSendStream[JSONRPCMessage]
RequestId, MemoryObjectSendStream[EventMessage]
] = {}
self._terminated = False
@@ -160,6 +223,21 @@ class StreamableHTTPServerTransport:
"""Extract the session ID from request headers."""
return request.headers.get(MCP_SESSION_ID_HEADER)
def _create_event_data(self, event_message: EventMessage) -> dict[str, str]:
"""Create event data dictionary from an EventMessage."""
event_data = {
"event": "message",
"data": event_message.message.model_dump_json(
by_alias=True, exclude_none=True
),
}
# If an event ID was provided, include it
if event_message.event_id:
event_data["id"] = event_message.event_id
return event_data
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Application entry point that handles all HTTP requests"""
request = Request(scope, receive)
@@ -308,7 +386,7 @@ class StreamableHTTPServerTransport:
request_id = str(message.root.id)
# Create promise stream for getting response
request_stream_writer, request_stream_reader = (
anyio.create_memory_object_stream[JSONRPCMessage](0)
anyio.create_memory_object_stream[EventMessage](0)
)
# Register this stream for the request ID
@@ -323,16 +401,18 @@ class StreamableHTTPServerTransport:
response_message = None
# Use similar approach to SSE writer for consistency
async for received_message in request_stream_reader:
async for event_message in request_stream_reader:
# If it's a response, this is what we're waiting for
if isinstance(
received_message.root, JSONRPCResponse | JSONRPCError
event_message.message.root, JSONRPCResponse | JSONRPCError
):
response_message = received_message
response_message = event_message.message
break
# For notifications and request, keep waiting
else:
logger.debug(f"received: {received_message.root.method}")
logger.debug(
f"received: {event_message.message.root.method}"
)
# At this point we should have a response
if response_message:
@@ -366,7 +446,7 @@ class StreamableHTTPServerTransport:
else:
# Create SSE stream
sse_stream_writer, sse_stream_reader = (
anyio.create_memory_object_stream[dict[str, Any]](0)
anyio.create_memory_object_stream[dict[str, str]](0)
)
async def sse_writer():
@@ -374,20 +454,14 @@ class StreamableHTTPServerTransport:
try:
async with sse_stream_writer, request_stream_reader:
# Process messages from the request-specific stream
async for received_message in request_stream_reader:
async for event_message in request_stream_reader:
# Build the event data
event_data = {
"event": "message",
"data": received_message.model_dump_json(
by_alias=True, exclude_none=True
),
}
event_data = self._create_event_data(event_message)
await sse_stream_writer.send(event_data)
# If response, remove from pending streams and close
if isinstance(
received_message.root,
event_message.message.root,
JSONRPCResponse | JSONRPCError,
):
if request_id:
@@ -472,6 +546,10 @@ class StreamableHTTPServerTransport:
if not await self._validate_session(request, send):
return
# Handle resumability: check for Last-Event-ID header
if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER):
await self._replay_events(last_event_id, request, send)
return
headers = {
"Cache-Control": "no-cache, no-transform",
@@ -493,14 +571,14 @@ class StreamableHTTPServerTransport:
# Create SSE stream
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
dict[str, Any]
dict[str, str]
](0)
async def standalone_sse_writer():
try:
# Create a standalone message stream for server-initiated messages
standalone_stream_writer, standalone_stream_reader = (
anyio.create_memory_object_stream[JSONRPCMessage](0)
anyio.create_memory_object_stream[EventMessage](0)
)
# Register this stream using the special key
@@ -508,20 +586,14 @@ class StreamableHTTPServerTransport:
async with sse_stream_writer, standalone_stream_reader:
# Process messages from the standalone stream
async for received_message in standalone_stream_reader:
async for event_message in standalone_stream_reader:
# For the standalone stream, we handle:
# - JSONRPCNotification (server sends notifications to client)
# - JSONRPCRequest (server sends requests to client)
# We should NOT receive JSONRPCResponse
# Send the message via SSE
event_data = {
"event": "message",
"data": received_message.model_dump_json(
by_alias=True, exclude_none=True
),
}
event_data = self._create_event_data(event_message)
await sse_stream_writer.send(event_data)
except Exception as e:
logger.exception(f"Error in standalone SSE writer: {e}")
@@ -639,6 +711,82 @@ class StreamableHTTPServerTransport:
return True
async def _replay_events(
self, last_event_id: str, request: Request, send: Send
) -> None:
"""
Replays events that would have been sent after the specified event ID.
Only used when resumability is enabled.
"""
event_store = self._event_store
if not event_store:
return
try:
headers = {
"Cache-Control": "no-cache, no-transform",
"Connection": "keep-alive",
"Content-Type": CONTENT_TYPE_SSE,
}
if self.mcp_session_id:
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
# Create SSE stream for replay
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
dict[str, str]
](0)
async def replay_sender():
try:
async with sse_stream_writer:
# Define an async callback for sending events
async def send_event(event_message: EventMessage) -> None:
event_data = self._create_event_data(event_message)
await sse_stream_writer.send(event_data)
# Replay past events and get the stream ID
stream_id = await event_store.replay_events_after(
last_event_id, send_event
)
# If stream ID not in mapping, create it
if stream_id and stream_id not in self._request_streams:
msg_writer, msg_reader = anyio.create_memory_object_stream[
EventMessage
](0)
self._request_streams[stream_id] = msg_writer
# Forward messages to SSE
async with msg_reader:
async for event_message in msg_reader:
event_data = self._create_event_data(event_message)
await sse_stream_writer.send(event_data)
except Exception as e:
logger.exception(f"Error in replay sender: {e}")
# Create and start EventSourceResponse
response = EventSourceResponse(
content=sse_stream_reader,
data_sender_callable=replay_sender,
headers=headers,
)
try:
await response(request.scope, request.receive, send)
except Exception as e:
logger.exception(f"Error in replay response: {e}")
except Exception as e:
logger.exception(f"Error replaying events: {e}")
response = self._create_error_response(
f"Error replaying events: {str(e)}",
HTTPStatus.INTERNAL_SERVER_ERROR,
INTERNAL_ERROR,
)
await response(request.scope, request.receive, send)
@asynccontextmanager
async def connect(
self,
@@ -691,10 +839,22 @@ class StreamableHTTPServerTransport:
target_request_id = str(message.root.id)
request_stream_id = target_request_id or GET_STREAM_KEY
# Store the event if we have an event store,
# regardless of whether a client is connected
# messages will be replayed on the re-connect
event_id = None
if self._event_store:
event_id = await self._event_store.store_event(
request_stream_id, message
)
logger.debug(f"Stored {event_id} from {request_stream_id}")
if request_stream_id in self._request_streams:
try:
# Send both the message and the event ID
await self._request_streams[request_stream_id].send(
message
EventMessage(message, event_id)
)
except (
anyio.BrokenResourceError,