Add message queue for SSE messages POST endpoint (#459)

This commit is contained in:
Akash D
2025-05-06 17:10:43 -07:00
committed by GitHub
parent 58c5e7223c
commit 3b1b213a96
26 changed files with 1247 additions and 50 deletions

View File

@@ -44,6 +44,7 @@ from mcp.server.lowlevel.helper_types import ReadResourceContents
from mcp.server.lowlevel.server import LifespanResultT
from mcp.server.lowlevel.server import Server as MCPServer
from mcp.server.lowlevel.server import lifespan as default_lifespan
from mcp.server.message_queue import MessageDispatch
from mcp.server.session import ServerSession, ServerSessionT
from mcp.server.sse import SseServerTransport
from mcp.server.stdio import stdio_server
@@ -90,6 +91,11 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
sse_path: str = "/sse"
message_path: str = "/messages/"
# SSE message queue settings
message_dispatch: MessageDispatch | None = Field(
None, description="Custom message dispatch instance"
)
# resource settings
warn_on_duplicate_resources: bool = True
@@ -569,12 +575,21 @@ class FastMCP:
def sse_app(self) -> Starlette:
"""Return an instance of the SSE server app."""
message_dispatch = self.settings.message_dispatch
if message_dispatch is None:
from mcp.server.message_queue import InMemoryMessageDispatch
message_dispatch = InMemoryMessageDispatch()
logger.info("Using default in-memory message dispatch")
from starlette.middleware import Middleware
from starlette.routing import Mount, Route
# Set up auth context and dependencies
sse = SseServerTransport(self.settings.message_path)
sse = SseServerTransport(
self.settings.message_path, message_dispatch=message_dispatch
)
async def handle_sse(scope: Scope, receive: Receive, send: Send):
# Add client ID from auth context into request context if available
@@ -589,7 +604,14 @@ class FastMCP:
streams[1],
self._mcp_server.create_initialization_options(),
)
return Response()
return Response()
@asynccontextmanager
async def lifespan(app: Starlette):
try:
yield
finally:
await message_dispatch.close()
# Create routes
routes: list[Route | Mount] = []
@@ -666,7 +688,10 @@ class FastMCP:
# Create Starlette app with routes and middleware
return Starlette(
debug=self.settings.debug, routes=routes, middleware=middleware
debug=self.settings.debug,
routes=routes,
middleware=middleware,
lifespan=lifespan,
)
async def list_prompts(self) -> list[MCPPrompt]:

View File

@@ -0,0 +1,16 @@
"""
Message Dispatch Module for MCP Server
This module implements dispatch interfaces for handling
messages between clients and servers.
"""
from mcp.server.message_queue.base import InMemoryMessageDispatch, MessageDispatch
# Try to import Redis implementation if available
try:
from mcp.server.message_queue.redis import RedisMessageDispatch
except ImportError:
RedisMessageDispatch = None
__all__ = ["MessageDispatch", "InMemoryMessageDispatch", "RedisMessageDispatch"]

View File

