mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 06:54:18 +01:00
StreamableHttp -- resumability support for servers (#587)
This commit is contained in:
@@ -9,6 +9,7 @@ A simple MCP server example demonstrating the StreamableHttp transport, which en
|
|||||||
- Task management with anyio task groups
|
- Task management with anyio task groups
|
||||||
- Ability to send multiple notifications over time to the client
|
- Ability to send multiple notifications over time to the client
|
||||||
- Proper resource cleanup and lifespan management
|
- Proper resource cleanup and lifespan management
|
||||||
|
- Resumability support via InMemoryEventStore
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
@@ -32,6 +33,23 @@ The server exposes a tool named "start-notification-stream" that accepts three a
|
|||||||
- `count`: Number of notifications to send (e.g., 5)
|
- `count`: Number of notifications to send (e.g., 5)
|
||||||
- `caller`: Identifier string for the caller
|
- `caller`: Identifier string for the caller
|
||||||
|
|
||||||
|
## Resumability Support
|
||||||
|
|
||||||
|
This server includes resumability support through the InMemoryEventStore. This enables clients to:
|
||||||
|
|
||||||
|
- Reconnect to the server after a disconnection
|
||||||
|
- Resume event streaming from where they left off using the Last-Event-ID header
|
||||||
|
|
||||||
|
|
||||||
|
The server will:
|
||||||
|
- Generate unique event IDs for each SSE message
|
||||||
|
- Store events in memory for later replay
|
||||||
|
- Replay missed events when a client reconnects with a Last-Event-ID header
|
||||||
|
|
||||||
|
Note: The InMemoryEventStore is designed for demonstration purposes only. For production use, consider implementing a persistent storage solution.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Client
|
## Client
|
||||||
|
|
||||||
You can connect to this server using an HTTP client, for now only Typescript SDK has streamable HTTP client examples or you can use (Inspector)[https://github.com/modelcontextprotocol/inspector]
|
You can connect to this server using an HTTP client, for now only Typescript SDK has streamable HTTP client examples or you can use [Inspector](https://github.com/modelcontextprotocol/inspector)
|
||||||
@@ -0,0 +1,105 @@
|
|||||||
|
"""
|
||||||
|
In-memory event store for demonstrating resumability functionality.
|
||||||
|
|
||||||
|
This is a simple implementation intended for examples and testing,
|
||||||
|
not for production use where a persistent storage solution would be more appropriate.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections import deque
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from mcp.server.streamable_http import (
|
||||||
|
EventCallback,
|
||||||
|
EventId,
|
||||||
|
EventMessage,
|
||||||
|
EventStore,
|
||||||
|
StreamId,
|
||||||
|
)
|
||||||
|
from mcp.types import JSONRPCMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EventEntry:
|
||||||
|
"""
|
||||||
|
Represents an event entry in the event store.
|
||||||
|
"""
|
||||||
|
|
||||||
|
event_id: EventId
|
||||||
|
stream_id: StreamId
|
||||||
|
message: JSONRPCMessage
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryEventStore(EventStore):
|
||||||
|
"""
|
||||||
|
Simple in-memory implementation of the EventStore interface for resumability.
|
||||||
|
This is primarily intended for examples and testing, not for production use
|
||||||
|
where a persistent storage solution would be more appropriate.
|
||||||
|
|
||||||
|
This implementation keeps only the last N events per stream for memory efficiency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_events_per_stream: int = 100):
|
||||||
|
"""Initialize the event store.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_events_per_stream: Maximum number of events to keep per stream
|
||||||
|
"""
|
||||||
|
self.max_events_per_stream = max_events_per_stream
|
||||||
|
# for maintaining last N events per stream
|
||||||
|
self.streams: dict[StreamId, deque[EventEntry]] = {}
|
||||||
|
# event_id -> EventEntry for quick lookup
|
||||||
|
self.event_index: dict[EventId, EventEntry] = {}
|
||||||
|
|
||||||
|
async def store_event(
|
||||||
|
self, stream_id: StreamId, message: JSONRPCMessage
|
||||||
|
) -> EventId:
|
||||||
|
"""Stores an event with a generated event ID."""
|
||||||
|
event_id = str(uuid4())
|
||||||
|
event_entry = EventEntry(
|
||||||
|
event_id=event_id, stream_id=stream_id, message=message
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get or create deque for this stream
|
||||||
|
if stream_id not in self.streams:
|
||||||
|
self.streams[stream_id] = deque(maxlen=self.max_events_per_stream)
|
||||||
|
|
||||||
|
# If deque is full, the oldest event will be automatically removed
|
||||||
|
# We need to remove it from the event_index as well
|
||||||
|
if len(self.streams[stream_id]) == self.max_events_per_stream:
|
||||||
|
oldest_event = self.streams[stream_id][0]
|
||||||
|
self.event_index.pop(oldest_event.event_id, None)
|
||||||
|
|
||||||
|
# Add new event
|
||||||
|
self.streams[stream_id].append(event_entry)
|
||||||
|
self.event_index[event_id] = event_entry
|
||||||
|
|
||||||
|
return event_id
|
||||||
|
|
||||||
|
async def replay_events_after(
|
||||||
|
self,
|
||||||
|
last_event_id: EventId,
|
||||||
|
send_callback: EventCallback,
|
||||||
|
) -> StreamId | None:
|
||||||
|
"""Replays events that occurred after the specified event ID."""
|
||||||
|
if last_event_id not in self.event_index:
|
||||||
|
logger.warning(f"Event ID {last_event_id} not found in store")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get the stream and find events after the last one
|
||||||
|
last_event = self.event_index[last_event_id]
|
||||||
|
stream_id = last_event.stream_id
|
||||||
|
stream_events = self.streams.get(last_event.stream_id, deque())
|
||||||
|
|
||||||
|
# Events in deque are already in chronological order
|
||||||
|
found_last = False
|
||||||
|
for event in stream_events:
|
||||||
|
if found_last:
|
||||||
|
await send_callback(EventMessage(event.message, event.event_id))
|
||||||
|
elif event.event_id == last_event_id:
|
||||||
|
found_last = True
|
||||||
|
|
||||||
|
return stream_id
|
||||||
@@ -17,12 +17,24 @@ from starlette.requests import Request
|
|||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
from starlette.routing import Mount
|
from starlette.routing import Mount
|
||||||
|
|
||||||
|
from .event_store import InMemoryEventStore
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Global task group that will be initialized in the lifespan
|
# Global task group that will be initialized in the lifespan
|
||||||
task_group = None
|
task_group = None
|
||||||
|
|
||||||
|
# Event store for resumability
|
||||||
|
# The InMemoryEventStore enables resumability support for StreamableHTTP transport.
|
||||||
|
# It stores SSE events with unique IDs, allowing clients to:
|
||||||
|
# 1. Receive event IDs for each SSE message
|
||||||
|
# 2. Resume streams by sending Last-Event-ID in GET requests
|
||||||
|
# 3. Replay missed events after reconnection
|
||||||
|
# Note: This in-memory implementation is for demonstration ONLY.
|
||||||
|
# For production, use a persistent storage solution.
|
||||||
|
event_store = InMemoryEventStore()
|
||||||
|
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def lifespan(app):
|
async def lifespan(app):
|
||||||
@@ -79,9 +91,14 @@ def main(
|
|||||||
|
|
||||||
# Send the specified number of notifications with the given interval
|
# Send the specified number of notifications with the given interval
|
||||||
for i in range(count):
|
for i in range(count):
|
||||||
|
# Include more detailed message for resumability demonstration
|
||||||
|
notification_msg = (
|
||||||
|
f"[{i+1}/{count}] Event from '{caller}' - "
|
||||||
|
f"Use Last-Event-ID to resume if disconnected"
|
||||||
|
)
|
||||||
await ctx.session.send_log_message(
|
await ctx.session.send_log_message(
|
||||||
level="info",
|
level="info",
|
||||||
data=f"Notification {i+1}/{count} from caller: {caller}",
|
data=notification_msg,
|
||||||
logger="notification_stream",
|
logger="notification_stream",
|
||||||
# Associates this notification with the original request
|
# Associates this notification with the original request
|
||||||
# Ensures notifications are sent to the correct response stream
|
# Ensures notifications are sent to the correct response stream
|
||||||
@@ -90,6 +107,7 @@ def main(
|
|||||||
# - nowhere (if GET request isn't supported)
|
# - nowhere (if GET request isn't supported)
|
||||||
related_request_id=ctx.request_id,
|
related_request_id=ctx.request_id,
|
||||||
)
|
)
|
||||||
|
logger.debug(f"Sent notification {i+1}/{count} for caller: {caller}")
|
||||||
if i < count - 1: # Don't wait after the last notification
|
if i < count - 1: # Don't wait after the last notification
|
||||||
await anyio.sleep(interval)
|
await anyio.sleep(interval)
|
||||||
|
|
||||||
@@ -163,8 +181,10 @@ def main(
|
|||||||
http_transport = StreamableHTTPServerTransport(
|
http_transport = StreamableHTTPServerTransport(
|
||||||
mcp_session_id=new_session_id,
|
mcp_session_id=new_session_id,
|
||||||
is_json_response_enabled=json_response,
|
is_json_response_enabled=json_response,
|
||||||
|
event_store=event_store, # Enable resumability
|
||||||
)
|
)
|
||||||
server_instances[http_transport.mcp_session_id] = http_transport
|
server_instances[http_transport.mcp_session_id] = http_transport
|
||||||
|
logger.info(f"Created new transport with session ID: {new_session_id}")
|
||||||
async with http_transport.connect() as streams:
|
async with http_transport.connect() as streams:
|
||||||
read_stream, write_stream = streams
|
read_stream, write_stream = streams
|
||||||
|
|
||||||
|
|||||||
@@ -10,10 +10,11 @@ responses, with streaming support for long-running operations.
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from collections.abc import AsyncGenerator
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||||
@@ -57,6 +58,63 @@ GET_STREAM_KEY = "_GET_stream"
|
|||||||
# Pattern ensures entire string contains only valid characters by using ^ and $ anchors
|
# Pattern ensures entire string contains only valid characters by using ^ and $ anchors
|
||||||
SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$")
|
SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$")
|
||||||
|
|
||||||
|
# Type aliases
|
||||||
|
StreamId = str
|
||||||
|
EventId = str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EventMessage:
|
||||||
|
"""
|
||||||
|
A JSONRPCMessage with an optional event ID for stream resumability.
|
||||||
|
"""
|
||||||
|
|
||||||
|
message: JSONRPCMessage
|
||||||
|
event_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
EventCallback = Callable[[EventMessage], Awaitable[None]]
|
||||||
|
|
||||||
|
|
||||||
|
class EventStore(ABC):
|
||||||
|
"""
|
||||||
|
Interface for resumability support via event storage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def store_event(
|
||||||
|
self, stream_id: StreamId, message: JSONRPCMessage
|
||||||
|
) -> EventId:
|
||||||
|
"""
|
||||||
|
Stores an event for later retrieval.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id: ID of the stream the event belongs to
|
||||||
|
message: The JSON-RPC message to store
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The generated event ID for the stored event
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def replay_events_after(
|
||||||
|
self,
|
||||||
|
last_event_id: EventId,
|
||||||
|
send_callback: EventCallback,
|
||||||
|
) -> StreamId | None:
|
||||||
|
"""
|
||||||
|
Replays events that occurred after the specified event ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
last_event_id: The ID of the last event the client received
|
||||||
|
send_callback: A callback function to send events to the client
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The stream ID of the replayed events
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class StreamableHTTPServerTransport:
|
class StreamableHTTPServerTransport:
|
||||||
"""
|
"""
|
||||||
@@ -76,6 +134,7 @@ class StreamableHTTPServerTransport:
|
|||||||
self,
|
self,
|
||||||
mcp_session_id: str | None,
|
mcp_session_id: str | None,
|
||||||
is_json_response_enabled: bool = False,
|
is_json_response_enabled: bool = False,
|
||||||
|
event_store: EventStore | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize a new StreamableHTTP server transport.
|
Initialize a new StreamableHTTP server transport.
|
||||||
@@ -85,6 +144,9 @@ class StreamableHTTPServerTransport:
|
|||||||
Must contain only visible ASCII characters (0x21-0x7E).
|
Must contain only visible ASCII characters (0x21-0x7E).
|
||||||
is_json_response_enabled: If True, return JSON responses for requests
|
is_json_response_enabled: If True, return JSON responses for requests
|
||||||
instead of SSE streams. Default is False.
|
instead of SSE streams. Default is False.
|
||||||
|
event_store: Event store for resumability support. If provided,
|
||||||
|
resumability will be enabled, allowing clients to
|
||||||
|
reconnect and resume messages.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the session ID contains invalid characters.
|
ValueError: If the session ID contains invalid characters.
|
||||||
@@ -98,8 +160,9 @@ class StreamableHTTPServerTransport:
|
|||||||
|
|
||||||
self.mcp_session_id = mcp_session_id
|
self.mcp_session_id = mcp_session_id
|
||||||
self.is_json_response_enabled = is_json_response_enabled
|
self.is_json_response_enabled = is_json_response_enabled
|
||||||
|
self._event_store = event_store
|
||||||
self._request_streams: dict[
|
self._request_streams: dict[
|
||||||
RequestId, MemoryObjectSendStream[JSONRPCMessage]
|
RequestId, MemoryObjectSendStream[EventMessage]
|
||||||
] = {}
|
] = {}
|
||||||
self._terminated = False
|
self._terminated = False
|
||||||
|
|
||||||
@@ -160,6 +223,21 @@ class StreamableHTTPServerTransport:
|
|||||||
"""Extract the session ID from request headers."""
|
"""Extract the session ID from request headers."""
|
||||||
return request.headers.get(MCP_SESSION_ID_HEADER)
|
return request.headers.get(MCP_SESSION_ID_HEADER)
|
||||||
|
|
||||||
|
def _create_event_data(self, event_message: EventMessage) -> dict[str, str]:
|
||||||
|
"""Create event data dictionary from an EventMessage."""
|
||||||
|
event_data = {
|
||||||
|
"event": "message",
|
||||||
|
"data": event_message.message.model_dump_json(
|
||||||
|
by_alias=True, exclude_none=True
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# If an event ID was provided, include it
|
||||||
|
if event_message.event_id:
|
||||||
|
event_data["id"] = event_message.event_id
|
||||||
|
|
||||||
|
return event_data
|
||||||
|
|
||||||
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
|
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
"""Application entry point that handles all HTTP requests"""
|
"""Application entry point that handles all HTTP requests"""
|
||||||
request = Request(scope, receive)
|
request = Request(scope, receive)
|
||||||
@@ -308,7 +386,7 @@ class StreamableHTTPServerTransport:
|
|||||||
request_id = str(message.root.id)
|
request_id = str(message.root.id)
|
||||||
# Create promise stream for getting response
|
# Create promise stream for getting response
|
||||||
request_stream_writer, request_stream_reader = (
|
request_stream_writer, request_stream_reader = (
|
||||||
anyio.create_memory_object_stream[JSONRPCMessage](0)
|
anyio.create_memory_object_stream[EventMessage](0)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Register this stream for the request ID
|
# Register this stream for the request ID
|
||||||
@@ -323,16 +401,18 @@ class StreamableHTTPServerTransport:
|
|||||||
response_message = None
|
response_message = None
|
||||||
|
|
||||||
# Use similar approach to SSE writer for consistency
|
# Use similar approach to SSE writer for consistency
|
||||||
async for received_message in request_stream_reader:
|
async for event_message in request_stream_reader:
|
||||||
# If it's a response, this is what we're waiting for
|
# If it's a response, this is what we're waiting for
|
||||||
if isinstance(
|
if isinstance(
|
||||||
received_message.root, JSONRPCResponse | JSONRPCError
|
event_message.message.root, JSONRPCResponse | JSONRPCError
|
||||||
):
|
):
|
||||||
response_message = received_message
|
response_message = event_message.message
|
||||||
break
|
break
|
||||||
# For notifications and request, keep waiting
|
# For notifications and request, keep waiting
|
||||||
else:
|
else:
|
||||||
logger.debug(f"received: {received_message.root.method}")
|
logger.debug(
|
||||||
|
f"received: {event_message.message.root.method}"
|
||||||
|
)
|
||||||
|
|
||||||
# At this point we should have a response
|
# At this point we should have a response
|
||||||
if response_message:
|
if response_message:
|
||||||
@@ -366,7 +446,7 @@ class StreamableHTTPServerTransport:
|
|||||||
else:
|
else:
|
||||||
# Create SSE stream
|
# Create SSE stream
|
||||||
sse_stream_writer, sse_stream_reader = (
|
sse_stream_writer, sse_stream_reader = (
|
||||||
anyio.create_memory_object_stream[dict[str, Any]](0)
|
anyio.create_memory_object_stream[dict[str, str]](0)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def sse_writer():
|
async def sse_writer():
|
||||||
@@ -374,20 +454,14 @@ class StreamableHTTPServerTransport:
|
|||||||
try:
|
try:
|
||||||
async with sse_stream_writer, request_stream_reader:
|
async with sse_stream_writer, request_stream_reader:
|
||||||
# Process messages from the request-specific stream
|
# Process messages from the request-specific stream
|
||||||
async for received_message in request_stream_reader:
|
async for event_message in request_stream_reader:
|
||||||
# Build the event data
|
# Build the event data
|
||||||
event_data = {
|
event_data = self._create_event_data(event_message)
|
||||||
"event": "message",
|
|
||||||
"data": received_message.model_dump_json(
|
|
||||||
by_alias=True, exclude_none=True
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
await sse_stream_writer.send(event_data)
|
await sse_stream_writer.send(event_data)
|
||||||
|
|
||||||
# If response, remove from pending streams and close
|
# If response, remove from pending streams and close
|
||||||
if isinstance(
|
if isinstance(
|
||||||
received_message.root,
|
event_message.message.root,
|
||||||
JSONRPCResponse | JSONRPCError,
|
JSONRPCResponse | JSONRPCError,
|
||||||
):
|
):
|
||||||
if request_id:
|
if request_id:
|
||||||
@@ -472,6 +546,10 @@ class StreamableHTTPServerTransport:
|
|||||||
|
|
||||||
if not await self._validate_session(request, send):
|
if not await self._validate_session(request, send):
|
||||||
return
|
return
|
||||||
|
# Handle resumability: check for Last-Event-ID header
|
||||||
|
if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER):
|
||||||
|
await self._replay_events(last_event_id, request, send)
|
||||||
|
return
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"Cache-Control": "no-cache, no-transform",
|
"Cache-Control": "no-cache, no-transform",
|
||||||
@@ -493,14 +571,14 @@ class StreamableHTTPServerTransport:
|
|||||||
|
|
||||||
# Create SSE stream
|
# Create SSE stream
|
||||||
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
|
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
|
||||||
dict[str, Any]
|
dict[str, str]
|
||||||
](0)
|
](0)
|
||||||
|
|
||||||
async def standalone_sse_writer():
|
async def standalone_sse_writer():
|
||||||
try:
|
try:
|
||||||
# Create a standalone message stream for server-initiated messages
|
# Create a standalone message stream for server-initiated messages
|
||||||
standalone_stream_writer, standalone_stream_reader = (
|
standalone_stream_writer, standalone_stream_reader = (
|
||||||
anyio.create_memory_object_stream[JSONRPCMessage](0)
|
anyio.create_memory_object_stream[EventMessage](0)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Register this stream using the special key
|
# Register this stream using the special key
|
||||||
@@ -508,20 +586,14 @@ class StreamableHTTPServerTransport:
|
|||||||
|
|
||||||
async with sse_stream_writer, standalone_stream_reader:
|
async with sse_stream_writer, standalone_stream_reader:
|
||||||
# Process messages from the standalone stream
|
# Process messages from the standalone stream
|
||||||
async for received_message in standalone_stream_reader:
|
async for event_message in standalone_stream_reader:
|
||||||
# For the standalone stream, we handle:
|
# For the standalone stream, we handle:
|
||||||
# - JSONRPCNotification (server sends notifications to client)
|
# - JSONRPCNotification (server sends notifications to client)
|
||||||
# - JSONRPCRequest (server sends requests to client)
|
# - JSONRPCRequest (server sends requests to client)
|
||||||
# We should NOT receive JSONRPCResponse
|
# We should NOT receive JSONRPCResponse
|
||||||
|
|
||||||
# Send the message via SSE
|
# Send the message via SSE
|
||||||
event_data = {
|
event_data = self._create_event_data(event_message)
|
||||||
"event": "message",
|
|
||||||
"data": received_message.model_dump_json(
|
|
||||||
by_alias=True, exclude_none=True
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
await sse_stream_writer.send(event_data)
|
await sse_stream_writer.send(event_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Error in standalone SSE writer: {e}")
|
logger.exception(f"Error in standalone SSE writer: {e}")
|
||||||
@@ -639,6 +711,82 @@ class StreamableHTTPServerTransport:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
async def _replay_events(
|
||||||
|
self, last_event_id: str, request: Request, send: Send
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Replays events that would have been sent after the specified event ID.
|
||||||
|
Only used when resumability is enabled.
|
||||||
|
"""
|
||||||
|
event_store = self._event_store
|
||||||
|
if not event_store:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
headers = {
|
||||||
|
"Cache-Control": "no-cache, no-transform",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"Content-Type": CONTENT_TYPE_SSE,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.mcp_session_id:
|
||||||
|
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
|
||||||
|
|
||||||
|
# Create SSE stream for replay
|
||||||
|
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
|
||||||
|
dict[str, str]
|
||||||
|
](0)
|
||||||
|
|
||||||
|
async def replay_sender():
|
||||||
|
try:
|
||||||
|
async with sse_stream_writer:
|
||||||
|
# Define an async callback for sending events
|
||||||
|
async def send_event(event_message: EventMessage) -> None:
|
||||||
|
event_data = self._create_event_data(event_message)
|
||||||
|
await sse_stream_writer.send(event_data)
|
||||||
|
|
||||||
|
# Replay past events and get the stream ID
|
||||||
|
stream_id = await event_store.replay_events_after(
|
||||||
|
last_event_id, send_event
|
||||||
|
)
|
||||||
|
|
||||||
|
# If stream ID not in mapping, create it
|
||||||
|
if stream_id and stream_id not in self._request_streams:
|
||||||
|
msg_writer, msg_reader = anyio.create_memory_object_stream[
|
||||||
|
EventMessage
|
||||||
|
](0)
|
||||||
|
self._request_streams[stream_id] = msg_writer
|
||||||
|
|
||||||
|
# Forward messages to SSE
|
||||||
|
async with msg_reader:
|
||||||
|
async for event_message in msg_reader:
|
||||||
|
event_data = self._create_event_data(event_message)
|
||||||
|
|
||||||
|
await sse_stream_writer.send(event_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Error in replay sender: {e}")
|
||||||
|
|
||||||
|
# Create and start EventSourceResponse
|
||||||
|
response = EventSourceResponse(
|
||||||
|
content=sse_stream_reader,
|
||||||
|
data_sender_callable=replay_sender,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await response(request.scope, request.receive, send)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Error in replay response: {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Error replaying events: {e}")
|
||||||
|
response = self._create_error_response(
|
||||||
|
f"Error replaying events: {str(e)}",
|
||||||
|
HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||||
|
INTERNAL_ERROR,
|
||||||
|
)
|
||||||
|
await response(request.scope, request.receive, send)
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def connect(
|
async def connect(
|
||||||
self,
|
self,
|
||||||
@@ -691,10 +839,22 @@ class StreamableHTTPServerTransport:
|
|||||||
target_request_id = str(message.root.id)
|
target_request_id = str(message.root.id)
|
||||||
|
|
||||||
request_stream_id = target_request_id or GET_STREAM_KEY
|
request_stream_id = target_request_id or GET_STREAM_KEY
|
||||||
|
|
||||||
|
# Store the event if we have an event store,
|
||||||
|
# regardless of whether a client is connected
|
||||||
|
# messages will be replayed on the re-connect
|
||||||
|
event_id = None
|
||||||
|
if self._event_store:
|
||||||
|
event_id = await self._event_store.store_event(
|
||||||
|
request_stream_id, message
|
||||||
|
)
|
||||||
|
logger.debug(f"Stored {event_id} from {request_stream_id}")
|
||||||
|
|
||||||
if request_stream_id in self._request_streams:
|
if request_stream_id in self._request_streams:
|
||||||
try:
|
try:
|
||||||
|
# Send both the message and the event ID
|
||||||
await self._request_streams[request_stream_id].send(
|
await self._request_streams[request_stream_id].send(
|
||||||
message
|
EventMessage(message, event_id)
|
||||||
)
|
)
|
||||||
except (
|
except (
|
||||||
anyio.BrokenResourceError,
|
anyio.BrokenResourceError,
|
||||||
|
|||||||
@@ -113,15 +113,12 @@ def create_app(is_json_response_enabled=False) -> Starlette:
|
|||||||
|
|
||||||
async with anyio.create_task_group() as tg:
|
async with anyio.create_task_group() as tg:
|
||||||
task_group = tg
|
task_group = tg
|
||||||
print("Application started, task group initialized!")
|
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
print("Application shutting down, cleaning up resources...")
|
|
||||||
if task_group:
|
if task_group:
|
||||||
tg.cancel_scope.cancel()
|
tg.cancel_scope.cancel()
|
||||||
task_group = None
|
task_group = None
|
||||||
print("Resources cleaned up successfully.")
|
|
||||||
|
|
||||||
async def handle_streamable_http(scope, receive, send):
|
async def handle_streamable_http(scope, receive, send):
|
||||||
request = Request(scope, receive)
|
request = Request(scope, receive)
|
||||||
@@ -148,14 +145,11 @@ def create_app(is_json_response_enabled=False) -> Starlette:
|
|||||||
read_stream, write_stream = streams
|
read_stream, write_stream = streams
|
||||||
|
|
||||||
async def run_server():
|
async def run_server():
|
||||||
try:
|
await server.run(
|
||||||
await server.run(
|
read_stream,
|
||||||
read_stream,
|
write_stream,
|
||||||
write_stream,
|
server.create_initialization_options(),
|
||||||
server.create_initialization_options(),
|
)
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Server exception: {e}")
|
|
||||||
|
|
||||||
if task_group is None:
|
if task_group is None:
|
||||||
response = Response(
|
response = Response(
|
||||||
@@ -196,10 +190,6 @@ def run_server(port: int, is_json_response_enabled=False) -> None:
|
|||||||
port: Port to listen on.
|
port: Port to listen on.
|
||||||
is_json_response_enabled: If True, use JSON responses instead of SSE streams.
|
is_json_response_enabled: If True, use JSON responses instead of SSE streams.
|
||||||
"""
|
"""
|
||||||
print(
|
|
||||||
f"Starting test server on port {port} with "
|
|
||||||
f"json_enabled={is_json_response_enabled}"
|
|
||||||
)
|
|
||||||
|
|
||||||
app = create_app(is_json_response_enabled)
|
app = create_app(is_json_response_enabled)
|
||||||
# Configure server
|
# Configure server
|
||||||
@@ -218,16 +208,12 @@ def run_server(port: int, is_json_response_enabled=False) -> None:
|
|||||||
|
|
||||||
# This is important to catch exceptions and prevent test hangs
|
# This is important to catch exceptions and prevent test hangs
|
||||||
try:
|
try:
|
||||||
print("Server starting...")
|
|
||||||
server.run()
|
server.run()
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print(f"ERROR: Server failed to run: {e}")
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
print("Server shutdown")
|
|
||||||
|
|
||||||
|
|
||||||
# Test fixtures - using same approach as SSE tests
|
# Test fixtures - using same approach as SSE tests
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -273,8 +259,6 @@ def basic_server(basic_server_port: int) -> Generator[None, None, None]:
|
|||||||
# Clean up
|
# Clean up
|
||||||
proc.kill()
|
proc.kill()
|
||||||
proc.join(timeout=2)
|
proc.join(timeout=2)
|
||||||
if proc.is_alive():
|
|
||||||
print("server process failed to terminate")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -306,8 +290,6 @@ def json_response_server(json_server_port: int) -> Generator[None, None, None]:
|
|||||||
# Clean up
|
# Clean up
|
||||||
proc.kill()
|
proc.kill()
|
||||||
proc.join(timeout=2)
|
proc.join(timeout=2)
|
||||||
if proc.is_alive():
|
|
||||||
print("server process failed to terminate")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
Reference in New Issue
Block a user