Revert "refactor: reorganize message handling for better type safety and clar…" (#282)

This reverts commit 9d0f2daddb.
This commit is contained in:
Marcelo Trylesinski
2025-03-14 10:50:46 +01:00
committed by GitHub
parent ebb81d3b2b
commit 7196604468
17 changed files with 151 additions and 283 deletions

View File

@@ -1,11 +1,12 @@
from datetime import timedelta
from typing import Any, Protocol
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl, TypeAdapter
import mcp.types as types
from mcp.shared.context import RequestContext
from mcp.shared.session import BaseSession, ReadStream, RequestResponder, WriteStream
from mcp.shared.session import BaseSession, RequestResponder
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
@@ -58,8 +59,8 @@ class ClientSession(
):
def __init__(
self,
read_stream: ReadStream,
write_stream: WriteStream,
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,

View File

@@ -6,16 +6,10 @@ from urllib.parse import urljoin, urlparse
import anyio
import httpx
from anyio.abc import TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import aconnect_sse
import mcp.types as types
from mcp.shared.session import (
ReadStream,
ReadStreamWriter,
WriteStream,
WriteStreamReader,
)
from mcp.types import MessageFrame
logger = logging.getLogger(__name__)
@@ -37,11 +31,11 @@ async def sse_client(
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.
"""
read_stream: ReadStream
read_stream_writer: ReadStreamWriter
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: WriteStream
write_stream_reader: WriteStreamReader
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -90,11 +84,8 @@ async def sse_client(
case "message":
try:
message = MessageFrame(
message=types.JSONRPCMessage.model_validate_json( # noqa: E501
sse.data
),
raw=sse,
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
sse.data
)
logger.debug(
f"Received server message: {message}"

View File

@@ -1,7 +1,7 @@
import json
import logging
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator
from typing import AsyncGenerator
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
@@ -10,7 +10,6 @@ from websockets.asyncio.client import connect as ws_connect
from websockets.typing import Subprotocol
import mcp.types as types
from mcp.types import MessageFrame
logger = logging.getLogger(__name__)
@@ -20,8 +19,8 @@ async def websocket_client(
url: str,
) -> AsyncGenerator[
tuple[
MemoryObjectReceiveStream[MessageFrame[Any] | Exception],
MemoryObjectSendStream[MessageFrame[Any]],
MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
MemoryObjectSendStream[types.JSONRPCMessage],
],
None,
]:
@@ -54,11 +53,7 @@ async def websocket_client(
async with read_stream_writer:
async for raw_text in ws:
try:
json_message = types.JSONRPCMessage.model_validate_json(
raw_text
)
# Create MessageFrame with JSON message as root
message = MessageFrame(message=json_message, raw=raw_text)
message = types.JSONRPCMessage.model_validate_json(raw_text)
await read_stream_writer.send(message)
except ValidationError as exc:
# If JSON parse or model validation fails, send the exception
@@ -71,8 +66,8 @@ async def websocket_client(
"""
async with write_stream_reader:
async for message in write_stream_reader:
# Extract the JSON-RPC message from MessageFrame and convert to JSON
msg_dict = message.message.model_dump(
# Convert to a dict, then to JSON
msg_dict = message.model_dump(
by_alias=True, mode="json", exclude_none=True
)
await ws.send(json.dumps(msg_dict))

View File

@@ -74,6 +74,7 @@ from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontext
from typing import Any, AsyncIterator, Generic, TypeVar
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
import mcp.types as types
@@ -83,7 +84,7 @@ from mcp.server.session import ServerSession
from mcp.server.stdio import stdio_server as stdio_server
from mcp.shared.context import RequestContext
from mcp.shared.exceptions import McpError
from mcp.shared.session import ReadStream, RequestResponder, WriteStream
from mcp.shared.session import RequestResponder
logger = logging.getLogger(__name__)
@@ -473,8 +474,8 @@ class Server(Generic[LifespanResultT]):
async def run(
self,
read_stream: ReadStream,
write_stream: WriteStream,
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
initialization_options: InitializationOptions,
# When False, exceptions are returned as messages to the client.
# When True, exceptions are raised, which will cause the server to shut down

View File

@@ -5,7 +5,9 @@ and tools.
from pydantic import BaseModel
from mcp.types import ServerCapabilities
from mcp.types import (
ServerCapabilities,
)
class InitializationOptions(BaseModel):

View File

@@ -42,15 +42,14 @@ from typing import Any, TypeVar
import anyio
import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
import mcp.types as types
from mcp.server.models import InitializationOptions
from mcp.shared.session import (
BaseSession,
ReadStream,
RequestResponder,
WriteStream,
)
@@ -77,8 +76,8 @@ class ServerSession(
def __init__(
self,
read_stream: ReadStream,
write_stream: WriteStream,
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
init_options: InitializationOptions,
) -> None:
super().__init__(

View File

@@ -38,6 +38,7 @@ from urllib.parse import quote
from uuid import UUID, uuid4
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import ValidationError
from sse_starlette import EventSourceResponse
from starlette.requests import Request
@@ -45,13 +46,6 @@ from starlette.responses import Response
from starlette.types import Receive, Scope, Send
import mcp.types as types
from mcp.shared.session import (
ReadStream,
ReadStreamWriter,
WriteStream,
WriteStreamReader,
)
from mcp.types import MessageFrame
logger = logging.getLogger(__name__)
@@ -69,7 +63,9 @@ class SseServerTransport:
"""
_endpoint: str
_read_stream_writers: dict[UUID, ReadStreamWriter]
_read_stream_writers: dict[
UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception]
]
def __init__(self, endpoint: str) -> None:
"""
@@ -89,11 +85,11 @@ class SseServerTransport:
raise ValueError("connect_sse can only handle HTTP requests")
logger.debug("Setting up SSE connection")
read_stream: ReadStream
read_stream_writer: ReadStreamWriter
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: WriteStream
write_stream_reader: WriteStreamReader
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -176,4 +172,4 @@ class SseServerTransport:
logger.debug(f"Sending message to writer: {message}")
response = Response("Accepted", status_code=202)
await response(scope, receive, send)
await writer.send(MessageFrame(message=message, raw=request))
await writer.send(message)

View File

@@ -24,15 +24,9 @@ from io import TextIOWrapper
import anyio
import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
import mcp.types as types
from mcp.shared.session import (
ReadStream,
ReadStreamWriter,
WriteStream,
WriteStreamReader,
)
from mcp.types import MessageFrame
@asynccontextmanager
@@ -53,11 +47,11 @@ async def stdio_server(
if not stdout:
stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8"))
read_stream: ReadStream
read_stream_writer: ReadStreamWriter
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: WriteStream
write_stream_reader: WriteStreamReader
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -72,9 +66,7 @@ async def stdio_server(
await read_stream_writer.send(exc)
continue
await read_stream_writer.send(
MessageFrame(message=message, raw=line)
)
await read_stream_writer.send(message)
except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint()
@@ -82,7 +74,6 @@ async def stdio_server(
try:
async with write_stream_reader:
async for message in write_stream_reader:
# Extract the inner JSONRPCRequest/JSONRPCResponse from MessageFrame
json = message.model_dump_json(by_alias=True, exclude_none=True)
await stdout.write(json + "\n")
await stdout.flush()

View File

@@ -2,17 +2,11 @@ import logging
from contextlib import asynccontextmanager
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from starlette.types import Receive, Scope, Send
from starlette.websockets import WebSocket
import mcp.types as types
from mcp.shared.session import (
ReadStream,
ReadStreamWriter,
WriteStream,
WriteStreamReader,
)
from mcp.types import MessageFrame
logger = logging.getLogger(__name__)
@@ -27,11 +21,11 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send):
websocket = WebSocket(scope, receive, send)
await websocket.accept(subprotocol="mcp")
read_stream: ReadStream
read_stream_writer: ReadStreamWriter
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: WriteStream
write_stream_reader: WriteStreamReader
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -46,9 +40,7 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send):
await read_stream_writer.send(exc)
continue
await read_stream_writer.send(
MessageFrame(message=client_message, raw=message)
)
await read_stream_writer.send(client_message)
except anyio.ClosedResourceError:
await websocket.close()

View File

@@ -11,11 +11,11 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStre
from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT
from mcp.server import Server
from mcp.types import MessageFrame
from mcp.types import JSONRPCMessage
MessageStream = tuple[
MemoryObjectReceiveStream[MessageFrame | Exception],
MemoryObjectSendStream[MessageFrame],
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage],
]
@@ -32,10 +32,10 @@ async def create_client_server_memory_streams() -> (
"""
# Create streams for both directions
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
MessageFrame | Exception
JSONRPCMessage | Exception
](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
MessageFrame | Exception
JSONRPCMessage | Exception
](1)
client_streams = (server_to_client_receive, client_to_server_send)
@@ -60,9 +60,12 @@ async def create_connected_server_and_client_session(
) -> AsyncGenerator[ClientSession, None]:
"""Creates a ClientSession that is connected to a running MCP server."""
async with create_client_server_memory_streams() as (
(client_read, client_write),
(server_read, server_write),
client_streams,
server_streams,
):
client_read, client_write = client_streams
server_read, server_write = server_streams
# Create a cancel scope for the server task
async with anyio.create_task_group() as tg:
tg.start_soon(

View File

@@ -22,18 +22,12 @@ from mcp.types import (
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
MessageFrame,
RequestParams,
ServerNotification,
ServerRequest,
ServerResult,
)
ReadStream = MemoryObjectReceiveStream[MessageFrame | Exception]
ReadStreamWriter = MemoryObjectSendStream[MessageFrame | Exception]
WriteStream = MemoryObjectSendStream[MessageFrame]
WriteStreamReader = MemoryObjectReceiveStream[MessageFrame]
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
@@ -171,8 +165,8 @@ class BaseSession(
def __init__(
self,
read_stream: ReadStream,
write_stream: WriteStream,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
receive_request_type: type[ReceiveRequestT],
receive_notification_type: type[ReceiveNotificationT],
# If none, reading will never time out
@@ -248,9 +242,7 @@ class BaseSession(
# TODO: Support progress callbacks
await self._write_stream.send(
MessageFrame(message=JSONRPCMessage(jsonrpc_request), raw=None)
)
await self._write_stream.send(JSONRPCMessage(jsonrpc_request))
try:
with anyio.fail_after(
@@ -286,18 +278,14 @@ class BaseSession(
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
)
await self._write_stream.send(
MessageFrame(message=JSONRPCMessage(jsonrpc_notification), raw=None)
)
await self._write_stream.send(JSONRPCMessage(jsonrpc_notification))
async def _send_response(
self, request_id: RequestId, response: SendResultT | ErrorData
) -> None:
if isinstance(response, ErrorData):
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
await self._write_stream.send(
MessageFrame(message=JSONRPCMessage(jsonrpc_error), raw=None)
)
await self._write_stream.send(JSONRPCMessage(jsonrpc_error))
else:
jsonrpc_response = JSONRPCResponse(
jsonrpc="2.0",
@@ -306,9 +294,7 @@ class BaseSession(
by_alias=True, mode="json", exclude_none=True
),
)
await self._write_stream.send(
MessageFrame(message=JSONRPCMessage(jsonrpc_response), raw=None)
)
await self._write_stream.send(JSONRPCMessage(jsonrpc_response))
async def _receive_loop(self) -> None:
async with (
@@ -316,13 +302,10 @@ class BaseSession(
self._write_stream,
self._incoming_message_stream_writer,
):
async for raw_message in self._read_stream:
if isinstance(raw_message, Exception):
await self._incoming_message_stream_writer.send(raw_message)
continue
message = raw_message.message
if isinstance(message.root, JSONRPCRequest):
async for message in self._read_stream:
if isinstance(message, Exception):
await self._incoming_message_stream_writer.send(message)
elif isinstance(message.root, JSONRPCRequest):
validated_request = self._receive_request_type.model_validate(
message.root.model_dump(
by_alias=True, mode="json", exclude_none=True

View File

@@ -180,49 +180,6 @@ class JSONRPCMessage(
pass
RawT = TypeVar("RawT")
class MessageFrame(BaseModel, Generic[RawT]):
"""
A wrapper around the general message received that contains both the parsed message
and the raw message.
This class serves as an encapsulation for JSON-RPC messages, providing access to
both the parsed structure (root) and the original raw data. This design is
particularly useful for Server-Sent Events (SSE) consumers who may need to access
additional metadata or headers associated with the message.
The 'root' attribute contains the parsed JSONRPCMessage, which could be a request,
notification, response, or error. The 'raw' attribute preserves the original
message as received, allowing access to any additional context or metadata that
might be lost in parsing.
This dual representation allows for flexible handling of messages, where consumers
can work with the structured data for standard operations, but still have the
option to examine or utilize the raw data when needed, such as for debugging,
logging, or accessing transport-specific information.
"""
message: JSONRPCMessage
raw: RawT | None = None
model_config = ConfigDict(extra="allow")
def model_dump(self, *args, **kwargs):
"""
Dumps the model to a dictionary, delegating to the root JSONRPCMessage.
This method allows for consistent serialization of the parsed message.
"""
return self.message.model_dump(*args, **kwargs)
def model_dump_json(self, *args, **kwargs):
"""
Dumps the model to a JSON string, delegating to the root JSONRPCMessage.
This method provides a convenient way to serialize the parsed message to JSON.
"""
return self.message.model_dump_json(*args, **kwargs)
class EmptyResult(Result):
"""A response that indicates success but carries no data."""

View File

@@ -1,5 +1,3 @@
from types import NoneType
import anyio
import pytest
@@ -13,9 +11,9 @@ from mcp.types import (
InitializeRequest,
InitializeResult,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
MessageFrame,
ServerCapabilities,
ServerResult,
)
@@ -24,10 +22,10 @@ from mcp.types import (
@pytest.mark.anyio
async def test_client_session_initialize():
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
MessageFrame[NoneType]
JSONRPCMessage
](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
MessageFrame[NoneType]
JSONRPCMessage
](1)
initialized_notification = None
@@ -36,7 +34,7 @@ async def test_client_session_initialize():
nonlocal initialized_notification
jsonrpc_request = await client_to_server_receive.receive()
assert isinstance(jsonrpc_request, MessageFrame)
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
request = ClientRequest.model_validate(
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
)
@@ -58,25 +56,21 @@ async def test_client_session_initialize():
)
async with server_to_client_send:
assert isinstance(jsonrpc_request.message.root, JSONRPCRequest)
await server_to_client_send.send(
MessageFrame(
message=JSONRPCMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.message.root.id,
result=result.model_dump(
by_alias=True, mode="json", exclude_none=True
),
)
),
raw=None,
JSONRPCMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.id,
result=result.model_dump(
by_alias=True, mode="json", exclude_none=True
),
)
)
)
jsonrpc_notification = await client_to_server_receive.receive()
assert isinstance(jsonrpc_notification.message, JSONRPCMessage)
assert isinstance(jsonrpc_notification.root, JSONRPCNotification)
initialized_notification = ClientNotification.model_validate(
jsonrpc_notification.message.model_dump(
jsonrpc_notification.model_dump(
by_alias=True, mode="json", exclude_none=True
)
)

View File

@@ -11,7 +11,6 @@ from mcp.types import (
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
MessageFrame,
NotificationParams,
)
@@ -65,9 +64,7 @@ async def test_request_id_match() -> None:
jsonrpc="2.0",
)
await client_writer.send(
MessageFrame(message=JSONRPCMessage(root=init_req), raw=None)
)
await client_writer.send(JSONRPCMessage(root=init_req))
await server_reader.receive() # Get init response but don't need to check it
# Send initialized notification
@@ -76,27 +73,21 @@ async def test_request_id_match() -> None:
params=NotificationParams().model_dump(by_alias=True, exclude_none=True),
jsonrpc="2.0",
)
await client_writer.send(
MessageFrame(
message=JSONRPCMessage(root=initialized_notification), raw=None
)
)
await client_writer.send(JSONRPCMessage(root=initialized_notification))
# Send ping request with custom ID
ping_request = JSONRPCRequest(
id=custom_request_id, method="ping", params={}, jsonrpc="2.0"
)
await client_writer.send(
MessageFrame(message=JSONRPCMessage(root=ping_request), raw=None)
)
await client_writer.send(JSONRPCMessage(root=ping_request))
# Read response
response = await server_reader.receive()
# Verify response ID matches request ID
assert (
response.message.root.id == custom_request_id
response.root.id == custom_request_id
), "Response ID should match request ID"
# Cancel server task

View File

@@ -17,7 +17,6 @@ from mcp.types import (
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
MessageFrame,
)
@@ -65,7 +64,7 @@ async def test_lowlevel_server_lifespan():
send_stream2,
InitializationOptions(
server_name="test",
server_version="1.0.0",
server_version="0.1.0",
capabilities=server.get_capabilities(
notification_options=NotificationOptions(),
experimental_capabilities={},
@@ -83,51 +82,42 @@ async def test_lowlevel_server_lifespan():
clientInfo=Implementation(name="test-client", version="0.1.0"),
)
await send_stream1.send(
MessageFrame(
message=JSONRPCMessage(
root=JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="initialize",
params=TypeAdapter(InitializeRequestParams).dump_python(params),
)
),
raw=None,
JSONRPCMessage(
root=JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="initialize",
params=TypeAdapter(InitializeRequestParams).dump_python(params),
)
)
)
response = await receive_stream2.receive()
# Send initialized notification
await send_stream1.send(
MessageFrame(
message=JSONRPCMessage(
root=JSONRPCNotification(
jsonrpc="2.0",
method="notifications/initialized",
)
),
raw=None,
JSONRPCMessage(
root=JSONRPCNotification(
jsonrpc="2.0",
method="notifications/initialized",
)
)
)
# Call the tool to verify lifespan context
await send_stream1.send(
MessageFrame(
message=JSONRPCMessage(
root=JSONRPCRequest(
jsonrpc="2.0",
id=2,
method="tools/call",
params={"name": "check_lifespan", "arguments": {}},
)
),
raw=None,
JSONRPCMessage(
root=JSONRPCRequest(
jsonrpc="2.0",
id=2,
method="tools/call",
params={"name": "check_lifespan", "arguments": {}},
)
)
)
# Get response and verify
response = await receive_stream2.receive()
assert response.message.root.result["content"][0]["text"] == "true"
assert response.root.result["content"][0]["text"] == "true"
# Cancel server task
tg.cancel_scope.cancel()
@@ -188,51 +178,42 @@ async def test_fastmcp_server_lifespan():
clientInfo=Implementation(name="test-client", version="0.1.0"),
)
await send_stream1.send(
MessageFrame(
message=JSONRPCMessage(
root=JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="initialize",
params=TypeAdapter(InitializeRequestParams).dump_python(params),
)
),
raw=None,
JSONRPCMessage(
root=JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="initialize",
params=TypeAdapter(InitializeRequestParams).dump_python(params),
)
)
)
response = await receive_stream2.receive()
# Send initialized notification
await send_stream1.send(
MessageFrame(
message=JSONRPCMessage(
root=JSONRPCNotification(
jsonrpc="2.0",
method="notifications/initialized",
)
),
raw=None,
JSONRPCMessage(
root=JSONRPCNotification(
jsonrpc="2.0",
method="notifications/initialized",
)
)
)
# Call the tool to verify lifespan context
await send_stream1.send(
MessageFrame(
message=JSONRPCMessage(
root=JSONRPCRequest(
jsonrpc="2.0",
id=2,
method="tools/call",
params={"name": "check_lifespan", "arguments": {}},
)
),
raw=None,
JSONRPCMessage(
root=JSONRPCRequest(
jsonrpc="2.0",
id=2,
method="tools/call",
params={"name": "check_lifespan", "arguments": {}},
)
)
)
# Get response and verify
response = await receive_stream2.receive()
assert response.message.root.result["content"][0]["text"] == "true"
assert response.root.result["content"][0]["text"] == "true"
# Cancel server task
tg.cancel_scope.cancel()

View File

@@ -9,7 +9,7 @@ from mcp.server.session import ServerSession
from mcp.types import (
ClientNotification,
InitializedNotification,
MessageFrame,
JSONRPCMessage,
PromptsCapability,
ResourcesCapability,
ServerCapabilities,
@@ -19,10 +19,10 @@ from mcp.types import (
@pytest.mark.anyio
async def test_server_session_initialize():
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
MessageFrame[None]
JSONRPCMessage
](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
MessageFrame[None]
JSONRPCMessage
](1)
async def run_client(client: ClientSession):

View File

@@ -4,7 +4,7 @@ import anyio
import pytest
from mcp.server.stdio import stdio_server
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, MessageFrame
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
@pytest.mark.anyio
@@ -13,8 +13,8 @@ async def test_stdio_server():
stdout = io.StringIO()
messages = [
JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"),
JSONRPCResponse(jsonrpc="2.0", id=2, result={}),
JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")),
JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})),
]
for message in messages:
@@ -35,29 +35,17 @@ async def test_stdio_server():
# Verify received messages
assert len(received_messages) == 2
assert isinstance(received_messages[0].message, JSONRPCMessage)
assert isinstance(received_messages[0].message.root, JSONRPCRequest)
assert received_messages[0].message.root.id == 1
assert received_messages[0].message.root.method == "ping"
assert isinstance(received_messages[1].message, JSONRPCMessage)
assert isinstance(received_messages[1].message.root, JSONRPCResponse)
assert received_messages[1].message.root.id == 2
assert received_messages[0] == JSONRPCMessage(
root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
)
assert received_messages[1] == JSONRPCMessage(
root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})
)
# Test sending responses from the server
responses = [
MessageFrame(
message=JSONRPCMessage(
root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")
),
raw=None,
),
MessageFrame(
message=JSONRPCMessage(
root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})
),
raw=None,
),
JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")),
JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})),
]
async with write_stream:
@@ -68,10 +56,13 @@ async def test_stdio_server():
output_lines = stdout.readlines()
assert len(output_lines) == 2
# Parse and verify the JSON responses directly
request_json = JSONRPCRequest.model_validate_json(output_lines[0].strip())
response_json = JSONRPCResponse.model_validate_json(output_lines[1].strip())
assert request_json.id == 3
assert request_json.method == "ping"
assert response_json.id == 4
received_responses = [
JSONRPCMessage.model_validate_json(line.strip()) for line in output_lines
]
assert len(received_responses) == 2
assert received_responses[0] == JSONRPCMessage(
root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")
)
assert received_responses[1] == JSONRPCMessage(
root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})
)