Revert "refactor: reorganize message handling for better type safety and clar…" (#282)

This reverts commit 9d0f2daddb.
This commit is contained in:
Marcelo Trylesinski
2025-03-14 10:50:46 +01:00
committed by GitHub
parent ebb81d3b2b
commit 7196604468
17 changed files with 151 additions and 283 deletions

View File

@@ -1,11 +1,12 @@
from datetime import timedelta from datetime import timedelta
from typing import Any, Protocol from typing import Any, Protocol
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl, TypeAdapter from pydantic import AnyUrl, TypeAdapter
import mcp.types as types import mcp.types as types
from mcp.shared.context import RequestContext from mcp.shared.context import RequestContext
from mcp.shared.session import BaseSession, ReadStream, RequestResponder, WriteStream from mcp.shared.session import BaseSession, RequestResponder
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
@@ -58,8 +59,8 @@ class ClientSession(
): ):
def __init__( def __init__(
self, self,
read_stream: ReadStream, read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: WriteStream, write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
read_timeout_seconds: timedelta | None = None, read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None, sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None, list_roots_callback: ListRootsFnT | None = None,

View File

@@ -6,16 +6,10 @@ from urllib.parse import urljoin, urlparse
import anyio import anyio
import httpx import httpx
from anyio.abc import TaskStatus from anyio.abc import TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import aconnect_sse from httpx_sse import aconnect_sse
import mcp.types as types import mcp.types as types
from mcp.shared.session import (
ReadStream,
ReadStreamWriter,
WriteStream,
WriteStreamReader,
)
from mcp.types import MessageFrame
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -37,11 +31,11 @@ async def sse_client(
`sse_read_timeout` determines how long (in seconds) the client will wait for a new `sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`. event before disconnecting. All other HTTP operations are controlled by `timeout`.
""" """
read_stream: ReadStream read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: ReadStreamWriter read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: WriteStream write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: WriteStreamReader write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0) read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -90,11 +84,8 @@ async def sse_client(
case "message": case "message":
try: try:
message = MessageFrame(
message = types.JSONRPCMessage.model_validate_json( # noqa: E501 message = types.JSONRPCMessage.model_validate_json( # noqa: E501
sse.data sse.data
),
raw=sse,
) )
logger.debug( logger.debug(
f"Received server message: {message}" f"Received server message: {message}"

View File

@@ -1,7 +1,7 @@
import json import json
import logging import logging
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator from typing import AsyncGenerator
import anyio import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
@@ -10,7 +10,6 @@ from websockets.asyncio.client import connect as ws_connect
from websockets.typing import Subprotocol from websockets.typing import Subprotocol
import mcp.types as types import mcp.types as types
from mcp.types import MessageFrame
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -20,8 +19,8 @@ async def websocket_client(
url: str, url: str,
) -> AsyncGenerator[ ) -> AsyncGenerator[
tuple[ tuple[
MemoryObjectReceiveStream[MessageFrame[Any] | Exception], MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
MemoryObjectSendStream[MessageFrame[Any]], MemoryObjectSendStream[types.JSONRPCMessage],
], ],
None, None,
]: ]:
@@ -54,11 +53,7 @@ async def websocket_client(
async with read_stream_writer: async with read_stream_writer:
async for raw_text in ws: async for raw_text in ws:
try: try:
json_message = types.JSONRPCMessage.model_validate_json( message = types.JSONRPCMessage.model_validate_json(raw_text)
raw_text
)
# Create MessageFrame with JSON message as root
message = MessageFrame(message=json_message, raw=raw_text)
await read_stream_writer.send(message) await read_stream_writer.send(message)
except ValidationError as exc: except ValidationError as exc:
# If JSON parse or model validation fails, send the exception # If JSON parse or model validation fails, send the exception
@@ -71,8 +66,8 @@ async def websocket_client(
""" """
async with write_stream_reader: async with write_stream_reader:
async for message in write_stream_reader: async for message in write_stream_reader:
# Extract the JSON-RPC message from MessageFrame and convert to JSON # Convert to a dict, then to JSON
msg_dict = message.message.model_dump( msg_dict = message.model_dump(
by_alias=True, mode="json", exclude_none=True by_alias=True, mode="json", exclude_none=True
) )
await ws.send(json.dumps(msg_dict)) await ws.send(json.dumps(msg_dict))

View File

@@ -74,6 +74,7 @@ from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontext
from typing import Any, AsyncIterator, Generic, TypeVar from typing import Any, AsyncIterator, Generic, TypeVar
import anyio import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl from pydantic import AnyUrl
import mcp.types as types import mcp.types as types
@@ -83,7 +84,7 @@ from mcp.server.session import ServerSession
from mcp.server.stdio import stdio_server as stdio_server from mcp.server.stdio import stdio_server as stdio_server
from mcp.shared.context import RequestContext from mcp.shared.context import RequestContext
from mcp.shared.exceptions import McpError from mcp.shared.exceptions import McpError
from mcp.shared.session import ReadStream, RequestResponder, WriteStream from mcp.shared.session import RequestResponder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -473,8 +474,8 @@ class Server(Generic[LifespanResultT]):
async def run( async def run(
self, self,
read_stream: ReadStream, read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: WriteStream, write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
initialization_options: InitializationOptions, initialization_options: InitializationOptions,
# When False, exceptions are returned as messages to the client. # When False, exceptions are returned as messages to the client.
# When True, exceptions are raised, which will cause the server to shut down # When True, exceptions are raised, which will cause the server to shut down

View File

@@ -5,7 +5,9 @@ and tools.
from pydantic import BaseModel from pydantic import BaseModel
from mcp.types import ServerCapabilities from mcp.types import (
ServerCapabilities,
)
class InitializationOptions(BaseModel): class InitializationOptions(BaseModel):

View File

@@ -42,15 +42,14 @@ from typing import Any, TypeVar
import anyio import anyio
import anyio.lowlevel import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl from pydantic import AnyUrl
import mcp.types as types import mcp.types as types
from mcp.server.models import InitializationOptions from mcp.server.models import InitializationOptions
from mcp.shared.session import ( from mcp.shared.session import (
BaseSession, BaseSession,
ReadStream,
RequestResponder, RequestResponder,
WriteStream,
) )
@@ -77,8 +76,8 @@ class ServerSession(
def __init__( def __init__(
self, self,
read_stream: ReadStream, read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: WriteStream, write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
init_options: InitializationOptions, init_options: InitializationOptions,
) -> None: ) -> None:
super().__init__( super().__init__(

View File

@@ -38,6 +38,7 @@ from urllib.parse import quote
from uuid import UUID, uuid4 from uuid import UUID, uuid4
import anyio import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import ValidationError from pydantic import ValidationError
from sse_starlette import EventSourceResponse from sse_starlette import EventSourceResponse
from starlette.requests import Request from starlette.requests import Request
@@ -45,13 +46,6 @@ from starlette.responses import Response
from starlette.types import Receive, Scope, Send from starlette.types import Receive, Scope, Send
import mcp.types as types import mcp.types as types
from mcp.shared.session import (
ReadStream,
ReadStreamWriter,
WriteStream,
WriteStreamReader,
)
from mcp.types import MessageFrame
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -69,7 +63,9 @@ class SseServerTransport:
""" """
_endpoint: str _endpoint: str
_read_stream_writers: dict[UUID, ReadStreamWriter] _read_stream_writers: dict[
UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception]
]
def __init__(self, endpoint: str) -> None: def __init__(self, endpoint: str) -> None:
""" """
@@ -89,11 +85,11 @@ class SseServerTransport:
raise ValueError("connect_sse can only handle HTTP requests") raise ValueError("connect_sse can only handle HTTP requests")
logger.debug("Setting up SSE connection") logger.debug("Setting up SSE connection")
read_stream: ReadStream read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: ReadStreamWriter read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: WriteStream write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: WriteStreamReader write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0) read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -176,4 +172,4 @@ class SseServerTransport:
logger.debug(f"Sending message to writer: {message}") logger.debug(f"Sending message to writer: {message}")
response = Response("Accepted", status_code=202) response = Response("Accepted", status_code=202)
await response(scope, receive, send) await response(scope, receive, send)
await writer.send(MessageFrame(message=message, raw=request)) await writer.send(message)

View File

@@ -24,15 +24,9 @@ from io import TextIOWrapper
import anyio import anyio
import anyio.lowlevel import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
import mcp.types as types import mcp.types as types
from mcp.shared.session import (
ReadStream,
ReadStreamWriter,
WriteStream,
WriteStreamReader,
)
from mcp.types import MessageFrame
@asynccontextmanager @asynccontextmanager
@@ -53,11 +47,11 @@ async def stdio_server(
if not stdout: if not stdout:
stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8")) stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8"))
read_stream: ReadStream read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: ReadStreamWriter read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: WriteStream write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: WriteStreamReader write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0) read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -72,9 +66,7 @@ async def stdio_server(
await read_stream_writer.send(exc) await read_stream_writer.send(exc)
continue continue
await read_stream_writer.send( await read_stream_writer.send(message)
MessageFrame(message=message, raw=line)
)
except anyio.ClosedResourceError: except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint() await anyio.lowlevel.checkpoint()
@@ -82,7 +74,6 @@ async def stdio_server(
try: try:
async with write_stream_reader: async with write_stream_reader:
async for message in write_stream_reader: async for message in write_stream_reader:
# Extract the inner JSONRPCRequest/JSONRPCResponse from MessageFrame
json = message.model_dump_json(by_alias=True, exclude_none=True) json = message.model_dump_json(by_alias=True, exclude_none=True)
await stdout.write(json + "\n") await stdout.write(json + "\n")
await stdout.flush() await stdout.flush()

View File

@@ -2,17 +2,11 @@ import logging
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import anyio import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from starlette.types import Receive, Scope, Send from starlette.types import Receive, Scope, Send
from starlette.websockets import WebSocket from starlette.websockets import WebSocket
import mcp.types as types import mcp.types as types
from mcp.shared.session import (
ReadStream,
ReadStreamWriter,
WriteStream,
WriteStreamReader,
)
from mcp.types import MessageFrame
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -27,11 +21,11 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send):
websocket = WebSocket(scope, receive, send) websocket = WebSocket(scope, receive, send)
await websocket.accept(subprotocol="mcp") await websocket.accept(subprotocol="mcp")
read_stream: ReadStream read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: ReadStreamWriter read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: WriteStream write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: WriteStreamReader write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0) read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -46,9 +40,7 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send):
await read_stream_writer.send(exc) await read_stream_writer.send(exc)
continue continue
await read_stream_writer.send( await read_stream_writer.send(client_message)
MessageFrame(message=client_message, raw=message)
)
except anyio.ClosedResourceError: except anyio.ClosedResourceError:
await websocket.close() await websocket.close()

View File

@@ -11,11 +11,11 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStre
from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT
from mcp.server import Server from mcp.server import Server
from mcp.types import MessageFrame from mcp.types import JSONRPCMessage
MessageStream = tuple[ MessageStream = tuple[
MemoryObjectReceiveStream[MessageFrame | Exception], MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[MessageFrame], MemoryObjectSendStream[JSONRPCMessage],
] ]
@@ -32,10 +32,10 @@ async def create_client_server_memory_streams() -> (
""" """
# Create streams for both directions # Create streams for both directions
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
MessageFrame | Exception JSONRPCMessage | Exception
](1) ](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
MessageFrame | Exception JSONRPCMessage | Exception
](1) ](1)
client_streams = (server_to_client_receive, client_to_server_send) client_streams = (server_to_client_receive, client_to_server_send)
@@ -60,9 +60,12 @@ async def create_connected_server_and_client_session(
) -> AsyncGenerator[ClientSession, None]: ) -> AsyncGenerator[ClientSession, None]:
"""Creates a ClientSession that is connected to a running MCP server.""" """Creates a ClientSession that is connected to a running MCP server."""
async with create_client_server_memory_streams() as ( async with create_client_server_memory_streams() as (
(client_read, client_write), client_streams,
(server_read, server_write), server_streams,
): ):
client_read, client_write = client_streams
server_read, server_write = server_streams
# Create a cancel scope for the server task # Create a cancel scope for the server task
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
tg.start_soon( tg.start_soon(

View File

@@ -22,18 +22,12 @@ from mcp.types import (
JSONRPCNotification, JSONRPCNotification,
JSONRPCRequest, JSONRPCRequest,
JSONRPCResponse, JSONRPCResponse,
MessageFrame,
RequestParams, RequestParams,
ServerNotification, ServerNotification,
ServerRequest, ServerRequest,
ServerResult, ServerResult,
) )
ReadStream = MemoryObjectReceiveStream[MessageFrame | Exception]
ReadStreamWriter = MemoryObjectSendStream[MessageFrame | Exception]
WriteStream = MemoryObjectSendStream[MessageFrame]
WriteStreamReader = MemoryObjectReceiveStream[MessageFrame]
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest) SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult) SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification) SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
@@ -171,8 +165,8 @@ class BaseSession(
def __init__( def __init__(
self, self,
read_stream: ReadStream, read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: WriteStream, write_stream: MemoryObjectSendStream[JSONRPCMessage],
receive_request_type: type[ReceiveRequestT], receive_request_type: type[ReceiveRequestT],
receive_notification_type: type[ReceiveNotificationT], receive_notification_type: type[ReceiveNotificationT],
# If none, reading will never time out # If none, reading will never time out
@@ -248,9 +242,7 @@ class BaseSession(
# TODO: Support progress callbacks # TODO: Support progress callbacks
await self._write_stream.send( await self._write_stream.send(JSONRPCMessage(jsonrpc_request))
MessageFrame(message=JSONRPCMessage(jsonrpc_request), raw=None)
)
try: try:
with anyio.fail_after( with anyio.fail_after(
@@ -286,18 +278,14 @@ class BaseSession(
**notification.model_dump(by_alias=True, mode="json", exclude_none=True), **notification.model_dump(by_alias=True, mode="json", exclude_none=True),
) )
await self._write_stream.send( await self._write_stream.send(JSONRPCMessage(jsonrpc_notification))
MessageFrame(message=JSONRPCMessage(jsonrpc_notification), raw=None)
)
async def _send_response( async def _send_response(
self, request_id: RequestId, response: SendResultT | ErrorData self, request_id: RequestId, response: SendResultT | ErrorData
) -> None: ) -> None:
if isinstance(response, ErrorData): if isinstance(response, ErrorData):
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
await self._write_stream.send( await self._write_stream.send(JSONRPCMessage(jsonrpc_error))
MessageFrame(message=JSONRPCMessage(jsonrpc_error), raw=None)
)
else: else:
jsonrpc_response = JSONRPCResponse( jsonrpc_response = JSONRPCResponse(
jsonrpc="2.0", jsonrpc="2.0",
@@ -306,9 +294,7 @@ class BaseSession(
by_alias=True, mode="json", exclude_none=True by_alias=True, mode="json", exclude_none=True
), ),
) )
await self._write_stream.send( await self._write_stream.send(JSONRPCMessage(jsonrpc_response))
MessageFrame(message=JSONRPCMessage(jsonrpc_response), raw=None)
)
async def _receive_loop(self) -> None: async def _receive_loop(self) -> None:
async with ( async with (
@@ -316,13 +302,10 @@ class BaseSession(
self._write_stream, self._write_stream,
self._incoming_message_stream_writer, self._incoming_message_stream_writer,
): ):
async for raw_message in self._read_stream: async for message in self._read_stream:
if isinstance(raw_message, Exception): if isinstance(message, Exception):
await self._incoming_message_stream_writer.send(raw_message) await self._incoming_message_stream_writer.send(message)
continue elif isinstance(message.root, JSONRPCRequest):
message = raw_message.message
if isinstance(message.root, JSONRPCRequest):
validated_request = self._receive_request_type.model_validate( validated_request = self._receive_request_type.model_validate(
message.root.model_dump( message.root.model_dump(
by_alias=True, mode="json", exclude_none=True by_alias=True, mode="json", exclude_none=True

View File

@@ -180,49 +180,6 @@ class JSONRPCMessage(
pass pass
RawT = TypeVar("RawT")
class MessageFrame(BaseModel, Generic[RawT]):
"""
A wrapper around the general message received that contains both the parsed message
and the raw message.
This class serves as an encapsulation for JSON-RPC messages, providing access to
both the parsed structure (root) and the original raw data. This design is
particularly useful for Server-Sent Events (SSE) consumers who may need to access
additional metadata or headers associated with the message.
The 'root' attribute contains the parsed JSONRPCMessage, which could be a request,
notification, response, or error. The 'raw' attribute preserves the original
message as received, allowing access to any additional context or metadata that
might be lost in parsing.
This dual representation allows for flexible handling of messages, where consumers
can work with the structured data for standard operations, but still have the
option to examine or utilize the raw data when needed, such as for debugging,
logging, or accessing transport-specific information.
"""
message: JSONRPCMessage
raw: RawT | None = None
model_config = ConfigDict(extra="allow")
def model_dump(self, *args, **kwargs):
"""
Dumps the model to a dictionary, delegating to the root JSONRPCMessage.
This method allows for consistent serialization of the parsed message.
"""
return self.message.model_dump(*args, **kwargs)
def model_dump_json(self, *args, **kwargs):
"""
Dumps the model to a JSON string, delegating to the root JSONRPCMessage.
This method provides a convenient way to serialize the parsed message to JSON.
"""
return self.message.model_dump_json(*args, **kwargs)
class EmptyResult(Result): class EmptyResult(Result):
"""A response that indicates success but carries no data.""" """A response that indicates success but carries no data."""

View File

@@ -1,5 +1,3 @@
from types import NoneType
import anyio import anyio
import pytest import pytest
@@ -13,9 +11,9 @@ from mcp.types import (
InitializeRequest, InitializeRequest,
InitializeResult, InitializeResult,
JSONRPCMessage, JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest, JSONRPCRequest,
JSONRPCResponse, JSONRPCResponse,
MessageFrame,
ServerCapabilities, ServerCapabilities,
ServerResult, ServerResult,
) )
@@ -24,10 +22,10 @@ from mcp.types import (
@pytest.mark.anyio @pytest.mark.anyio
async def test_client_session_initialize(): async def test_client_session_initialize():
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
MessageFrame[NoneType] JSONRPCMessage
](1) ](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
MessageFrame[NoneType] JSONRPCMessage
](1) ](1)
initialized_notification = None initialized_notification = None
@@ -36,7 +34,7 @@ async def test_client_session_initialize():
nonlocal initialized_notification nonlocal initialized_notification
jsonrpc_request = await client_to_server_receive.receive() jsonrpc_request = await client_to_server_receive.receive()
assert isinstance(jsonrpc_request, MessageFrame) assert isinstance(jsonrpc_request.root, JSONRPCRequest)
request = ClientRequest.model_validate( request = ClientRequest.model_validate(
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
) )
@@ -58,25 +56,21 @@ async def test_client_session_initialize():
) )
async with server_to_client_send: async with server_to_client_send:
assert isinstance(jsonrpc_request.message.root, JSONRPCRequest)
await server_to_client_send.send( await server_to_client_send.send(
MessageFrame( JSONRPCMessage(
message=JSONRPCMessage(
JSONRPCResponse( JSONRPCResponse(
jsonrpc="2.0", jsonrpc="2.0",
id=jsonrpc_request.message.root.id, id=jsonrpc_request.root.id,
result=result.model_dump( result=result.model_dump(
by_alias=True, mode="json", exclude_none=True by_alias=True, mode="json", exclude_none=True
), ),
) )
),
raw=None,
) )
) )
jsonrpc_notification = await client_to_server_receive.receive() jsonrpc_notification = await client_to_server_receive.receive()
assert isinstance(jsonrpc_notification.message, JSONRPCMessage) assert isinstance(jsonrpc_notification.root, JSONRPCNotification)
initialized_notification = ClientNotification.model_validate( initialized_notification = ClientNotification.model_validate(
jsonrpc_notification.message.model_dump( jsonrpc_notification.model_dump(
by_alias=True, mode="json", exclude_none=True by_alias=True, mode="json", exclude_none=True
) )
) )

View File

@@ -11,7 +11,6 @@ from mcp.types import (
JSONRPCMessage, JSONRPCMessage,
JSONRPCNotification, JSONRPCNotification,
JSONRPCRequest, JSONRPCRequest,
MessageFrame,
NotificationParams, NotificationParams,
) )
@@ -65,9 +64,7 @@ async def test_request_id_match() -> None:
jsonrpc="2.0", jsonrpc="2.0",
) )
await client_writer.send( await client_writer.send(JSONRPCMessage(root=init_req))
MessageFrame(message=JSONRPCMessage(root=init_req), raw=None)
)
await server_reader.receive() # Get init response but don't need to check it await server_reader.receive() # Get init response but don't need to check it
# Send initialized notification # Send initialized notification
@@ -76,27 +73,21 @@ async def test_request_id_match() -> None:
params=NotificationParams().model_dump(by_alias=True, exclude_none=True), params=NotificationParams().model_dump(by_alias=True, exclude_none=True),
jsonrpc="2.0", jsonrpc="2.0",
) )
await client_writer.send( await client_writer.send(JSONRPCMessage(root=initialized_notification))
MessageFrame(
message=JSONRPCMessage(root=initialized_notification), raw=None
)
)
# Send ping request with custom ID # Send ping request with custom ID
ping_request = JSONRPCRequest( ping_request = JSONRPCRequest(
id=custom_request_id, method="ping", params={}, jsonrpc="2.0" id=custom_request_id, method="ping", params={}, jsonrpc="2.0"
) )
await client_writer.send( await client_writer.send(JSONRPCMessage(root=ping_request))
MessageFrame(message=JSONRPCMessage(root=ping_request), raw=None)
)
# Read response # Read response
response = await server_reader.receive() response = await server_reader.receive()
# Verify response ID matches request ID # Verify response ID matches request ID
assert ( assert (
response.message.root.id == custom_request_id response.root.id == custom_request_id
), "Response ID should match request ID" ), "Response ID should match request ID"
# Cancel server task # Cancel server task