@@ -0,0 +1,116 @@
import logging
from collections.abc import Awaitable, Callable
from contextlib import asynccontextmanager
from typing import Protocol, runtime_checkable
from uuid import UUID
from pydantic import ValidationError
from mcp.shared.message import SessionMessage
logger = logging.getLogger(__name__)
MessageCallback = Callable[[SessionMessage | Exception], Awaitable[None]]
@runtime_checkable
class MessageDispatch(Protocol):
"""Abstract interface for SSE message dispatching.
This interface allows messages to be published to sessions and callbacks to be
registered for message handling, enabling multiple servers to handle requests.
"""
async def publish_message(
self, session_id: UUID, message: SessionMessage | str
) -> bool:
"""Publish a message for the specified session.
Args:
session_id: The UUID of the session this message is for
message: The message to publish (SessionMessage or str for invalid JSON)
Returns:
bool: True if message was published, False if session not found
"""
...
@asynccontextmanager
async def subscribe(self, session_id: UUID, callback: MessageCallback):
"""Request-scoped context manager that subscribes to messages for a session.
Args:
session_id: The UUID of the session to subscribe to
callback: Async callback function to handle messages for this session
"""
yield
async def session_exists(self, session_id: UUID) -> bool:
"""Check if a session exists.
Args:
session_id: The UUID of the session to check
Returns:
bool: True if the session is active, False otherwise
"""
...
async def close(self) -> None:
"""Close the message dispatch."""
...
class InMemoryMessageDispatch:
"""Default in-memory implementation of the MessageDispatch interface.
This implementation immediately dispatches messages to registered callbacks when
messages are received without any queuing behavior.
"""
def __init__(self) -> None:
self._callbacks: dict[UUID, MessageCallback] = {}
async def publish_message(
self, session_id: UUID, message: SessionMessage | str
) -> bool:
"""Publish a message for the specified session."""
if session_id not in self._callbacks:
logger.warning(f"Message dropped: unknown session {session_id}")
return False
# Parse string messages or recreate original ValidationError
if isinstance(message, str):
try:
callback_argument = SessionMessage.model_validate_json(message)
except ValidationError as exc:
callback_argument = exc
else:
callback_argument = message
# Call the callback with either valid message or recreated ValidationError
await self._callbacks[session_id](callback_argument)
logger.debug(f"Message dispatched to session {session_id}")
return True
@asynccontextmanager
async def subscribe(self, session_id: UUID, callback: MessageCallback):
"""Request-scoped context manager that subscribes to messages for a session."""
self._callbacks[session_id] = callback
logger.debug(f"Subscribing to messages for session {session_id}")
try:
yield
finally:
if session_id in self._callbacks:
del self._callbacks[session_id]
logger.debug(f"Unsubscribed from session {session_id}")
async def session_exists(self, session_id: UUID) -> bool:
"""Check if a session exists."""
return session_id in self._callbacks
async def close(self) -> None:
"""Close the message dispatch."""
pass

View File

