Wrap JSONRPC messages with SessionMessage for metadata support (#590)

This commit is contained in:
ihrpr
2025-05-02 14:29:00 +01:00
committed by GitHub
parent 3978c6e1b9
commit da0cf22355
22 changed files with 286 additions and 173 deletions

View File

@@ -11,8 +11,8 @@ import mcp.types as types
from mcp.client.session import ClientSession from mcp.client.session import ClientSession
from mcp.client.sse import sse_client from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.shared.message import SessionMessage
from mcp.shared.session import RequestResponder from mcp.shared.session import RequestResponder
from mcp.types import JSONRPCMessage
if not sys.warnoptions: if not sys.warnoptions:
import warnings import warnings
@@ -36,8 +36,8 @@ async def message_handler(
async def run_session( async def run_session(
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage], write_stream: MemoryObjectSendStream[SessionMessage],
client_info: types.Implementation | None = None, client_info: types.Implementation | None = None,
): ):
async with ClientSession( async with ClientSession(

View File

@@ -7,6 +7,7 @@ from pydantic import AnyUrl, TypeAdapter
import mcp.types as types import mcp.types as types
from mcp.shared.context import RequestContext from mcp.shared.context import RequestContext
from mcp.shared.message import SessionMessage
from mcp.shared.session import BaseSession, RequestResponder from mcp.shared.session import BaseSession, RequestResponder
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
@@ -92,8 +93,8 @@ class ClientSession(
): ):
def __init__( def __init__(
self, self,
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[types.JSONRPCMessage], write_stream: MemoryObjectSendStream[SessionMessage],
read_timeout_seconds: timedelta | None = None, read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None, sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None, list_roots_callback: ListRootsFnT | None = None,

View File

@@ -10,6 +10,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStre
from httpx_sse import aconnect_sse from httpx_sse import aconnect_sse
import mcp.types as types import mcp.types as types
from mcp.shared.message import SessionMessage
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -31,11 +32,11 @@ async def sse_client(
`sse_read_timeout` determines how long (in seconds) the client will wait for a new `sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`. event before disconnecting. All other HTTP operations are controlled by `timeout`.
""" """
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
write_stream: MemoryObjectSendStream[types.JSONRPCMessage] write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0) read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -97,7 +98,8 @@ async def sse_client(
await read_stream_writer.send(exc) await read_stream_writer.send(exc)
continue continue
await read_stream_writer.send(message) session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
case _: case _:
logger.warning( logger.warning(
f"Unknown SSE event: {sse.event}" f"Unknown SSE event: {sse.event}"
@@ -111,11 +113,13 @@ async def sse_client(
async def post_writer(endpoint_url: str): async def post_writer(endpoint_url: str):
try: try:
async with write_stream_reader: async with write_stream_reader:
async for message in write_stream_reader: async for session_message in write_stream_reader:
logger.debug(f"Sending client message: {message}") logger.debug(
f"Sending client message: {session_message}"
)
response = await client.post( response = await client.post(
endpoint_url, endpoint_url,
json=message.model_dump( json=session_message.message.model_dump(
by_alias=True, by_alias=True,
mode="json", mode="json",
exclude_none=True, exclude_none=True,

View File

@@ -11,6 +11,7 @@ from anyio.streams.text import TextReceiveStream
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import mcp.types as types import mcp.types as types
from mcp.shared.message import SessionMessage
from .win32 import ( from .win32 import (
create_windows_process, create_windows_process,
@@ -98,11 +99,11 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
Client transport for stdio: this will connect to a server by spawning a Client transport for stdio: this will connect to a server by spawning a
process and communicating with it over stdin/stdout. process and communicating with it over stdin/stdout.
""" """
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
write_stream: MemoryObjectSendStream[types.JSONRPCMessage] write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0) read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -143,7 +144,8 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
await read_stream_writer.send(exc) await read_stream_writer.send(exc)
continue continue
await read_stream_writer.send(message) session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
except anyio.ClosedResourceError: except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint() await anyio.lowlevel.checkpoint()
@@ -152,8 +154,10 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
try: try:
async with write_stream_reader: async with write_stream_reader:
async for message in write_stream_reader: async for session_message in write_stream_reader:
json = message.model_dump_json(by_alias=True, exclude_none=True) json = session_message.message.model_dump_json(
by_alias=True, exclude_none=True
)
await process.stdin.send( await process.stdin.send(
(json + "\n").encode( (json + "\n").encode(
encoding=server.encoding, encoding=server.encoding,

View File

@@ -15,6 +15,7 @@ import anyio
import httpx import httpx
from httpx_sse import EventSource, aconnect_sse from httpx_sse import EventSource, aconnect_sse
from mcp.shared.message import SessionMessage
from mcp.types import ( from mcp.types import (
ErrorData, ErrorData,
JSONRPCError, JSONRPCError,
@@ -52,10 +53,10 @@ async def streamablehttp_client(
""" """
read_stream_writer, read_stream = anyio.create_memory_object_stream[ read_stream_writer, read_stream = anyio.create_memory_object_stream[
JSONRPCMessage | Exception SessionMessage | Exception
](0) ](0)
write_stream, write_stream_reader = anyio.create_memory_object_stream[ write_stream, write_stream_reader = anyio.create_memory_object_stream[
JSONRPCMessage SessionMessage
](0) ](0)
async def get_stream(): async def get_stream():
@@ -86,7 +87,8 @@ async def streamablehttp_client(
try: try:
message = JSONRPCMessage.model_validate_json(sse.data) message = JSONRPCMessage.model_validate_json(sse.data)
logger.debug(f"GET message: {message}") logger.debug(f"GET message: {message}")
await read_stream_writer.send(message) session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
except Exception as exc: except Exception as exc:
logger.error(f"Error parsing GET message: {exc}") logger.error(f"Error parsing GET message: {exc}")
await read_stream_writer.send(exc) await read_stream_writer.send(exc)
@@ -100,7 +102,8 @@ async def streamablehttp_client(
nonlocal session_id nonlocal session_id
try: try:
async with write_stream_reader: async with write_stream_reader:
async for message in write_stream_reader: async for session_message in write_stream_reader:
message = session_message.message
# Add session ID to headers if we have one # Add session ID to headers if we have one
post_headers = request_headers.copy() post_headers = request_headers.copy()
if session_id: if session_id:
@@ -141,9 +144,10 @@ async def streamablehttp_client(
message="Session terminated", message="Session terminated",
), ),
) )
await read_stream_writer.send( session_message = SessionMessage(
JSONRPCMessage(jsonrpc_error) JSONRPCMessage(jsonrpc_error)
) )
await read_stream_writer.send(session_message)
continue continue
response.raise_for_status() response.raise_for_status()
@@ -163,7 +167,8 @@ async def streamablehttp_client(
json_message = JSONRPCMessage.model_validate_json( json_message = JSONRPCMessage.model_validate_json(
content content
) )
await read_stream_writer.send(json_message) session_message = SessionMessage(json_message)
await read_stream_writer.send(session_message)
except Exception as exc: except Exception as exc:
logger.error(f"Error parsing JSON response: {exc}") logger.error(f"Error parsing JSON response: {exc}")
await read_stream_writer.send(exc) await read_stream_writer.send(exc)
@@ -175,11 +180,15 @@ async def streamablehttp_client(
async for sse in event_source.aiter_sse(): async for sse in event_source.aiter_sse():
if sse.event == "message": if sse.event == "message":
try: try:
await read_stream_writer.send( message = (
JSONRPCMessage.model_validate_json( JSONRPCMessage.model_validate_json(
sse.data sse.data
) )
) )
session_message = SessionMessage(message)
await read_stream_writer.send(
session_message
)
except Exception as exc: except Exception as exc:
logger.exception("Error parsing message") logger.exception("Error parsing message")
await read_stream_writer.send(exc) await read_stream_writer.send(exc)

View File

@@ -10,6 +10,7 @@ from websockets.asyncio.client import connect as ws_connect
from websockets.typing import Subprotocol from websockets.typing import Subprotocol
import mcp.types as types import mcp.types as types
from mcp.shared.message import SessionMessage
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -19,8 +20,8 @@ async def websocket_client(
url: str, url: str,
) -> AsyncGenerator[ ) -> AsyncGenerator[
tuple[ tuple[
MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], MemoryObjectReceiveStream[SessionMessage | Exception],
MemoryObjectSendStream[types.JSONRPCMessage], MemoryObjectSendStream[SessionMessage],
], ],
None, None,
]: ]:
@@ -39,10 +40,10 @@ async def websocket_client(
# Create two in-memory streams: # Create two in-memory streams:
# - One for incoming messages (read_stream, written by ws_reader) # - One for incoming messages (read_stream, written by ws_reader)
# - One for outgoing messages (write_stream, read by ws_writer) # - One for outgoing messages (write_stream, read by ws_writer)
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
write_stream: MemoryObjectSendStream[types.JSONRPCMessage] write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0) read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -59,7 +60,8 @@ async def websocket_client(
async for raw_text in ws: async for raw_text in ws:
try: try:
message = types.JSONRPCMessage.model_validate_json(raw_text) message = types.JSONRPCMessage.model_validate_json(raw_text)
await read_stream_writer.send(message) session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
except ValidationError as exc: except ValidationError as exc:
# If JSON parse or model validation fails, send the exception # If JSON parse or model validation fails, send the exception
await read_stream_writer.send(exc) await read_stream_writer.send(exc)
@@ -70,9 +72,9 @@ async def websocket_client(
sends them to the server. sends them to the server.
""" """
async with write_stream_reader: async with write_stream_reader:
async for message in write_stream_reader: async for session_message in write_stream_reader:
# Convert to a dict, then to JSON # Convert to a dict, then to JSON
msg_dict = message.model_dump( msg_dict = session_message.message.model_dump(
by_alias=True, mode="json", exclude_none=True by_alias=True, mode="json", exclude_none=True
) )
await ws.send(json.dumps(msg_dict)) await ws.send(json.dumps(msg_dict))

View File

@@ -84,6 +84,7 @@ from mcp.server.session import ServerSession
from mcp.server.stdio import stdio_server as stdio_server from mcp.server.stdio import stdio_server as stdio_server
from mcp.shared.context import RequestContext from mcp.shared.context import RequestContext
from mcp.shared.exceptions import McpError from mcp.shared.exceptions import McpError
from mcp.shared.message import SessionMessage
from mcp.shared.session import RequestResponder from mcp.shared.session import RequestResponder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -471,8 +472,8 @@ class Server(Generic[LifespanResultT]):
async def run( async def run(
self, self,
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[types.JSONRPCMessage], write_stream: MemoryObjectSendStream[SessionMessage],
initialization_options: InitializationOptions, initialization_options: InitializationOptions,
# When False, exceptions are returned as messages to the client. # When False, exceptions are returned as messages to the client.
# When True, exceptions are raised, which will cause the server to shut down # When True, exceptions are raised, which will cause the server to shut down

View File

@@ -47,6 +47,7 @@ from pydantic import AnyUrl
import mcp.types as types import mcp.types as types
from mcp.server.models import InitializationOptions from mcp.server.models import InitializationOptions
from mcp.shared.message import SessionMessage
from mcp.shared.session import ( from mcp.shared.session import (
BaseSession, BaseSession,
RequestResponder, RequestResponder,
@@ -82,8 +83,8 @@ class ServerSession(
def __init__( def __init__(
self, self,
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[types.JSONRPCMessage], write_stream: MemoryObjectSendStream[SessionMessage],
init_options: InitializationOptions, init_options: InitializationOptions,
stateless: bool = False, stateless: bool = False,
) -> None: ) -> None:

View File

@@ -46,6 +46,7 @@ from starlette.responses import Response
from starlette.types import Receive, Scope, Send from starlette.types import Receive, Scope, Send
import mcp.types as types import mcp.types as types
from mcp.shared.message import SessionMessage
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -63,9 +64,7 @@ class SseServerTransport:
""" """
_endpoint: str _endpoint: str
_read_stream_writers: dict[ _read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]]
UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception]
]
def __init__(self, endpoint: str) -> None: def __init__(self, endpoint: str) -> None:
""" """
@@ -85,11 +84,11 @@ class SseServerTransport:
raise ValueError("connect_sse can only handle HTTP requests") raise ValueError("connect_sse can only handle HTTP requests")
logger.debug("Setting up SSE connection") logger.debug("Setting up SSE connection")
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
write_stream: MemoryObjectSendStream[types.JSONRPCMessage] write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0) read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -109,12 +108,12 @@ class SseServerTransport:
await sse_stream_writer.send({"event": "endpoint", "data": session_uri}) await sse_stream_writer.send({"event": "endpoint", "data": session_uri})
logger.debug(f"Sent endpoint event: {session_uri}") logger.debug(f"Sent endpoint event: {session_uri}")
async for message in write_stream_reader: async for session_message in write_stream_reader:
logger.debug(f"Sending message via SSE: {message}") logger.debug(f"Sending message via SSE: {session_message}")
await sse_stream_writer.send( await sse_stream_writer.send(
{ {
"event": "message", "event": "message",
"data": message.model_dump_json( "data": session_message.message.model_dump_json(
by_alias=True, exclude_none=True by_alias=True, exclude_none=True
), ),
} }
@@ -169,7 +168,8 @@ class SseServerTransport:
await writer.send(err) await writer.send(err)
return return
logger.debug(f"Sending message to writer: {message}") session_message = SessionMessage(message)
logger.debug(f"Sending session message to writer: {session_message}")
response = Response("Accepted", status_code=202) response = Response("Accepted", status_code=202)
await response(scope, receive, send) await response(scope, receive, send)
await writer.send(message) await writer.send(session_message)

View File

@@ -27,6 +27,7 @@ import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
import mcp.types as types import mcp.types as types
from mcp.shared.message import SessionMessage
@asynccontextmanager @asynccontextmanager
@@ -47,11 +48,11 @@ async def stdio_server(
if not stdout: if not stdout:
stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8")) stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8"))
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
write_stream: MemoryObjectSendStream[types.JSONRPCMessage] write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0) read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -66,15 +67,18 @@ async def stdio_server(
await read_stream_writer.send(exc) await read_stream_writer.send(exc)
continue continue
await read_stream_writer.send(message) session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
except anyio.ClosedResourceError: except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint() await anyio.lowlevel.checkpoint()
async def stdout_writer(): async def stdout_writer():
try: try:
async with write_stream_reader: async with write_stream_reader:
async for message in write_stream_reader: async for session_message in write_stream_reader:
json = message.model_dump_json(by_alias=True, exclude_none=True) json = session_message.message.model_dump_json(
by_alias=True, exclude_none=True
)
await stdout.write(json + "\n") await stdout.write(json + "\n")
await stdout.flush() await stdout.flush()
except anyio.ClosedResourceError: except anyio.ClosedResourceError:

View File

@@ -24,6 +24,7 @@ from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
from starlette.types import Receive, Scope, Send from starlette.types import Receive, Scope, Send
from mcp.shared.message import SessionMessage
from mcp.types import ( from mcp.types import (
INTERNAL_ERROR, INTERNAL_ERROR,
INVALID_PARAMS, INVALID_PARAMS,
@@ -125,10 +126,10 @@ class StreamableHTTPServerTransport:
""" """
# Server notification streams for POST requests as well as standalone SSE stream # Server notification streams for POST requests as well as standalone SSE stream
_read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] | None = ( _read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = (
None None
) )
_write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] | None = None _write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None
def __init__( def __init__(
self, self,
@@ -378,7 +379,8 @@ class StreamableHTTPServerTransport:
await response(scope, receive, send) await response(scope, receive, send)
# Process the message after sending the response # Process the message after sending the response
await writer.send(message) session_message = SessionMessage(message)
await writer.send(session_message)
return return
@@ -394,7 +396,8 @@ class StreamableHTTPServerTransport:
if self.is_json_response_enabled: if self.is_json_response_enabled:
# Process the message # Process the message
await writer.send(message) session_message = SessionMessage(message)
await writer.send(session_message)
try: try:
# Process messages from the request-specific stream # Process messages from the request-specific stream
# We need to collect all messages until we get a response # We need to collect all messages until we get a response
@@ -500,7 +503,8 @@ class StreamableHTTPServerTransport:
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
tg.start_soon(response, scope, receive, send) tg.start_soon(response, scope, receive, send)
# Then send the message to be processed by the server # Then send the message to be processed by the server
await writer.send(message) session_message = SessionMessage(message)
await writer.send(session_message)
except Exception: except Exception:
logger.exception("SSE response error") logger.exception("SSE response error")
# Clean up the request stream if something goes wrong # Clean up the request stream if something goes wrong
@@ -792,8 +796,8 @@ class StreamableHTTPServerTransport:
self, self,
) -> AsyncGenerator[ ) -> AsyncGenerator[
tuple[ tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception], MemoryObjectReceiveStream[SessionMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage], MemoryObjectSendStream[SessionMessage],
], ],
None, None,
]: ]:
@@ -806,10 +810,10 @@ class StreamableHTTPServerTransport:
# Create the memory streams for this connection # Create the memory streams for this connection
read_stream_writer, read_stream = anyio.create_memory_object_stream[ read_stream_writer, read_stream = anyio.create_memory_object_stream[
JSONRPCMessage | Exception SessionMessage | Exception
](0) ](0)
write_stream, write_stream_reader = anyio.create_memory_object_stream[ write_stream, write_stream_reader = anyio.create_memory_object_stream[
JSONRPCMessage SessionMessage
](0) ](0)
# Store the streams # Store the streams
@@ -821,8 +825,9 @@ class StreamableHTTPServerTransport:
# Create a message router that distributes messages to request streams # Create a message router that distributes messages to request streams
async def message_router(): async def message_router():
try: try:
async for message in write_stream_reader: async for session_message in write_stream_reader:
# Determine which request stream(s) should receive this message # Determine which request stream(s) should receive this message
message = session_message.message
target_request_id = None target_request_id = None
if isinstance( if isinstance(
message.root, JSONRPCNotification | JSONRPCRequest message.root, JSONRPCNotification | JSONRPCRequest

View File

@@ -8,6 +8,7 @@ from starlette.types import Receive, Scope, Send
from starlette.websockets import WebSocket from starlette.websockets import WebSocket
import mcp.types as types import mcp.types as types
from mcp.shared.message import SessionMessage
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -22,11 +23,11 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send):
websocket = WebSocket(scope, receive, send) websocket = WebSocket(scope, receive, send)
await websocket.accept(subprotocol="mcp") await websocket.accept(subprotocol="mcp")
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
write_stream: MemoryObjectSendStream[types.JSONRPCMessage] write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0) read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -41,15 +42,18 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send):
await read_stream_writer.send(exc) await read_stream_writer.send(exc)
continue continue
await read_stream_writer.send(client_message) session_message = SessionMessage(client_message)
await read_stream_writer.send(session_message)
except anyio.ClosedResourceError: except anyio.ClosedResourceError:
await websocket.close() await websocket.close()
async def ws_writer(): async def ws_writer():
try: try:
async with write_stream_reader: async with write_stream_reader:
async for message in write_stream_reader: async for session_message in write_stream_reader:
obj = message.model_dump_json(by_alias=True, exclude_none=True) obj = session_message.message.model_dump_json(
by_alias=True, exclude_none=True
)
await websocket.send_text(obj) await websocket.send_text(obj)
except anyio.ClosedResourceError: except anyio.ClosedResourceError:
await websocket.close() await websocket.close()

View File

@@ -19,11 +19,11 @@ from mcp.client.session import (
SamplingFnT, SamplingFnT,
) )
from mcp.server import Server from mcp.server import Server
from mcp.types import JSONRPCMessage from mcp.shared.message import SessionMessage
MessageStream = tuple[ MessageStream = tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception], MemoryObjectReceiveStream[SessionMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage], MemoryObjectSendStream[SessionMessage],
] ]
@@ -40,10 +40,10 @@ async def create_client_server_memory_streams() -> (
""" """
# Create streams for both directions # Create streams for both directions
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
JSONRPCMessage | Exception SessionMessage | Exception
](1) ](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
JSONRPCMessage | Exception SessionMessage | Exception
](1) ](1)
client_streams = (server_to_client_receive, client_to_server_send) client_streams = (server_to_client_receive, client_to_server_send)

35
src/mcp/shared/message.py Normal file
View File

@@ -0,0 +1,35 @@
"""
Message wrapper with metadata support.
This module defines a wrapper type that combines JSONRPCMessage with metadata
to support transport-specific features like resumability.
"""
from dataclasses import dataclass
from mcp.types import JSONRPCMessage, RequestId
@dataclass
class ClientMessageMetadata:
"""Metadata specific to client messages."""
resumption_token: str | None = None
@dataclass
class ServerMessageMetadata:
"""Metadata specific to server messages."""
related_request_id: RequestId | None = None
MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None
@dataclass
class SessionMessage:
"""A message with specific metadata for transport-specific features."""
message: JSONRPCMessage
metadata: MessageMetadata = None

View File

@@ -12,6 +12,7 @@ from pydantic import BaseModel
from typing_extensions import Self from typing_extensions import Self
from mcp.shared.exceptions import McpError from mcp.shared.exceptions import McpError
from mcp.shared.message import SessionMessage
from mcp.types import ( from mcp.types import (
CancelledNotification, CancelledNotification,
ClientNotification, ClientNotification,
@@ -172,8 +173,8 @@ class BaseSession(
def __init__( def __init__(
self, self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage], write_stream: MemoryObjectSendStream[SessionMessage],
receive_request_type: type[ReceiveRequestT], receive_request_type: type[ReceiveRequestT],
receive_notification_type: type[ReceiveNotificationT], receive_notification_type: type[ReceiveNotificationT],
# If none, reading will never time out # If none, reading will never time out
@@ -240,7 +241,9 @@ class BaseSession(
# TODO: Support progress callbacks # TODO: Support progress callbacks
await self._write_stream.send(JSONRPCMessage(jsonrpc_request)) await self._write_stream.send(
SessionMessage(message=JSONRPCMessage(jsonrpc_request))
)
# request read timeout takes precedence over session read timeout # request read timeout takes precedence over session read timeout
timeout = None timeout = None
@@ -300,14 +303,16 @@ class BaseSession(
jsonrpc="2.0", jsonrpc="2.0",
**notification.model_dump(by_alias=True, mode="json", exclude_none=True), **notification.model_dump(by_alias=True, mode="json", exclude_none=True),
) )
await self._write_stream.send(JSONRPCMessage(jsonrpc_notification)) session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_notification))
await self._write_stream.send(session_message)
async def _send_response( async def _send_response(
self, request_id: RequestId, response: SendResultT | ErrorData self, request_id: RequestId, response: SendResultT | ErrorData
) -> None: ) -> None:
if isinstance(response, ErrorData): if isinstance(response, ErrorData):
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
await self._write_stream.send(JSONRPCMessage(jsonrpc_error)) session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
await self._write_stream.send(session_message)
else: else:
jsonrpc_response = JSONRPCResponse( jsonrpc_response = JSONRPCResponse(
jsonrpc="2.0", jsonrpc="2.0",
@@ -316,7 +321,8 @@ class BaseSession(
by_alias=True, mode="json", exclude_none=True by_alias=True, mode="json", exclude_none=True
), ),
) )
await self._write_stream.send(JSONRPCMessage(jsonrpc_response)) session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
await self._write_stream.send(session_message)
async def _receive_loop(self) -> None: async def _receive_loop(self) -> None:
async with ( async with (
@@ -326,15 +332,15 @@ class BaseSession(
async for message in self._read_stream: async for message in self._read_stream:
if isinstance(message, Exception): if isinstance(message, Exception):
await self._handle_incoming(message) await self._handle_incoming(message)
elif isinstance(message.root, JSONRPCRequest): elif isinstance(message.message.root, JSONRPCRequest):
validated_request = self._receive_request_type.model_validate( validated_request = self._receive_request_type.model_validate(
message.root.model_dump( message.message.root.model_dump(
by_alias=True, mode="json", exclude_none=True by_alias=True, mode="json", exclude_none=True
) )
) )
responder = RequestResponder( responder = RequestResponder(
request_id=message.root.id, request_id=message.message.root.id,
request_meta=validated_request.root.params.meta request_meta=validated_request.root.params.meta
if validated_request.root.params if validated_request.root.params
else None, else None,
@@ -349,10 +355,10 @@ class BaseSession(
if not responder._completed: # type: ignore[reportPrivateUsage] if not responder._completed: # type: ignore[reportPrivateUsage]
await self._handle_incoming(responder) await self._handle_incoming(responder)
elif isinstance(message.root, JSONRPCNotification): elif isinstance(message.message.root, JSONRPCNotification):
try: try:
notification = self._receive_notification_type.model_validate( notification = self._receive_notification_type.model_validate(
message.root.model_dump( message.message.root.model_dump(
by_alias=True, mode="json", exclude_none=True by_alias=True, mode="json", exclude_none=True
) )
) )
@@ -368,12 +374,12 @@ class BaseSession(
# For other validation errors, log and continue # For other validation errors, log and continue
logging.warning( logging.warning(
f"Failed to validate notification: {e}. " f"Failed to validate notification: {e}. "
f"Message was: {message.root}" f"Message was: {message.message.root}"
) )
else: # Response or error else: # Response or error
stream = self._response_streams.pop(message.root.id, None) stream = self._response_streams.pop(message.message.root.id, None)
if stream: if stream:
await stream.send(message.root) await stream.send(message.message.root)
else: else:
await self._handle_incoming( await self._handle_incoming(
RuntimeError( RuntimeError(

View File

@@ -3,6 +3,7 @@ import pytest
import mcp.types as types import mcp.types as types
from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession
from mcp.shared.message import SessionMessage
from mcp.shared.session import RequestResponder from mcp.shared.session import RequestResponder
from mcp.types import ( from mcp.types import (
LATEST_PROTOCOL_VERSION, LATEST_PROTOCOL_VERSION,
@@ -24,10 +25,10 @@ from mcp.types import (
@pytest.mark.anyio @pytest.mark.anyio
async def test_client_session_initialize(): async def test_client_session_initialize():
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
JSONRPCMessage SessionMessage
](1) ](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
JSONRPCMessage SessionMessage
](1) ](1)
initialized_notification = None initialized_notification = None
@@ -35,7 +36,8 @@ async def test_client_session_initialize():
async def mock_server(): async def mock_server():
nonlocal initialized_notification nonlocal initialized_notification
jsonrpc_request = await client_to_server_receive.receive() session_message = await client_to_server_receive.receive()
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request.root, JSONRPCRequest) assert isinstance(jsonrpc_request.root, JSONRPCRequest)
request = ClientRequest.model_validate( request = ClientRequest.model_validate(
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
@@ -59,6 +61,7 @@ async def test_client_session_initialize():
async with server_to_client_send: async with server_to_client_send:
await server_to_client_send.send( await server_to_client_send.send(
SessionMessage(
JSONRPCMessage( JSONRPCMessage(
JSONRPCResponse( JSONRPCResponse(
jsonrpc="2.0", jsonrpc="2.0",
@@ -69,7 +72,9 @@ async def test_client_session_initialize():
) )
) )
) )
jsonrpc_notification = await client_to_server_receive.receive() )
session_notification = await client_to_server_receive.receive()
jsonrpc_notification = session_notification.message
assert isinstance(jsonrpc_notification.root, JSONRPCNotification) assert isinstance(jsonrpc_notification.root, JSONRPCNotification)
initialized_notification = ClientNotification.model_validate( initialized_notification = ClientNotification.model_validate(
jsonrpc_notification.model_dump( jsonrpc_notification.model_dump(
@@ -116,10 +121,10 @@ async def test_client_session_initialize():
@pytest.mark.anyio @pytest.mark.anyio
async def test_client_session_custom_client_info(): async def test_client_session_custom_client_info():
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
JSONRPCMessage SessionMessage
](1) ](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
JSONRPCMessage SessionMessage
](1) ](1)
custom_client_info = Implementation(name="test-client", version="1.2.3") custom_client_info = Implementation(name="test-client", version="1.2.3")
@@ -128,7 +133,8 @@ async def test_client_session_custom_client_info():
async def mock_server(): async def mock_server():
nonlocal received_client_info nonlocal received_client_info
jsonrpc_request = await client_to_server_receive.receive() session_message = await client_to_server_receive.receive()
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request.root, JSONRPCRequest) assert isinstance(jsonrpc_request.root, JSONRPCRequest)
request = ClientRequest.model_validate( request = ClientRequest.model_validate(
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
@@ -146,6 +152,7 @@ async def test_client_session_custom_client_info():
async with server_to_client_send: async with server_to_client_send:
await server_to_client_send.send( await server_to_client_send.send(
SessionMessage(
JSONRPCMessage( JSONRPCMessage(
JSONRPCResponse( JSONRPCResponse(
jsonrpc="2.0", jsonrpc="2.0",
@@ -156,6 +163,7 @@ async def test_client_session_custom_client_info():
) )
) )
) )
)
# Receive initialized notification # Receive initialized notification
await client_to_server_receive.receive() await client_to_server_receive.receive()
@@ -181,10 +189,10 @@ async def test_client_session_custom_client_info():
@pytest.mark.anyio @pytest.mark.anyio
async def test_client_session_default_client_info(): async def test_client_session_default_client_info():
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
JSONRPCMessage SessionMessage
](1) ](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
JSONRPCMessage SessionMessage
](1) ](1)
received_client_info = None received_client_info = None
@@ -192,7 +200,8 @@ async def test_client_session_default_client_info():
async def mock_server(): async def mock_server():
nonlocal received_client_info nonlocal received_client_info
jsonrpc_request = await client_to_server_receive.receive() session_message = await client_to_server_receive.receive()
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request.root, JSONRPCRequest) assert isinstance(jsonrpc_request.root, JSONRPCRequest)
request = ClientRequest.model_validate( request = ClientRequest.model_validate(
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
@@ -210,6 +219,7 @@ async def test_client_session_default_client_info():
async with server_to_client_send: async with server_to_client_send:
await server_to_client_send.send( await server_to_client_send.send(
SessionMessage(
JSONRPCMessage( JSONRPCMessage(
JSONRPCResponse( JSONRPCResponse(
jsonrpc="2.0", jsonrpc="2.0",
@@ -220,6 +230,7 @@ async def test_client_session_default_client_info():
) )
) )
) )
)
# Receive initialized notification # Receive initialized notification
await client_to_server_receive.receive() await client_to_server_receive.receive()

View File

@@ -3,6 +3,7 @@ import shutil
import pytest import pytest
from mcp.client.stdio import StdioServerParameters, stdio_client from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.shared.message import SessionMessage
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
tee: str = shutil.which("tee") # type: ignore tee: str = shutil.which("tee") # type: ignore
@@ -22,7 +23,8 @@ async def test_stdio_client():
async with write_stream: async with write_stream:
for message in messages: for message in messages:
await write_stream.send(message) session_message = SessionMessage(message)
await write_stream.send(session_message)
read_messages = [] read_messages = []
async with read_stream: async with read_stream:
@@ -30,7 +32,7 @@ async def test_stdio_client():
if isinstance(message, Exception): if isinstance(message, Exception):
raise message raise message
read_messages.append(message) read_messages.append(message.message)
if len(read_messages) == 2: if len(read_messages) == 2:
break break

View File

@@ -3,6 +3,7 @@ import pytest
from mcp.server.lowlevel import NotificationOptions, Server from mcp.server.lowlevel import NotificationOptions, Server
from mcp.server.models import InitializationOptions from mcp.server.models import InitializationOptions
from mcp.shared.message import SessionMessage
from mcp.types import ( from mcp.types import (
LATEST_PROTOCOL_VERSION, LATEST_PROTOCOL_VERSION,
ClientCapabilities, ClientCapabilities,
@@ -64,8 +65,10 @@ async def test_request_id_match() -> None:
jsonrpc="2.0", jsonrpc="2.0",
) )
await client_writer.send(JSONRPCMessage(root=init_req)) await client_writer.send(SessionMessage(JSONRPCMessage(root=init_req)))
await server_reader.receive() # Get init response but don't need to check it response = (
await server_reader.receive()
) # Get init response but don't need to check it
# Send initialized notification # Send initialized notification
initialized_notification = JSONRPCNotification( initialized_notification = JSONRPCNotification(
@@ -73,21 +76,23 @@ async def test_request_id_match() -> None:
params=NotificationParams().model_dump(by_alias=True, exclude_none=True), params=NotificationParams().model_dump(by_alias=True, exclude_none=True),
jsonrpc="2.0", jsonrpc="2.0",
) )
await client_writer.send(JSONRPCMessage(root=initialized_notification)) await client_writer.send(
SessionMessage(JSONRPCMessage(root=initialized_notification))
)
# Send ping request with custom ID # Send ping request with custom ID
ping_request = JSONRPCRequest( ping_request = JSONRPCRequest(
id=custom_request_id, method="ping", params={}, jsonrpc="2.0" id=custom_request_id, method="ping", params={}, jsonrpc="2.0"
) )
await client_writer.send(JSONRPCMessage(root=ping_request)) await client_writer.send(SessionMessage(JSONRPCMessage(root=ping_request)))
# Read response # Read response
response = await server_reader.receive() response = await server_reader.receive()
# Verify response ID matches request ID # Verify response ID matches request ID
assert ( assert (
response.root.id == custom_request_id response.message.root.id == custom_request_id
), "Response ID should match request ID" ), "Response ID should match request ID"
# Cancel server task # Cancel server task

View File

@@ -10,6 +10,7 @@ from pydantic import TypeAdapter
from mcp.server.fastmcp import Context, FastMCP from mcp.server.fastmcp import Context, FastMCP
from mcp.server.lowlevel.server import NotificationOptions, Server from mcp.server.lowlevel.server import NotificationOptions, Server
from mcp.server.models import InitializationOptions from mcp.server.models import InitializationOptions
from mcp.shared.message import SessionMessage
from mcp.types import ( from mcp.types import (
ClientCapabilities, ClientCapabilities,
Implementation, Implementation,
@@ -82,6 +83,7 @@ async def test_lowlevel_server_lifespan():
clientInfo=Implementation(name="test-client", version="0.1.0"), clientInfo=Implementation(name="test-client", version="0.1.0"),
) )
await send_stream1.send( await send_stream1.send(
SessionMessage(
JSONRPCMessage( JSONRPCMessage(
root=JSONRPCRequest( root=JSONRPCRequest(
jsonrpc="2.0", jsonrpc="2.0",
@@ -91,10 +93,13 @@ async def test_lowlevel_server_lifespan():
) )
) )
) )
)
response = await receive_stream2.receive() response = await receive_stream2.receive()
response = response.message
# Send initialized notification # Send initialized notification
await send_stream1.send( await send_stream1.send(
SessionMessage(
JSONRPCMessage( JSONRPCMessage(
root=JSONRPCNotification( root=JSONRPCNotification(
jsonrpc="2.0", jsonrpc="2.0",
@@ -102,9 +107,11 @@ async def test_lowlevel_server_lifespan():
) )
) )
) )
)
# Call the tool to verify lifespan context # Call the tool to verify lifespan context
await send_stream1.send( await send_stream1.send(
SessionMessage(
JSONRPCMessage( JSONRPCMessage(
root=JSONRPCRequest( root=JSONRPCRequest(
jsonrpc="2.0", jsonrpc="2.0",
@@ -114,9 +121,11 @@ async def test_lowlevel_server_lifespan():
) )
) )
) )
)
# Get response and verify # Get response and verify
response = await receive_stream2.receive() response = await receive_stream2.receive()
response = response.message
assert response.root.result["content"][0]["text"] == "true" assert response.root.result["content"][0]["text"] == "true"
# Cancel server task # Cancel server task
@@ -178,6 +187,7 @@ async def test_fastmcp_server_lifespan():
clientInfo=Implementation(name="test-client", version="0.1.0"), clientInfo=Implementation(name="test-client", version="0.1.0"),
) )
await send_stream1.send( await send_stream1.send(
SessionMessage(
JSONRPCMessage( JSONRPCMessage(
root=JSONRPCRequest( root=JSONRPCRequest(
jsonrpc="2.0", jsonrpc="2.0",
@@ -187,10 +197,13 @@ async def test_fastmcp_server_lifespan():
) )
) )
) )
)
response = await receive_stream2.receive() response = await receive_stream2.receive()
response = response.message
# Send initialized notification # Send initialized notification
await send_stream1.send( await send_stream1.send(
SessionMessage(
JSONRPCMessage( JSONRPCMessage(
root=JSONRPCNotification( root=JSONRPCNotification(
jsonrpc="2.0", jsonrpc="2.0",
@@ -198,9 +211,11 @@ async def test_fastmcp_server_lifespan():
) )
) )
) )
)
# Call the tool to verify lifespan context # Call the tool to verify lifespan context
await send_stream1.send( await send_stream1.send(
SessionMessage(
JSONRPCMessage( JSONRPCMessage(
root=JSONRPCRequest( root=JSONRPCRequest(
jsonrpc="2.0", jsonrpc="2.0",
@@ -210,9 +225,11 @@ async def test_fastmcp_server_lifespan():
) )
) )
) )
)
# Get response and verify # Get response and verify
response = await receive_stream2.receive() response = await receive_stream2.receive()
response = response.message
assert response.root.result["content"][0]["text"] == "true" assert response.root.result["content"][0]["text"] == "true"
# Cancel server task # Cancel server task

View File

@@ -8,10 +8,10 @@ from mcp.server import Server
from mcp.server.lowlevel import NotificationOptions from mcp.server.lowlevel import NotificationOptions
from mcp.server.models import InitializationOptions from mcp.server.models import InitializationOptions
from mcp.server.session import ServerSession from mcp.server.session import ServerSession
from mcp.shared.message import SessionMessage
from mcp.shared.session import RequestResponder from mcp.shared.session import RequestResponder
from mcp.types import ( from mcp.types import (
ClientResult, ClientResult,
JSONRPCMessage,
ServerNotification, ServerNotification,
ServerRequest, ServerRequest,
Tool, Tool,
@@ -46,10 +46,10 @@ async def test_lowlevel_server_tool_annotations():
] ]
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
JSONRPCMessage SessionMessage
](10) ](10)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
JSONRPCMessage SessionMessage
](10) ](10)
# Message handler for client # Message handler for client

View File

@@ -7,11 +7,11 @@ from mcp.server import Server
from mcp.server.lowlevel import NotificationOptions from mcp.server.lowlevel import NotificationOptions
from mcp.server.models import InitializationOptions from mcp.server.models import InitializationOptions
from mcp.server.session import ServerSession from mcp.server.session import ServerSession
from mcp.shared.message import SessionMessage
from mcp.shared.session import RequestResponder from mcp.shared.session import RequestResponder
from mcp.types import ( from mcp.types import (
ClientNotification, ClientNotification,
InitializedNotification, InitializedNotification,
JSONRPCMessage,
PromptsCapability, PromptsCapability,
ResourcesCapability, ResourcesCapability,
ServerCapabilities, ServerCapabilities,
@@ -21,10 +21,10 @@ from mcp.types import (
@pytest.mark.anyio @pytest.mark.anyio
async def test_server_session_initialize(): async def test_server_session_initialize():
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
JSONRPCMessage SessionMessage
](1) ](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
JSONRPCMessage SessionMessage
](1) ](1)
# Create a message handler to catch exceptions # Create a message handler to catch exceptions

View File

@@ -4,6 +4,7 @@ import anyio
import pytest import pytest
from mcp.server.stdio import stdio_server from mcp.server.stdio import stdio_server
from mcp.shared.message import SessionMessage
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
@@ -29,7 +30,7 @@ async def test_stdio_server():
async for message in read_stream: async for message in read_stream:
if isinstance(message, Exception): if isinstance(message, Exception):
raise message raise message
received_messages.append(message) received_messages.append(message.message)
if len(received_messages) == 2: if len(received_messages) == 2:
break break
@@ -50,7 +51,8 @@ async def test_stdio_server():
async with write_stream: async with write_stream:
for response in responses: for response in responses:
await write_stream.send(response) session_message = SessionMessage(response)
await write_stream.send(session_message)
stdout.seek(0) stdout.seek(0)
output_lines = stdout.readlines() output_lines = stdout.readlines()