View File

@@ -17,7 +17,6 @@ from mcp.types import (
JSONRPCMessage, JSONRPCMessage,
JSONRPCNotification, JSONRPCNotification,
JSONRPCRequest, JSONRPCRequest,
MessageFrame,
) )
@@ -65,7 +64,7 @@ async def test_lowlevel_server_lifespan():
send_stream2, send_stream2,
InitializationOptions( InitializationOptions(
server_name="test", server_name="test",
server_version="1.0.0", server_version="0.1.0",
capabilities=server.get_capabilities( capabilities=server.get_capabilities(
notification_options=NotificationOptions(), notification_options=NotificationOptions(),
experimental_capabilities={}, experimental_capabilities={},
@@ -83,51 +82,42 @@ async def test_lowlevel_server_lifespan():
clientInfo=Implementation(name="test-client", version="0.1.0"), clientInfo=Implementation(name="test-client", version="0.1.0"),
) )
await send_stream1.send( await send_stream1.send(
MessageFrame( JSONRPCMessage(
message=JSONRPCMessage(
root=JSONRPCRequest( root=JSONRPCRequest(
jsonrpc="2.0", jsonrpc="2.0",
id=1, id=1,
method="initialize", method="initialize",
params=TypeAdapter(InitializeRequestParams).dump_python(params), params=TypeAdapter(InitializeRequestParams).dump_python(params),
) )
),
raw=None,
) )
) )
response = await receive_stream2.receive() response = await receive_stream2.receive()
# Send initialized notification # Send initialized notification
await send_stream1.send( await send_stream1.send(
MessageFrame( JSONRPCMessage(
message=JSONRPCMessage(
root=JSONRPCNotification( root=JSONRPCNotification(
jsonrpc="2.0", jsonrpc="2.0",
method="notifications/initialized", method="notifications/initialized",
) )
),
raw=None,
) )
) )
# Call the tool to verify lifespan context # Call the tool to verify lifespan context
await send_stream1.send( await send_stream1.send(
MessageFrame( JSONRPCMessage(
message=JSONRPCMessage(
root=JSONRPCRequest( root=JSONRPCRequest(
jsonrpc="2.0", jsonrpc="2.0",
id=2, id=2,
method="tools/call", method="tools/call",
params={"name": "check_lifespan", "arguments": {}}, params={"name": "check_lifespan", "arguments": {}},
) )
),
raw=None,
) )
) )
# Get response and verify # Get response and verify
response = await receive_stream2.receive() response = await receive_stream2.receive()
assert response.message.root.result["content"][0]["text"] == "true" assert response.root.result["content"][0]["text"] == "true"
# Cancel server task # Cancel server task
tg.cancel_scope.cancel() tg.cancel_scope.cancel()
@@ -188,51 +178,42 @@ async def test_fastmcp_server_lifespan():
clientInfo=Implementation(name="test-client", version="0.1.0"), clientInfo=Implementation(name="test-client", version="0.1.0"),
) )
await send_stream1.send( await send_stream1.send(
MessageFrame( JSONRPCMessage(
message=JSONRPCMessage(
root=JSONRPCRequest( root=JSONRPCRequest(
jsonrpc="2.0", jsonrpc="2.0",
id=1, id=1,
method="initialize", method="initialize",
params=TypeAdapter(InitializeRequestParams).dump_python(params), params=TypeAdapter(InitializeRequestParams).dump_python(params),
) )
),
raw=None,
) )
) )
response = await receive_stream2.receive() response = await receive_stream2.receive()
# Send initialized notification # Send initialized notification
await send_stream1.send( await send_stream1.send(
MessageFrame( JSONRPCMessage(
message=JSONRPCMessage(
root=JSONRPCNotification( root=JSONRPCNotification(
jsonrpc="2.0", jsonrpc="2.0",
method="notifications/initialized", method="notifications/initialized",
) )
),
raw=None,
) )
) )
# Call the tool to verify lifespan context # Call the tool to verify lifespan context
await send_stream1.send( await send_stream1.send(
MessageFrame( JSONRPCMessage(
message=JSONRPCMessage(
root=JSONRPCRequest( root=JSONRPCRequest(
jsonrpc="2.0", jsonrpc="2.0",
id=2, id=2,
method="tools/call", method="tools/call",
params={"name": "check_lifespan", "arguments": {}}, params={"name": "check_lifespan", "arguments": {}},
) )
),
raw=None,
) )
) )
# Get response and verify # Get response and verify
response = await receive_stream2.receive() response = await receive_stream2.receive()
assert response.message.root.result["content"][0]["text"] == "true" assert response.root.result["content"][0]["text"] == "true"
# Cancel server task # Cancel server task
tg.cancel_scope.cancel() tg.cancel_scope.cancel()

