StreamableHttp -- resumability support for servers (#587)

This commit is contained in:
ihrpr
2025-05-02 14:10:40 +01:00
committed by GitHub
parent 9dfc925090
commit 3978c6e1b9
5 changed files with 340 additions and 55 deletions

View File

@@ -9,6 +9,7 @@ A simple MCP server example demonstrating the StreamableHttp transport, which en
- Task management with anyio task groups
- Ability to send multiple notifications over time to the client
- Proper resource cleanup and lifespan management
- Resumability support via InMemoryEventStore
## Usage
@@ -32,6 +33,23 @@ The server exposes a tool named "start-notification-stream" that accepts three a
- `count`: Number of notifications to send (e.g., 5)
- `caller`: Identifier string for the caller
## Resumability Support
This server includes resumability support through the InMemoryEventStore. This enables clients to:
- Reconnect to the server after a disconnection
- Resume event streaming from where they left off using the Last-Event-ID header
The server will:
- Generate unique event IDs for each SSE message
- Store events in memory for later replay
- Replay missed events when a client reconnects with a Last-Event-ID header
Note: The InMemoryEventStore is designed for demonstration purposes only. For production use, consider implementing a persistent storage solution.
## Client
You can connect to this server using an HTTP client, for now only Typescript SDK has streamable HTTP client examples or you can use (Inspector)[https://github.com/modelcontextprotocol/inspector]
You can connect to this server using an HTTP client, for now only Typescript SDK has streamable HTTP client examples or you can use [Inspector](https://github.com/modelcontextprotocol/inspector)

View File

@@ -0,0 +1,105 @@
"""
In-memory event store for demonstrating resumability functionality.
This is a simple implementation intended for examples and testing,
not for production use where a persistent storage solution would be more appropriate.
"""
import logging
from collections import deque
from dataclasses import dataclass
from uuid import uuid4
from mcp.server.streamable_http import (
EventCallback,
EventId,
EventMessage,
EventStore,
StreamId,
)
from mcp.types import JSONRPCMessage
logger = logging.getLogger(__name__)
@dataclass
class EventEntry:
"""
Represents an event entry in the event store.
"""
event_id: EventId
stream_id: StreamId
message: JSONRPCMessage
class InMemoryEventStore(EventStore):
"""
Simple in-memory implementation of the EventStore interface for resumability.
This is primarily intended for examples and testing, not for production use
where a persistent storage solution would be more appropriate.
This implementation keeps only the last N events per stream for memory efficiency.
"""
def __init__(self, max_events_per_stream: int = 100):
"""Initialize the event store.
Args:
max_events_per_stream: Maximum number of events to keep per stream
"""
self.max_events_per_stream = max_events_per_stream
# for maintaining last N events per stream
self.streams: dict[StreamId, deque[EventEntry]] = {}
# event_id -> EventEntry for quick lookup
self.event_index: dict[EventId, EventEntry] = {}
async def store_event(
self, stream_id: StreamId, message: JSONRPCMessage
) -> EventId:
"""Stores an event with a generated event ID."""
event_id = str(uuid4())
event_entry = EventEntry(
event_id=event_id, stream_id=stream_id, message=message
)
# Get or create deque for this stream
if stream_id not in self.streams:
self.streams[stream_id] = deque(maxlen=self.max_events_per_stream)
# If deque is full, the oldest event will be automatically removed
# We need to remove it from the event_index as well
if len(self.streams[stream_id]) == self.max_events_per_stream:
oldest_event = self.streams[stream_id][0]
self.event_index.pop(oldest_event.event_id, None)
# Add new event
self.streams[stream_id].append(event_entry)
self.event_index[event_id] = event_entry
return event_id
async def replay_events_after(
self,
last_event_id: EventId,
send_callback: EventCallback,
) -> StreamId | None:
"""Replays events that occurred after the specified event ID."""
if last_event_id not in self.event_index:
logger.warning(f"Event ID {last_event_id} not found in store")
return None
# Get the stream and find events after the last one
last_event = self.event_index[last_event_id]
stream_id = last_event.stream_id
stream_events = self.streams.get(last_event.stream_id, deque())
# Events in deque are already in chronological order
found_last = False
for event in stream_events:
if found_last:
await send_callback(EventMessage(event.message, event.event_id))
elif event.event_id == last_event_id:
found_last = True
return stream_id

View File

@@ -17,12 +17,24 @@ from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Mount
from .event_store import InMemoryEventStore
# Configure logging
logger = logging.getLogger(__name__)
# Global task group that will be initialized in the lifespan
task_group = None
# Event store for resumability
# The InMemoryEventStore enables resumability support for StreamableHTTP transport.
# It stores SSE events with unique IDs, allowing clients to:
# 1. Receive event IDs for each SSE message
# 2. Resume streams by sending Last-Event-ID in GET requests
# 3. Replay missed events after reconnection
# Note: This in-memory implementation is for demonstration ONLY.
# For production, use a persistent storage solution.
event_store = InMemoryEventStore()
@contextlib.asynccontextmanager
async def lifespan(app):
@@ -79,9 +91,14 @@ def main(
# Send the specified number of notifications with the given interval
for i in range(count):
# Include more detailed message for resumability demonstration
notification_msg = (
f"[{i+1}/{count}] Event from '{caller}' - "
f"Use Last-Event-ID to resume if disconnected"
)
await ctx.session.send_log_message(
level="info",
data=f"Notification {i+1}/{count} from caller: {caller}",
data=notification_msg,
logger="notification_stream",
# Associates this notification with the original request
# Ensures notifications are sent to the correct response stream
@@ -90,6 +107,7 @@ def main(
# - nowhere (if GET request isn't supported)
related_request_id=ctx.request_id,
)
logger.debug(f"Sent notification {i+1}/{count} for caller: {caller}")
if i < count - 1: # Don't wait after the last notification
await anyio.sleep(interval)
@@ -163,8 +181,10 @@ def main(
http_transport = StreamableHTTPServerTransport(
mcp_session_id=new_session_id,
is_json_response_enabled=json_response,
event_store=event_store, # Enable resumability
)
server_instances[http_transport.mcp_session_id] = http_transport
logger.info(f"Created new transport with session ID: {new_session_id}")
async with http_transport.connect() as streams:
read_stream, write_stream = streams

View File

@@ -10,10 +10,11 @@ responses, with streaming support for long-running operations.
import json
import logging
import re
from collections.abc import AsyncGenerator
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass
from http import HTTPStatus
from typing import Any
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
@@ -57,6 +58,63 @@ GET_STREAM_KEY = "_GET_stream"
# Pattern ensures entire string contains only valid characters by using ^ and $ anchors
SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$")
# Type aliases
StreamId = str
EventId = str
@dataclass
class EventMessage:
"""
A JSONRPCMessage with an optional event ID for stream resumability.
"""
message: JSONRPCMessage
event_id: str | None = None
EventCallback = Callable[[EventMessage], Awaitable[None]]
class EventStore(ABC):
"""
Interface for resumability support via event storage.
"""
@abstractmethod
async def store_event(
self, stream_id: StreamId, message: JSONRPCMessage
) -> EventId:
"""
Stores an event for later retrieval.
Args:
stream_id: ID of the stream the event belongs to
message: The JSON-RPC message to store
Returns:
The generated event ID for the stored event
"""
pass
@abstractmethod
async def replay_events_after(
self,
last_event_id: EventId,
send_callback: EventCallback,
) -> StreamId | None:
"""
Replays events that occurred after the specified event ID.
Args:
last_event_id: The ID of the last event the client received
send_callback: A callback function to send events to the client
Returns:
The stream ID of the replayed events
"""
pass
class StreamableHTTPServerTransport:
"""
@@ -76,6 +134,7 @@ class StreamableHTTPServerTransport:
self,
mcp_session_id: str | None,
is_json_response_enabled: bool = False,
event_store: EventStore | None = None,
) -> None:
"""
Initialize a new StreamableHTTP server transport.
@@ -85,6 +144,9 @@ class StreamableHTTPServerTransport:
Must contain only visible ASCII characters (0x21-0x7E).
is_json_response_enabled: If True, return JSON responses for requests
instead of SSE streams. Default is False.
event_store: Event store for resumability support. If provided,
resumability will be enabled, allowing clients to
reconnect and resume messages.
Raises:
ValueError: If the session ID contains invalid characters.
@@ -98,8 +160,9 @@ class StreamableHTTPServerTransport:
self.mcp_session_id = mcp_session_id
self.is_json_response_enabled = is_json_response_enabled
self._event_store = event_store
self._request_streams: dict[
RequestId, MemoryObjectSendStream[JSONRPCMessage]
RequestId, MemoryObjectSendStream[EventMessage]
] = {}
self._terminated = False
@@ -160,6 +223,21 @@ class StreamableHTTPServerTransport:
"""Extract the session ID from request headers."""
return request.headers.get(MCP_SESSION_ID_HEADER)
def _create_event_data(self, event_message: EventMessage) -> dict[str, str]:
"""Create event data dictionary from an EventMessage."""
event_data = {
"event": "message",
"data": event_message.message.model_dump_json(
by_alias=True, exclude_none=True
),
}
# If an event ID was provided, include it
if event_message.event_id:
event_data["id"] = event_message.event_id
return event_data
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Application entry point that handles all HTTP requests"""
request = Request(scope, receive)
@@ -308,7 +386,7 @@ class StreamableHTTPServerTransport:
request_id = str(message.root.id)
# Create promise stream for getting response
request_stream_writer, request_stream_reader = (
anyio.create_memory_object_stream[JSONRPCMessage](0)
anyio.create_memory_object_stream[EventMessage](0)
)
# Register this stream for the request ID
@@ -323,16 +401,18 @@ class StreamableHTTPServerTransport:
response_message = None
# Use similar approach to SSE writer for consistency
async for received_message in request_stream_reader:
async for event_message in request_stream_reader:
# If it's a response, this is what we're waiting for
if isinstance(
received_message.root, JSONRPCResponse | JSONRPCError
event_message.message.root, JSONRPCResponse | JSONRPCError
):
response_message = received_message
response_message = event_message.message
break
# For notifications and request, keep waiting
else:
logger.debug(f"received: {received_message.root.method}")
logger.debug(
f"received: {event_message.message.root.method}"
)
# At this point we should have a response
if response_message:
@@ -366,7 +446,7 @@ class StreamableHTTPServerTransport:
else:
# Create SSE stream
sse_stream_writer, sse_stream_reader = (
anyio.create_memory_object_stream[dict[str, Any]](0)
anyio.create_memory_object_stream[dict[str, str]](0)
)
async def sse_writer():
@@ -374,20 +454,14 @@ class StreamableHTTPServerTransport:
try:
async with sse_stream_writer, request_stream_reader:
# Process messages from the request-specific stream
async for received_message in request_stream_reader:
async for event_message in request_stream_reader:
# Build the event data
event_data = {
"event": "message",
"data": received_message.model_dump_json(
by_alias=True, exclude_none=True
),
}
event_data = self._create_event_data(event_message)
await sse_stream_writer.send(event_data)
# If response, remove from pending streams and close
if isinstance(
received_message.root,
event_message.message.root,
JSONRPCResponse | JSONRPCError,
):
if request_id:
@@ -472,6 +546,10 @@ class StreamableHTTPServerTransport:
if not await self._validate_session(request, send):
return
# Handle resumability: check for Last-Event-ID header
if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER):
await self._replay_events(last_event_id, request, send)
return
headers = {
"Cache-Control": "no-cache, no-transform",
@@ -493,14 +571,14 @@ class StreamableHTTPServerTransport:
# Create SSE stream
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
dict[str, Any]
dict[str, str]
](0)
async def standalone_sse_writer():
try:
# Create a standalone message stream for server-initiated messages
standalone_stream_writer, standalone_stream_reader = (
anyio.create_memory_object_stream[JSONRPCMessage](0)
anyio.create_memory_object_stream[EventMessage](0)
)
# Register this stream using the special key
@@ -508,20 +586,14 @@ class StreamableHTTPServerTransport:
async with sse_stream_writer, standalone_stream_reader:
# Process messages from the standalone stream
async for received_message in standalone_stream_reader:
async for event_message in standalone_stream_reader:
# For the standalone stream, we handle:
# - JSONRPCNotification (server sends notifications to client)
# - JSONRPCRequest (server sends requests to client)
# We should NOT receive JSONRPCResponse
# Send the message via SSE
event_data = {
"event": "message",
"data": received_message.model_dump_json(
by_alias=True, exclude_none=True
),
}
event_data = self._create_event_data(event_message)
await sse_stream_writer.send(event_data)
except Exception as e:
logger.exception(f"Error in standalone SSE writer: {e}")
@@ -639,6 +711,82 @@ class StreamableHTTPServerTransport:
return True
async def _replay_events(
self, last_event_id: str, request: Request, send: Send
) -> None:
"""
Replays events that would have been sent after the specified event ID.
Only used when resumability is enabled.
"""
event_store = self._event_store
if not event_store:
return
try:
headers = {
"Cache-Control": "no-cache, no-transform",
"Connection": "keep-alive",
"Content-Type": CONTENT_TYPE_SSE,
}
if self.mcp_session_id:
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
# Create SSE stream for replay
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
dict[str, str]
](0)
async def replay_sender():
try:
async with sse_stream_writer:
# Define an async callback for sending events
async def send_event(event_message: EventMessage) -> None:
event_data = self._create_event_data(event_message)
await sse_stream_writer.send(event_data)
# Replay past events and get the stream ID
stream_id = await event_store.replay_events_after(
last_event_id, send_event
)
# If stream ID not in mapping, create it
if stream_id and stream_id not in self._request_streams:
msg_writer, msg_reader = anyio.create_memory_object_stream[
EventMessage
](0)
self._request_streams[stream_id] = msg_writer
# Forward messages to SSE
async with msg_reader:
async for event_message in msg_reader:
event_data = self._create_event_data(event_message)
await sse_stream_writer.send(event_data)
except Exception as e:
logger.exception(f"Error in replay sender: {e}")
# Create and start EventSourceResponse
response = EventSourceResponse(
content=sse_stream_reader,
data_sender_callable=replay_sender,
headers=headers,
)
try:
await response(request.scope, request.receive, send)
except Exception as e:
logger.exception(f"Error in replay response: {e}")
except Exception as e:
logger.exception(f"Error replaying events: {e}")
response = self._create_error_response(
f"Error replaying events: {str(e)}",
HTTPStatus.INTERNAL_SERVER_ERROR,
INTERNAL_ERROR,
)
await response(request.scope, request.receive, send)
@asynccontextmanager
async def connect(
self,
@@ -691,10 +839,22 @@ class StreamableHTTPServerTransport:
target_request_id = str(message.root.id)
request_stream_id = target_request_id or GET_STREAM_KEY
# Store the event if we have an event store,
# regardless of whether a client is connected
# messages will be replayed on the re-connect
event_id = None
if self._event_store:
event_id = await self._event_store.store_event(
request_stream_id, message
)
logger.debug(f"Stored {event_id} from {request_stream_id}")
if request_stream_id in self._request_streams:
try:
# Send both the message and the event ID
await self._request_streams[request_stream_id].send(
message
EventMessage(message, event_id)
)
except (
anyio.BrokenResourceError,

View File

@@ -113,15 +113,12 @@ def create_app(is_json_response_enabled=False) -> Starlette:
async with anyio.create_task_group() as tg:
task_group = tg
print("Application started, task group initialized!")
try:
yield
finally:
print("Application shutting down, cleaning up resources...")
if task_group:
tg.cancel_scope.cancel()
task_group = None
print("Resources cleaned up successfully.")
async def handle_streamable_http(scope, receive, send):
request = Request(scope, receive)
@@ -148,14 +145,11 @@ def create_app(is_json_response_enabled=False) -> Starlette:
read_stream, write_stream = streams
async def run_server():
try:
await server.run(
read_stream,
write_stream,
server.create_initialization_options(),
)
except Exception as e:
print(f"Server exception: {e}")
await server.run(
read_stream,
write_stream,
server.create_initialization_options(),
)
if task_group is None:
response = Response(
@@ -196,10 +190,6 @@ def run_server(port: int, is_json_response_enabled=False) -> None:
port: Port to listen on.
is_json_response_enabled: If True, use JSON responses instead of SSE streams.
"""
print(
f"Starting test server on port {port} with "
f"json_enabled={is_json_response_enabled}"
)
app = create_app(is_json_response_enabled)
# Configure server
@@ -218,16 +208,12 @@ def run_server(port: int, is_json_response_enabled=False) -> None:
# This is important to catch exceptions and prevent test hangs
try:
print("Server starting...")
server.run()
except Exception as e:
print(f"ERROR: Server failed to run: {e}")
except Exception:
import traceback
traceback.print_exc()
print("Server shutdown")
# Test fixtures - using same approach as SSE tests
@pytest.fixture
@@ -273,8 +259,6 @@ def basic_server(basic_server_port: int) -> Generator[None, None, None]:
# Clean up
proc.kill()
proc.join(timeout=2)
if proc.is_alive():
print("server process failed to terminate")
@pytest.fixture
@@ -306,8 +290,6 @@ def json_response_server(json_server_port: int) -> Generator[None, None, None]:
# Clean up
proc.kill()
proc.join(timeout=2)
if proc.is_alive():
print("server process failed to terminate")
@pytest.fixture