mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 06:54:18 +01:00
StreamableHttp -- resumability support for servers (#587)
This commit is contained in:
@@ -10,10 +10,11 @@ responses, with streaming support for long-running operations.
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
@@ -57,6 +58,63 @@ GET_STREAM_KEY = "_GET_stream"
|
||||
# Pattern ensures entire string contains only valid characters by using ^ and $ anchors
|
||||
SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$")
|
||||
|
||||
# Type aliases
|
||||
StreamId = str
|
||||
EventId = str
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventMessage:
|
||||
"""
|
||||
A JSONRPCMessage with an optional event ID for stream resumability.
|
||||
"""
|
||||
|
||||
message: JSONRPCMessage
|
||||
event_id: str | None = None
|
||||
|
||||
|
||||
EventCallback = Callable[[EventMessage], Awaitable[None]]
|
||||
|
||||
|
||||
class EventStore(ABC):
|
||||
"""
|
||||
Interface for resumability support via event storage.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def store_event(
|
||||
self, stream_id: StreamId, message: JSONRPCMessage
|
||||
) -> EventId:
|
||||
"""
|
||||
Stores an event for later retrieval.
|
||||
|
||||
Args:
|
||||
stream_id: ID of the stream the event belongs to
|
||||
message: The JSON-RPC message to store
|
||||
|
||||
Returns:
|
||||
The generated event ID for the stored event
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def replay_events_after(
|
||||
self,
|
||||
last_event_id: EventId,
|
||||
send_callback: EventCallback,
|
||||
) -> StreamId | None:
|
||||
"""
|
||||
Replays events that occurred after the specified event ID.
|
||||
|
||||
Args:
|
||||
last_event_id: The ID of the last event the client received
|
||||
send_callback: A callback function to send events to the client
|
||||
|
||||
Returns:
|
||||
The stream ID of the replayed events
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class StreamableHTTPServerTransport:
|
||||
"""
|
||||
@@ -76,6 +134,7 @@ class StreamableHTTPServerTransport:
|
||||
self,
|
||||
mcp_session_id: str | None,
|
||||
is_json_response_enabled: bool = False,
|
||||
event_store: EventStore | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a new StreamableHTTP server transport.
|
||||
@@ -85,6 +144,9 @@ class StreamableHTTPServerTransport:
|
||||
Must contain only visible ASCII characters (0x21-0x7E).
|
||||
is_json_response_enabled: If True, return JSON responses for requests
|
||||
instead of SSE streams. Default is False.
|
||||
event_store: Event store for resumability support. If provided,
|
||||
resumability will be enabled, allowing clients to
|
||||
reconnect and resume messages.
|
||||
|
||||
Raises:
|
||||
ValueError: If the session ID contains invalid characters.
|
||||
@@ -98,8 +160,9 @@ class StreamableHTTPServerTransport:
|
||||
|
||||
self.mcp_session_id = mcp_session_id
|
||||
self.is_json_response_enabled = is_json_response_enabled
|
||||
self._event_store = event_store
|
||||
self._request_streams: dict[
|
||||
RequestId, MemoryObjectSendStream[JSONRPCMessage]
|
||||
RequestId, MemoryObjectSendStream[EventMessage]
|
||||
] = {}
|
||||
self._terminated = False
|
||||
|
||||
@@ -160,6 +223,21 @@ class StreamableHTTPServerTransport:
|
||||
"""Extract the session ID from request headers."""
|
||||
return request.headers.get(MCP_SESSION_ID_HEADER)
|
||||
|
||||
def _create_event_data(self, event_message: EventMessage) -> dict[str, str]:
|
||||
"""Create event data dictionary from an EventMessage."""
|
||||
event_data = {
|
||||
"event": "message",
|
||||
"data": event_message.message.model_dump_json(
|
||||
by_alias=True, exclude_none=True
|
||||
),
|
||||
}
|
||||
|
||||
# If an event ID was provided, include it
|
||||
if event_message.event_id:
|
||||
event_data["id"] = event_message.event_id
|
||||
|
||||
return event_data
|
||||
|
||||
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
"""Application entry point that handles all HTTP requests"""
|
||||
request = Request(scope, receive)
|
||||
@@ -308,7 +386,7 @@ class StreamableHTTPServerTransport:
|
||||
request_id = str(message.root.id)
|
||||
# Create promise stream for getting response
|
||||
request_stream_writer, request_stream_reader = (
|
||||
anyio.create_memory_object_stream[JSONRPCMessage](0)
|
||||
anyio.create_memory_object_stream[EventMessage](0)
|
||||
)
|
||||
|
||||
# Register this stream for the request ID
|
||||
@@ -323,16 +401,18 @@ class StreamableHTTPServerTransport:
|
||||
response_message = None
|
||||
|
||||
# Use similar approach to SSE writer for consistency
|
||||
async for received_message in request_stream_reader:
|
||||
async for event_message in request_stream_reader:
|
||||
# If it's a response, this is what we're waiting for
|
||||
if isinstance(
|
||||
received_message.root, JSONRPCResponse | JSONRPCError
|
||||
event_message.message.root, JSONRPCResponse | JSONRPCError
|
||||
):
|
||||
response_message = received_message
|
||||
response_message = event_message.message
|
||||
break
|
||||
# For notifications and request, keep waiting
|
||||
else:
|
||||
logger.debug(f"received: {received_message.root.method}")
|
||||
logger.debug(
|
||||
f"received: {event_message.message.root.method}"
|
||||
)
|
||||
|
||||
# At this point we should have a response
|
||||
if response_message:
|
||||
@@ -366,7 +446,7 @@ class StreamableHTTPServerTransport:
|
||||
else:
|
||||
# Create SSE stream
|
||||
sse_stream_writer, sse_stream_reader = (
|
||||
anyio.create_memory_object_stream[dict[str, Any]](0)
|
||||
anyio.create_memory_object_stream[dict[str, str]](0)
|
||||
)
|
||||
|
||||
async def sse_writer():
|
||||
@@ -374,20 +454,14 @@ class StreamableHTTPServerTransport:
|
||||
try:
|
||||
async with sse_stream_writer, request_stream_reader:
|
||||
# Process messages from the request-specific stream
|
||||
async for received_message in request_stream_reader:
|
||||
async for event_message in request_stream_reader:
|
||||
# Build the event data
|
||||
event_data = {
|
||||
"event": "message",
|
||||
"data": received_message.model_dump_json(
|
||||
by_alias=True, exclude_none=True
|
||||
),
|
||||
}
|
||||
|
||||
event_data = self._create_event_data(event_message)
|
||||
await sse_stream_writer.send(event_data)
|
||||
|
||||
# If response, remove from pending streams and close
|
||||
if isinstance(
|
||||
received_message.root,
|
||||
event_message.message.root,
|
||||
JSONRPCResponse | JSONRPCError,
|
||||
):
|
||||
if request_id:
|
||||
@@ -472,6 +546,10 @@ class StreamableHTTPServerTransport:
|
||||
|
||||
if not await self._validate_session(request, send):
|
||||
return
|
||||
# Handle resumability: check for Last-Event-ID header
|
||||
if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER):
|
||||
await self._replay_events(last_event_id, request, send)
|
||||
return
|
||||
|
||||
headers = {
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
@@ -493,14 +571,14 @@ class StreamableHTTPServerTransport:
|
||||
|
||||
# Create SSE stream
|
||||
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
|
||||
dict[str, Any]
|
||||
dict[str, str]
|
||||
](0)
|
||||
|
||||
async def standalone_sse_writer():
|
||||
try:
|
||||
# Create a standalone message stream for server-initiated messages
|
||||
standalone_stream_writer, standalone_stream_reader = (
|
||||
anyio.create_memory_object_stream[JSONRPCMessage](0)
|
||||
anyio.create_memory_object_stream[EventMessage](0)
|
||||
)
|
||||
|
||||
# Register this stream using the special key
|
||||
@@ -508,20 +586,14 @@ class StreamableHTTPServerTransport:
|
||||
|
||||
async with sse_stream_writer, standalone_stream_reader:
|
||||
# Process messages from the standalone stream
|
||||
async for received_message in standalone_stream_reader:
|
||||
async for event_message in standalone_stream_reader:
|
||||
# For the standalone stream, we handle:
|
||||
# - JSONRPCNotification (server sends notifications to client)
|
||||
# - JSONRPCRequest (server sends requests to client)
|
||||
# We should NOT receive JSONRPCResponse
|
||||
|
||||
# Send the message via SSE
|
||||
event_data = {
|
||||
"event": "message",
|
||||
"data": received_message.model_dump_json(
|
||||
by_alias=True, exclude_none=True
|
||||
),
|
||||
}
|
||||
|
||||
event_data = self._create_event_data(event_message)
|
||||
await sse_stream_writer.send(event_data)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in standalone SSE writer: {e}")
|
||||
@@ -639,6 +711,82 @@ class StreamableHTTPServerTransport:
|
||||
|
||||
return True
|
||||
|
||||
async def _replay_events(
|
||||
self, last_event_id: str, request: Request, send: Send
|
||||
) -> None:
|
||||
"""
|
||||
Replays events that would have been sent after the specified event ID.
|
||||
Only used when resumability is enabled.
|
||||
"""
|
||||
event_store = self._event_store
|
||||
if not event_store:
|
||||
return
|
||||
|
||||
try:
|
||||
headers = {
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
"Connection": "keep-alive",
|
||||
"Content-Type": CONTENT_TYPE_SSE,
|
||||
}
|
||||
|
||||
if self.mcp_session_id:
|
||||
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
|
||||
|
||||
# Create SSE stream for replay
|
||||
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
|
||||
dict[str, str]
|
||||
](0)
|
||||
|
||||
async def replay_sender():
|
||||
try:
|
||||
async with sse_stream_writer:
|
||||
# Define an async callback for sending events
|
||||
async def send_event(event_message: EventMessage) -> None:
|
||||
event_data = self._create_event_data(event_message)
|
||||
await sse_stream_writer.send(event_data)
|
||||
|
||||
# Replay past events and get the stream ID
|
||||
stream_id = await event_store.replay_events_after(
|
||||
last_event_id, send_event
|
||||
)
|
||||
|
||||
# If stream ID not in mapping, create it
|
||||
if stream_id and stream_id not in self._request_streams:
|
||||
msg_writer, msg_reader = anyio.create_memory_object_stream[
|
||||
EventMessage
|
||||
](0)
|
||||
self._request_streams[stream_id] = msg_writer
|
||||
|
||||
# Forward messages to SSE
|
||||
async with msg_reader:
|
||||
async for event_message in msg_reader:
|
||||
event_data = self._create_event_data(event_message)
|
||||
|
||||
await sse_stream_writer.send(event_data)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in replay sender: {e}")
|
||||
|
||||
# Create and start EventSourceResponse
|
||||
response = EventSourceResponse(
|
||||
content=sse_stream_reader,
|
||||
data_sender_callable=replay_sender,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
try:
|
||||
await response(request.scope, request.receive, send)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in replay response: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error replaying events: {e}")
|
||||
response = self._create_error_response(
|
||||
f"Error replaying events: {str(e)}",
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
INTERNAL_ERROR,
|
||||
)
|
||||
await response(request.scope, request.receive, send)
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect(
|
||||
self,
|
||||
@@ -691,10 +839,22 @@ class StreamableHTTPServerTransport:
|
||||
target_request_id = str(message.root.id)
|
||||
|
||||
request_stream_id = target_request_id or GET_STREAM_KEY
|
||||
|
||||
# Store the event if we have an event store,
|
||||
# regardless of whether a client is connected
|
||||
# messages will be replayed on the re-connect
|
||||
event_id = None
|
||||
if self._event_store:
|
||||
event_id = await self._event_store.store_event(
|
||||
request_stream_id, message
|
||||
)
|
||||
logger.debug(f"Stored {event_id} from {request_stream_id}")
|
||||
|
||||
if request_stream_id in self._request_streams:
|
||||
try:
|
||||
# Send both the message and the event ID
|
||||
await self._request_streams[request_stream_id].send(
|
||||
message
|
||||
EventMessage(message, event_id)
|
||||
)
|
||||
except (
|
||||
anyio.BrokenResourceError,
|
||||
|
||||
Reference in New Issue
Block a user