@@ -0,0 +1,198 @@
import logging
from contextlib import asynccontextmanager
from typing import Any, cast
from uuid import UUID
import anyio
from anyio import CancelScope, CapacityLimiter, lowlevel
from anyio.abc import TaskGroup
from pydantic import ValidationError
from mcp.server.message_queue.base import MessageCallback
from mcp.shared.message import SessionMessage
try:
import redis.asyncio as redis
except ImportError:
raise ImportError(
"Redis support requires the 'redis' package. "
"Install it with: 'uv add redis' or 'uv add \"mcp[redis]\"'"
)
logger = logging.getLogger(__name__)
class RedisMessageDispatch:
"""Redis implementation of the MessageDispatch interface using pubsub.
This implementation uses Redis pubsub for real-time message distribution across
multiple servers handling the same sessions.
"""
def __init__(
self,
redis_url: str = "redis://localhost:6379/0",
prefix: str = "mcp:pubsub:",
session_ttl: int = 3600, # 1 hour default TTL for sessions
) -> None:
"""Initialize Redis message dispatch.
Args:
redis_url: Redis connection string
prefix: Key prefix for Redis channels to avoid collisions
session_ttl: TTL in seconds for session keys (default: 1 hour)
"""
self._redis = redis.from_url(redis_url, decode_responses=True) # type: ignore
self._pubsub = self._redis.pubsub(ignore_subscribe_messages=True) # type: ignore
self._prefix = prefix
self._session_ttl = session_ttl
# Maps session IDs to the callback and task group for that SSE session.
self._session_state: dict[UUID, tuple[MessageCallback, TaskGroup]] = {}
# Ensures only one polling task runs at a time for message handling
self._limiter = CapacityLimiter(1)
logger.debug(f"Redis message dispatch initialized: {redis_url}")
async def close(self):
await self._pubsub.aclose() # type: ignore
await self._redis.aclose() # type: ignore
def _session_channel(self, session_id: UUID) -> str:
"""Get the Redis channel for a session."""
return f"{self._prefix}session:{session_id.hex}"
def _session_key(self, session_id: UUID) -> str:
"""Get the Redis key for a session."""
return f"{self._prefix}session_active:{session_id.hex}"
@asynccontextmanager
async def subscribe(self, session_id: UUID, callback: MessageCallback):
"""Request-scoped context manager that subscribes to messages for a session."""
session_key = self._session_key(session_id)
await self._redis.setex(session_key, self._session_ttl, "1") # type: ignore
channel = self._session_channel(session_id)
await self._pubsub.subscribe(channel) # type: ignore
logger.debug(f"Subscribing to Redis channel for session {session_id}")
async with anyio.create_task_group() as tg:
self._session_state[session_id] = (callback, tg)
tg.start_soon(self._listen_for_messages)
# Start heartbeat for this session
tg.start_soon(self._session_heartbeat, session_id)
try:
yield
finally:
with anyio.CancelScope(shield=True):
tg.cancel_scope.cancel()
await self._pubsub.unsubscribe(channel) # type: ignore
await self._redis.delete(session_key) # type: ignore
del self._session_state[session_id]
logger.debug(f"Unsubscribed from Redis channel: {session_id}")
async def _session_heartbeat(self, session_id: UUID) -> None:
"""Periodically refresh the TTL for a session."""
session_key = self._session_key(session_id)
while True:
await lowlevel.checkpoint()
try:
# Refresh TTL at half the TTL interval to avoid expiration
await anyio.sleep(self._session_ttl / 2)
with anyio.CancelScope(shield=True):
await self._redis.expire(session_key, self._session_ttl) # type: ignore
except anyio.get_cancelled_exc_class():
break
except Exception as e:
logger.error(f"Error refreshing TTL for session {session_id}: {e}")
def _extract_session_id(self, channel: str) -> UUID | None:
"""Extract and validate session ID from channel."""
expected_prefix = f"{self._prefix}session:"
if not channel.startswith(expected_prefix):
return None
session_hex = channel[len(expected_prefix) :]
try:
session_id = UUID(hex=session_hex)
if channel != self._session_channel(session_id):
logger.error(f"Channel format mismatch: {channel}")
return None
return session_id
except ValueError:
logger.error(f"Invalid UUID in channel: {channel}")
return None
async def _listen_for_messages(self) -> None:
"""Background task that listens for messages on subscribed channels."""
async with self._limiter:
while True:
await lowlevel.checkpoint()
with CancelScope(shield=True):
message: None | dict[str, Any] = await self._pubsub.get_message( # type: ignore
ignore_subscribe_messages=True,
timeout=0.1, # type: ignore
)
if message is None:
continue
channel: str = cast(str, message["channel"])
session_id = self._extract_session_id(channel)
if session_id is None:
logger.debug(
f"Ignoring message from non-MCP channel: {channel}"
)
continue
data: str = cast(str, message["data"])
try:
if session_state := self._session_state.get(session_id):
session_state[1].start_soon(
self._handle_message, session_id, data
)
else:
logger.warning(
f"Message dropped: unknown session {session_id}"
)
except Exception as e:
logger.error(f"Error processing message for {session_id}: {e}")
async def _handle_message(self, session_id: UUID, data: str) -> None:
"""Process a message from Redis in the session's task group."""
if (session_state := self._session_state.get(session_id)) is None:
logger.warning(f"Message dropped: callback removed for {session_id}")
return
try:
# Parse message or pass validation error to callback
msg_or_error = None
try:
msg_or_error = SessionMessage.model_validate_json(data)
except ValidationError as exc:
msg_or_error = exc
await session_state[0](msg_or_error)
except Exception as e:
logger.error(f"Error in message handler for {session_id}: {e}")
async def publish_message(
self, session_id: UUID, message: SessionMessage | str
) -> bool:
"""Publish a message for the specified session."""
if not await self.session_exists(session_id):
logger.warning(f"Message dropped: unknown session {session_id}")
return False
# Pass raw JSON strings directly, preserving validation errors
if isinstance(message, str):
data = message
else:
data = message.model_dump_json()
channel = self._session_channel(session_id)
await self._redis.publish(channel, data) # type: ignore[attr-defined]
logger.debug(f"Message published to Redis channel for session {session_id}")
return True
async def session_exists(self, session_id: UUID) -> bool:
"""Check if a session exists."""
session_key = self._session_key(session_id)
return bool(await self._redis.exists(session_key)) # type: ignore

View File