View File

@@ -9,7 +9,7 @@ from mcp.server.session import ServerSession
from mcp.types import ( from mcp.types import (
ClientNotification, ClientNotification,
InitializedNotification, InitializedNotification,
MessageFrame, JSONRPCMessage,
PromptsCapability, PromptsCapability,
ResourcesCapability, ResourcesCapability,
ServerCapabilities, ServerCapabilities,
@@ -19,10 +19,10 @@ from mcp.types import (
@pytest.mark.anyio @pytest.mark.anyio
async def test_server_session_initialize(): async def test_server_session_initialize():
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
MessageFrame[None] JSONRPCMessage
](1) ](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
MessageFrame[None] JSONRPCMessage
](1) ](1)
async def run_client(client: ClientSession): async def run_client(client: ClientSession):

View File

@@ -4,7 +4,7 @@ import anyio
import pytest import pytest
from mcp.server.stdio import stdio_server from mcp.server.stdio import stdio_server
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, MessageFrame from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
@pytest.mark.anyio @pytest.mark.anyio
@@ -13,8 +13,8 @@ async def test_stdio_server():
stdout = io.StringIO() stdout = io.StringIO()
messages = [ messages = [
JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"), JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")),
JSONRPCResponse(jsonrpc="2.0", id=2, result={}), JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})),
] ]
for message in messages: for message in messages:
@@ -35,29 +35,17 @@ async def test_stdio_server():
# Verify received messages # Verify received messages
assert len(received_messages) == 2 assert len(received_messages) == 2
assert isinstance(received_messages[0].message, JSONRPCMessage) assert received_messages[0] == JSONRPCMessage(
assert isinstance(received_messages[0].message.root, JSONRPCRequest) root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
assert received_messages[0].message.root.id == 1 )
assert received_messages[0].message.root.method == "ping" assert received_messages[1] == JSONRPCMessage(
root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})
assert isinstance(received_messages[1].message, JSONRPCMessage) )
assert isinstance(received_messages[1].message.root, JSONRPCResponse)
assert received_messages[1].message.root.id == 2
# Test sending responses from the server # Test sending responses from the server
responses = [ responses = [
MessageFrame( JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")),
message=JSONRPCMessage( JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})),
root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")
),
raw=None,
),
MessageFrame(
message=JSONRPCMessage(
root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})
),
raw=None,
),
] ]
async with write_stream: async with write_stream:
@@ -68,10 +56,13 @@ async def test_stdio_server():
output_lines = stdout.readlines() output_lines = stdout.readlines()
assert len(output_lines) == 2 assert len(output_lines) == 2
# Parse and verify the JSON responses directly received_responses = [
request_json = JSONRPCRequest.model_validate_json(output_lines[0].strip()) JSONRPCMessage.model_validate_json(line.strip()) for line in output_lines
response_json = JSONRPCResponse.model_validate_json(output_lines[1].strip()) ]
assert len(received_responses) == 2
assert request_json.id == 3 assert received_responses[0] == JSONRPCMessage(
assert request_json.method == "ping" root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")
assert response_json.id == 4 )
assert received_responses[1] == JSONRPCMessage(
root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})
)