mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Add message queue for SSE messages POST endpoint (#459)
This commit is contained in:
@@ -98,7 +98,9 @@ async def sse_client(
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
session_message = SessionMessage(message)
|
||||
session_message = SessionMessage(
|
||||
message=message
|
||||
)
|
||||
await read_stream_writer.send(session_message)
|
||||
case _:
|
||||
logger.warning(
|
||||
@@ -148,3 +150,5 @@ async def sse_client(
|
||||
finally:
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream.aclose()
|
||||
await read_stream.aclose()
|
||||
await write_stream_reader.aclose()
|
||||
|
||||
@@ -144,7 +144,7 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
session_message = SessionMessage(message)
|
||||
session_message = SessionMessage(message=message)
|
||||
await read_stream_writer.send(session_message)
|
||||
except anyio.ClosedResourceError:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
@@ -153,7 +153,7 @@ class StreamableHTTPTransport:
|
||||
):
|
||||
message.root.id = original_request_id
|
||||
|
||||
session_message = SessionMessage(message)
|
||||
session_message = SessionMessage(message=message)
|
||||
await read_stream_writer.send(session_message)
|
||||
|
||||
# Call resumption token callback if we have an ID
|
||||
@@ -286,7 +286,7 @@ class StreamableHTTPTransport:
|
||||
try:
|
||||
content = await response.aread()
|
||||
message = JSONRPCMessage.model_validate_json(content)
|
||||
session_message = SessionMessage(message)
|
||||
session_message = SessionMessage(message=message)
|
||||
await read_stream_writer.send(session_message)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error parsing JSON response: {exc}")
|
||||
@@ -333,7 +333,7 @@ class StreamableHTTPTransport:
|
||||
id=request_id,
|
||||
error=ErrorData(code=32600, message="Session terminated"),
|
||||
)
|
||||
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
|
||||
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
|
||||
await read_stream_writer.send(session_message)
|
||||
|
||||
async def post_writer(
|
||||
|
||||
@@ -60,7 +60,7 @@ async def websocket_client(
|
||||
async for raw_text in ws:
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json(raw_text)
|
||||
session_message = SessionMessage(message)
|
||||
session_message = SessionMessage(message=message)
|
||||
await read_stream_writer.send(session_message)
|
||||
except ValidationError as exc:
|
||||
# If JSON parse or model validation fails, send the exception
|
||||
|
||||
@@ -44,6 +44,7 @@ from mcp.server.lowlevel.helper_types import ReadResourceContents
|
||||
from mcp.server.lowlevel.server import LifespanResultT
|
||||
from mcp.server.lowlevel.server import Server as MCPServer
|
||||
from mcp.server.lowlevel.server import lifespan as default_lifespan
|
||||
from mcp.server.message_queue import MessageDispatch
|
||||
from mcp.server.session import ServerSession, ServerSessionT
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from mcp.server.stdio import stdio_server
|
||||
@@ -90,6 +91,11 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
|
||||
sse_path: str = "/sse"
|
||||
message_path: str = "/messages/"
|
||||
|
||||
# SSE message queue settings
|
||||
message_dispatch: MessageDispatch | None = Field(
|
||||
None, description="Custom message dispatch instance"
|
||||
)
|
||||
|
||||
# resource settings
|
||||
warn_on_duplicate_resources: bool = True
|
||||
|
||||
@@ -569,12 +575,21 @@ class FastMCP:
|
||||
|
||||
def sse_app(self) -> Starlette:
|
||||
"""Return an instance of the SSE server app."""
|
||||
message_dispatch = self.settings.message_dispatch
|
||||
if message_dispatch is None:
|
||||
from mcp.server.message_queue import InMemoryMessageDispatch
|
||||
|
||||
message_dispatch = InMemoryMessageDispatch()
|
||||
logger.info("Using default in-memory message dispatch")
|
||||
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.routing import Mount, Route
|
||||
|
||||
# Set up auth context and dependencies
|
||||
|
||||
sse = SseServerTransport(self.settings.message_path)
|
||||
sse = SseServerTransport(
|
||||
self.settings.message_path, message_dispatch=message_dispatch
|
||||
)
|
||||
|
||||
async def handle_sse(scope: Scope, receive: Receive, send: Send):
|
||||
# Add client ID from auth context into request context if available
|
||||
@@ -589,7 +604,14 @@ class FastMCP:
|
||||
streams[1],
|
||||
self._mcp_server.create_initialization_options(),
|
||||
)
|
||||
return Response()
|
||||
return Response()
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: Starlette):
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
await message_dispatch.close()
|
||||
|
||||
# Create routes
|
||||
routes: list[Route | Mount] = []
|
||||
@@ -666,7 +688,10 @@ class FastMCP:
|
||||
|
||||
# Create Starlette app with routes and middleware
|
||||
return Starlette(
|
||||
debug=self.settings.debug, routes=routes, middleware=middleware
|
||||
debug=self.settings.debug,
|
||||
routes=routes,
|
||||
middleware=middleware,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
async def list_prompts(self) -> list[MCPPrompt]:
|
||||
|
||||
16
src/mcp/server/message_queue/__init__.py
Normal file
16
src/mcp/server/message_queue/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Message Dispatch Module for MCP Server
|
||||
|
||||
This module implements dispatch interfaces for handling
|
||||
messages between clients and servers.
|
||||
"""
|
||||
|
||||
from mcp.server.message_queue.base import InMemoryMessageDispatch, MessageDispatch
|
||||
|
||||
# Try to import Redis implementation if available
|
||||
try:
|
||||
from mcp.server.message_queue.redis import RedisMessageDispatch
|
||||
except ImportError:
|
||||
RedisMessageDispatch = None
|
||||
|
||||
__all__ = ["MessageDispatch", "InMemoryMessageDispatch", "RedisMessageDispatch"]
|
||||
116
src/mcp/server/message_queue/base.py
Normal file
116
src/mcp/server/message_queue/base.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Protocol, runtime_checkable
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from mcp.shared.message import SessionMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MessageCallback = Callable[[SessionMessage | Exception], Awaitable[None]]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class MessageDispatch(Protocol):
|
||||
"""Abstract interface for SSE message dispatching.
|
||||
|
||||
This interface allows messages to be published to sessions and callbacks to be
|
||||
registered for message handling, enabling multiple servers to handle requests.
|
||||
"""
|
||||
|
||||
async def publish_message(
|
||||
self, session_id: UUID, message: SessionMessage | str
|
||||
) -> bool:
|
||||
"""Publish a message for the specified session.
|
||||
|
||||
Args:
|
||||
session_id: The UUID of the session this message is for
|
||||
message: The message to publish (SessionMessage or str for invalid JSON)
|
||||
|
||||
Returns:
|
||||
bool: True if message was published, False if session not found
|
||||
"""
|
||||
...
|
||||
|
||||
@asynccontextmanager
|
||||
async def subscribe(self, session_id: UUID, callback: MessageCallback):
|
||||
"""Request-scoped context manager that subscribes to messages for a session.
|
||||
|
||||
Args:
|
||||
session_id: The UUID of the session to subscribe to
|
||||
callback: Async callback function to handle messages for this session
|
||||
"""
|
||||
yield
|
||||
|
||||
async def session_exists(self, session_id: UUID) -> bool:
|
||||
"""Check if a session exists.
|
||||
|
||||
Args:
|
||||
session_id: The UUID of the session to check
|
||||
|
||||
Returns:
|
||||
bool: True if the session is active, False otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the message dispatch."""
|
||||
...
|
||||
|
||||
|
||||
class InMemoryMessageDispatch:
|
||||
"""Default in-memory implementation of the MessageDispatch interface.
|
||||
|
||||
This implementation immediately dispatches messages to registered callbacks when
|
||||
messages are received without any queuing behavior.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._callbacks: dict[UUID, MessageCallback] = {}
|
||||
|
||||
async def publish_message(
|
||||
self, session_id: UUID, message: SessionMessage | str
|
||||
) -> bool:
|
||||
"""Publish a message for the specified session."""
|
||||
if session_id not in self._callbacks:
|
||||
logger.warning(f"Message dropped: unknown session {session_id}")
|
||||
return False
|
||||
|
||||
# Parse string messages or recreate original ValidationError
|
||||
if isinstance(message, str):
|
||||
try:
|
||||
callback_argument = SessionMessage.model_validate_json(message)
|
||||
except ValidationError as exc:
|
||||
callback_argument = exc
|
||||
else:
|
||||
callback_argument = message
|
||||
|
||||
# Call the callback with either valid message or recreated ValidationError
|
||||
await self._callbacks[session_id](callback_argument)
|
||||
|
||||
logger.debug(f"Message dispatched to session {session_id}")
|
||||
return True
|
||||
|
||||
@asynccontextmanager
|
||||
async def subscribe(self, session_id: UUID, callback: MessageCallback):
|
||||
"""Request-scoped context manager that subscribes to messages for a session."""
|
||||
self._callbacks[session_id] = callback
|
||||
logger.debug(f"Subscribing to messages for session {session_id}")
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if session_id in self._callbacks:
|
||||
del self._callbacks[session_id]
|
||||
logger.debug(f"Unsubscribed from session {session_id}")
|
||||
|
||||
async def session_exists(self, session_id: UUID) -> bool:
|
||||
"""Check if a session exists."""
|
||||
return session_id in self._callbacks
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the message dispatch."""
|
||||
pass
|
||||
198
src/mcp/server/message_queue/redis.py
Normal file
198
src/mcp/server/message_queue/redis.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
import anyio
|
||||
from anyio import CancelScope, CapacityLimiter, lowlevel
|
||||
from anyio.abc import TaskGroup
|
||||
from pydantic import ValidationError
|
||||
|
||||
from mcp.server.message_queue.base import MessageCallback
|
||||
from mcp.shared.message import SessionMessage
|
||||
|
||||
try:
|
||||
import redis.asyncio as redis
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Redis support requires the 'redis' package. "
|
||||
"Install it with: 'uv add redis' or 'uv add \"mcp[redis]\"'"
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisMessageDispatch:
|
||||
"""Redis implementation of the MessageDispatch interface using pubsub.
|
||||
|
||||
This implementation uses Redis pubsub for real-time message distribution across
|
||||
multiple servers handling the same sessions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_url: str = "redis://localhost:6379/0",
|
||||
prefix: str = "mcp:pubsub:",
|
||||
session_ttl: int = 3600, # 1 hour default TTL for sessions
|
||||
) -> None:
|
||||
"""Initialize Redis message dispatch.
|
||||
|
||||
Args:
|
||||
redis_url: Redis connection string
|
||||
prefix: Key prefix for Redis channels to avoid collisions
|
||||
session_ttl: TTL in seconds for session keys (default: 1 hour)
|
||||
"""
|
||||
self._redis = redis.from_url(redis_url, decode_responses=True) # type: ignore
|
||||
self._pubsub = self._redis.pubsub(ignore_subscribe_messages=True) # type: ignore
|
||||
self._prefix = prefix
|
||||
self._session_ttl = session_ttl
|
||||
# Maps session IDs to the callback and task group for that SSE session.
|
||||
self._session_state: dict[UUID, tuple[MessageCallback, TaskGroup]] = {}
|
||||
# Ensures only one polling task runs at a time for message handling
|
||||
self._limiter = CapacityLimiter(1)
|
||||
logger.debug(f"Redis message dispatch initialized: {redis_url}")
|
||||
|
||||
async def close(self):
|
||||
await self._pubsub.aclose() # type: ignore
|
||||
await self._redis.aclose() # type: ignore
|
||||
|
||||
def _session_channel(self, session_id: UUID) -> str:
|
||||
"""Get the Redis channel for a session."""
|
||||
return f"{self._prefix}session:{session_id.hex}"
|
||||
|
||||
def _session_key(self, session_id: UUID) -> str:
|
||||
"""Get the Redis key for a session."""
|
||||
return f"{self._prefix}session_active:{session_id.hex}"
|
||||
|
||||
@asynccontextmanager
|
||||
async def subscribe(self, session_id: UUID, callback: MessageCallback):
|
||||
"""Request-scoped context manager that subscribes to messages for a session."""
|
||||
session_key = self._session_key(session_id)
|
||||
await self._redis.setex(session_key, self._session_ttl, "1") # type: ignore
|
||||
|
||||
channel = self._session_channel(session_id)
|
||||
await self._pubsub.subscribe(channel) # type: ignore
|
||||
|
||||
logger.debug(f"Subscribing to Redis channel for session {session_id}")
|
||||
async with anyio.create_task_group() as tg:
|
||||
self._session_state[session_id] = (callback, tg)
|
||||
tg.start_soon(self._listen_for_messages)
|
||||
# Start heartbeat for this session
|
||||
tg.start_soon(self._session_heartbeat, session_id)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
with anyio.CancelScope(shield=True):
|
||||
tg.cancel_scope.cancel()
|
||||
await self._pubsub.unsubscribe(channel) # type: ignore
|
||||
await self._redis.delete(session_key) # type: ignore
|
||||
del self._session_state[session_id]
|
||||
logger.debug(f"Unsubscribed from Redis channel: {session_id}")
|
||||
|
||||
async def _session_heartbeat(self, session_id: UUID) -> None:
|
||||
"""Periodically refresh the TTL for a session."""
|
||||
session_key = self._session_key(session_id)
|
||||
while True:
|
||||
await lowlevel.checkpoint()
|
||||
try:
|
||||
# Refresh TTL at half the TTL interval to avoid expiration
|
||||
await anyio.sleep(self._session_ttl / 2)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self._redis.expire(session_key, self._session_ttl) # type: ignore
|
||||
except anyio.get_cancelled_exc_class():
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing TTL for session {session_id}: {e}")
|
||||
|
||||
def _extract_session_id(self, channel: str) -> UUID | None:
|
||||
"""Extract and validate session ID from channel."""
|
||||
expected_prefix = f"{self._prefix}session:"
|
||||
if not channel.startswith(expected_prefix):
|
||||
return None
|
||||
|
||||
session_hex = channel[len(expected_prefix) :]
|
||||
try:
|
||||
session_id = UUID(hex=session_hex)
|
||||
if channel != self._session_channel(session_id):
|
||||
logger.error(f"Channel format mismatch: {channel}")
|
||||
return None
|
||||
return session_id
|
||||
except ValueError:
|
||||
logger.error(f"Invalid UUID in channel: {channel}")
|
||||
return None
|
||||
|
||||
async def _listen_for_messages(self) -> None:
|
||||
"""Background task that listens for messages on subscribed channels."""
|
||||
async with self._limiter:
|
||||
while True:
|
||||
await lowlevel.checkpoint()
|
||||
with CancelScope(shield=True):
|
||||
message: None | dict[str, Any] = await self._pubsub.get_message( # type: ignore
|
||||
ignore_subscribe_messages=True,
|
||||
timeout=0.1, # type: ignore
|
||||
)
|
||||
if message is None:
|
||||
continue
|
||||
|
||||
channel: str = cast(str, message["channel"])
|
||||
session_id = self._extract_session_id(channel)
|
||||
if session_id is None:
|
||||
logger.debug(
|
||||
f"Ignoring message from non-MCP channel: {channel}"
|
||||
)
|
||||
continue
|
||||
|
||||
data: str = cast(str, message["data"])
|
||||
try:
|
||||
if session_state := self._session_state.get(session_id):
|
||||
session_state[1].start_soon(
|
||||
self._handle_message, session_id, data
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Message dropped: unknown session {session_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message for {session_id}: {e}")
|
||||
|
||||
async def _handle_message(self, session_id: UUID, data: str) -> None:
|
||||
"""Process a message from Redis in the session's task group."""
|
||||
if (session_state := self._session_state.get(session_id)) is None:
|
||||
logger.warning(f"Message dropped: callback removed for {session_id}")
|
||||
return
|
||||
|
||||
try:
|
||||
# Parse message or pass validation error to callback
|
||||
msg_or_error = None
|
||||
try:
|
||||
msg_or_error = SessionMessage.model_validate_json(data)
|
||||
except ValidationError as exc:
|
||||
msg_or_error = exc
|
||||
|
||||
await session_state[0](msg_or_error)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in message handler for {session_id}: {e}")
|
||||
|
||||
async def publish_message(
|
||||
self, session_id: UUID, message: SessionMessage | str
|
||||
) -> bool:
|
||||
"""Publish a message for the specified session."""
|
||||
if not await self.session_exists(session_id):
|
||||
logger.warning(f"Message dropped: unknown session {session_id}")
|
||||
return False
|
||||
|
||||
# Pass raw JSON strings directly, preserving validation errors
|
||||
if isinstance(message, str):
|
||||
data = message
|
||||
else:
|
||||
data = message.model_dump_json()
|
||||
|
||||
channel = self._session_channel(session_id)
|
||||
await self._redis.publish(channel, data) # type: ignore[attr-defined]
|
||||
logger.debug(f"Message published to Redis channel for session {session_id}")
|
||||
return True
|
||||
|
||||
async def session_exists(self, session_id: UUID) -> bool:
|
||||
"""Check if a session exists."""
|
||||
session_key = self._session_key(session_id)
|
||||
return bool(await self._redis.exists(session_key)) # type: ignore
|
||||
@@ -52,9 +52,11 @@ from starlette.responses import Response
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.server.message_queue import InMemoryMessageDispatch, MessageDispatch
|
||||
from mcp.shared.message import SessionMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
||||
class SseServerTransport:
|
||||
@@ -70,17 +72,24 @@ class SseServerTransport:
|
||||
"""
|
||||
|
||||
_endpoint: str
|
||||
_message_dispatch: MessageDispatch
|
||||
_read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]]
|
||||
|
||||
def __init__(self, endpoint: str) -> None:
|
||||
def __init__(
|
||||
self, endpoint: str, message_dispatch: MessageDispatch | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Creates a new SSE server transport, which will direct the client to POST
|
||||
messages to the relative or absolute URL given.
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint URL for SSE connections
|
||||
message_dispatch: Optional message dispatch to use
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self._endpoint = endpoint
|
||||
self._read_stream_writers = {}
|
||||
self._message_dispatch = message_dispatch or InMemoryMessageDispatch()
|
||||
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -101,7 +110,12 @@ class SseServerTransport:
|
||||
|
||||
session_id = uuid4()
|
||||
session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}"
|
||||
self._read_stream_writers[session_id] = read_stream_writer
|
||||
|
||||
async def message_callback(message: SessionMessage | Exception) -> None:
|
||||
"""Callback that receives messages from the message queue"""
|
||||
logger.debug(f"Got message from queue for session {session_id}")
|
||||
await read_stream_writer.send(message)
|
||||
|
||||
logger.debug(f"Created new session with ID: {session_id}")
|
||||
|
||||
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
|
||||
@@ -138,13 +152,16 @@ class SseServerTransport:
|
||||
)(scope, receive, send)
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream_reader.aclose()
|
||||
await sse_stream_writer.aclose()
|
||||
await sse_stream_reader.aclose()
|
||||
logging.debug(f"Client session disconnected {session_id}")
|
||||
|
||||
logger.debug("Starting SSE response task")
|
||||
tg.start_soon(response_wrapper, scope, receive, send)
|
||||
|
||||
logger.debug("Yielding read and write streams")
|
||||
yield (read_stream, write_stream)
|
||||
async with self._message_dispatch.subscribe(session_id, message_callback):
|
||||
logger.debug("Yielding read and write streams")
|
||||
yield (read_stream, write_stream)
|
||||
|
||||
async def handle_post_message(
|
||||
self, scope: Scope, receive: Receive, send: Send
|
||||
@@ -166,8 +183,7 @@ class SseServerTransport:
|
||||
response = Response("Invalid session ID", status_code=400)
|
||||
return await response(scope, receive, send)
|
||||
|
||||
writer = self._read_stream_writers.get(session_id)
|
||||
if not writer:
|
||||
if not await self._message_dispatch.session_exists(session_id):
|
||||
logger.warning(f"Could not find session for ID: {session_id}")
|
||||
response = Response("Could not find session", status_code=404)
|
||||
return await response(scope, receive, send)
|
||||
@@ -182,11 +198,15 @@ class SseServerTransport:
|
||||
logger.error(f"Failed to parse message: {err}")
|
||||
response = Response("Could not parse message", status_code=400)
|
||||
await response(scope, receive, send)
|
||||
await writer.send(err)
|
||||
# Pass raw JSON string; receiver will recreate identical ValidationError
|
||||
# when parsing the same invalid JSON
|
||||
await self._message_dispatch.publish_message(session_id, body.decode())
|
||||
return
|
||||
|
||||
session_message = SessionMessage(message)
|
||||
logger.debug(f"Sending session message to writer: {session_message}")
|
||||
logger.debug(f"Publishing message for session {session_id}: {message}")
|
||||
response = Response("Accepted", status_code=202)
|
||||
await response(scope, receive, send)
|
||||
await writer.send(session_message)
|
||||
await self._message_dispatch.publish_message(
|
||||
session_id, SessionMessage(message=message)
|
||||
)
|
||||
logger.debug(f"Sending session message to writer: {message}")
|
||||
|
||||
@@ -67,7 +67,7 @@ async def stdio_server(
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
session_message = SessionMessage(message)
|
||||
session_message = SessionMessage(message=message)
|
||||
await read_stream_writer.send(session_message)
|
||||
except anyio.ClosedResourceError:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
@@ -398,7 +398,7 @@ class StreamableHTTPServerTransport:
|
||||
await response(scope, receive, send)
|
||||
|
||||
# Process the message after sending the response
|
||||
session_message = SessionMessage(message)
|
||||
session_message = SessionMessage(message=message)
|
||||
await writer.send(session_message)
|
||||
|
||||
return
|
||||
@@ -413,7 +413,7 @@ class StreamableHTTPServerTransport:
|
||||
|
||||
if self.is_json_response_enabled:
|
||||
# Process the message
|
||||
session_message = SessionMessage(message)
|
||||
session_message = SessionMessage(message=message)
|
||||
await writer.send(session_message)
|
||||
try:
|
||||
# Process messages from the request-specific stream
|
||||
@@ -512,7 +512,7 @@ class StreamableHTTPServerTransport:
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(response, scope, receive, send)
|
||||
# Then send the message to be processed by the server
|
||||
session_message = SessionMessage(message)
|
||||
session_message = SessionMessage(message=message)
|
||||
await writer.send(session_message)
|
||||
except Exception:
|
||||
logger.exception("SSE response error")
|
||||
|
||||
@@ -42,7 +42,7 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send):
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
session_message = SessionMessage(client_message)
|
||||
session_message = SessionMessage(message=client_message)
|
||||
await read_stream_writer.send(session_message)
|
||||
except anyio.ClosedResourceError:
|
||||
await websocket.close()
|
||||
|
||||
@@ -6,7 +6,8 @@ to support transport-specific features like resumability.
|
||||
"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from mcp.types import JSONRPCMessage, RequestId
|
||||
|
||||
@@ -15,8 +16,7 @@ ResumptionToken = str
|
||||
ResumptionTokenUpdateCallback = Callable[[ResumptionToken], Awaitable[None]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClientMessageMetadata:
|
||||
class ClientMessageMetadata(BaseModel):
|
||||
"""Metadata specific to client messages."""
|
||||
|
||||
resumption_token: ResumptionToken | None = None
|
||||
@@ -25,8 +25,7 @@ class ClientMessageMetadata:
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerMessageMetadata:
|
||||
class ServerMessageMetadata(BaseModel):
|
||||
"""Metadata specific to server messages."""
|
||||
|
||||
related_request_id: RequestId | None = None
|
||||
@@ -35,9 +34,8 @@ class ServerMessageMetadata:
|
||||
MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionMessage:
|
||||
class SessionMessage(BaseModel):
|
||||
"""A message with specific metadata for transport-specific features."""
|
||||
|
||||
message: JSONRPCMessage
|
||||
metadata: MessageMetadata = None
|
||||
metadata: MessageMetadata | None = None
|
||||
|
||||
Reference in New Issue
Block a user