@@ -52,9 +52,11 @@ from starlette.responses import Response
from starlette.types import Receive, Scope, Send
import mcp.types as types
from mcp.server.message_queue import InMemoryMessageDispatch, MessageDispatch
from mcp.shared.message import SessionMessage
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)
class SseServerTransport:
@@ -70,17 +72,24 @@ class SseServerTransport:
"""
_endpoint: str
_message_dispatch: MessageDispatch
_read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]]
def __init__(self, endpoint: str) -> None:
def __init__(
self, endpoint: str, message_dispatch: MessageDispatch | None = None
) -> None:
"""
Creates a new SSE server transport, which will direct the client to POST
messages to the relative or absolute URL given.
Args:
endpoint: The endpoint URL for SSE connections
message_dispatch: Optional message dispatch to use
"""
super().__init__()
self._endpoint = endpoint
self._read_stream_writers = {}
self._message_dispatch = message_dispatch or InMemoryMessageDispatch()
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")
@asynccontextmanager
@@ -101,7 +110,12 @@ class SseServerTransport:
session_id = uuid4()
session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}"
self._read_stream_writers[session_id] = read_stream_writer
async def message_callback(message: SessionMessage | Exception) -> None:
"""Callback that receives messages from the message queue"""
logger.debug(f"Got message from queue for session {session_id}")
await read_stream_writer.send(message)
logger.debug(f"Created new session with ID: {session_id}")
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
@@ -138,13 +152,16 @@ class SseServerTransport:
)(scope, receive, send)
await read_stream_writer.aclose()
await write_stream_reader.aclose()
await sse_stream_writer.aclose()
await sse_stream_reader.aclose()
logging.debug(f"Client session disconnected {session_id}")
logger.debug("Starting SSE response task")
tg.start_soon(response_wrapper, scope, receive, send)
logger.debug("Yielding read and write streams")
yield (read_stream, write_stream)
async with self._message_dispatch.subscribe(session_id, message_callback):
logger.debug("Yielding read and write streams")
yield (read_stream, write_stream)
async def handle_post_message(
self, scope: Scope, receive: Receive, send: Send
@@ -166,8 +183,7 @@ class SseServerTransport:
response = Response("Invalid session ID", status_code=400)
return await response(scope, receive, send)
writer = self._read_stream_writers.get(session_id)
if not writer:
if not await self._message_dispatch.session_exists(session_id):
logger.warning(f"Could not find session for ID: {session_id}")
response = Response("Could not find session", status_code=404)
return await response(scope, receive, send)
@@ -182,11 +198,15 @@ class SseServerTransport:
logger.error(f"Failed to parse message: {err}")
response = Response("Could not parse message", status_code=400)
await response(scope, receive, send)
await writer.send(err)
# Pass raw JSON string; receiver will recreate identical ValidationError
# when parsing the same invalid JSON
await self._message_dispatch.publish_message(session_id, body.decode())
return
session_message = SessionMessage(message)
logger.debug(f"Sending session message to writer: {session_message}")
logger.debug(f"Publishing message for session {session_id}: {message}")
response = Response("Accepted", status_code=202)
await response(scope, receive, send)
await writer.send(session_message)
await self._message_dispatch.publish_message(
session_id, SessionMessage(message=message)
)
logger.debug(f"Sending session message to writer: {message}")

View File

@@ -67,7 +67,7 @@ async def stdio_server(
await read_stream_writer.send(exc)
continue
session_message = SessionMessage(message)
session_message = SessionMessage(message=message)
await read_stream_writer.send(session_message)
except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint()

View File

@@ -398,7 +398,7 @@ class StreamableHTTPServerTransport:
await response(scope, receive, send)
# Process the message after sending the response
session_message = SessionMessage(message)
session_message = SessionMessage(message=message)
await writer.send(session_message)
return
@@ -413,7 +413,7 @@ class StreamableHTTPServerTransport:
if self.is_json_response_enabled:
# Process the message
session_message = SessionMessage(message)
session_message = SessionMessage(message=message)
await writer.send(session_message)
try:
# Process messages from the request-specific stream
@@ -512,7 +512,7 @@ class StreamableHTTPServerTransport:
async with anyio.create_task_group() as tg:
tg.start_soon(response, scope, receive, send)
# Then send the message to be processed by the server
session_message = SessionMessage(message)
session_message = SessionMessage(message=message)
await writer.send(session_message)
except Exception:
logger.exception("SSE response error")

View File

@@ -42,7 +42,7 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send):
await read_stream_writer.send(exc)
continue
session_message = SessionMessage(client_message)
session_message = SessionMessage(message=client_message)
await read_stream_writer.send(session_message)
except anyio.ClosedResourceError:
await websocket.close()