Revert "refactor: reorganize message handling for better type safety and clar…" (#282)

This reverts commit 9d0f2daddb.
This commit is contained in:
Marcelo Trylesinski
2025-03-14 10:50:46 +01:00
committed by GitHub
parent ebb81d3b2b
commit 7196604468
17 changed files with 151 additions and 283 deletions

View File

@@ -1,5 +1,3 @@
from types import NoneType
import anyio
import pytest
@@ -13,9 +11,9 @@ from mcp.types import (
InitializeRequest,
InitializeResult,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
MessageFrame,
ServerCapabilities,
ServerResult,
)
@@ -24,10 +22,10 @@ from mcp.types import (
@pytest.mark.anyio
async def test_client_session_initialize():
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
MessageFrame[NoneType]
JSONRPCMessage
](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
MessageFrame[NoneType]
JSONRPCMessage
](1)
initialized_notification = None
@@ -36,7 +34,7 @@ async def test_client_session_initialize():
nonlocal initialized_notification
jsonrpc_request = await client_to_server_receive.receive()
assert isinstance(jsonrpc_request, MessageFrame)
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
request = ClientRequest.model_validate(
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
)
@@ -58,25 +56,21 @@ async def test_client_session_initialize():
)
async with server_to_client_send:
assert isinstance(jsonrpc_request.message.root, JSONRPCRequest)
await server_to_client_send.send(
MessageFrame(
message=JSONRPCMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.message.root.id,
result=result.model_dump(
by_alias=True, mode="json", exclude_none=True
),
)
),
raw=None,
JSONRPCMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.id,
result=result.model_dump(
by_alias=True, mode="json", exclude_none=True
),
)
)
)
jsonrpc_notification = await client_to_server_receive.receive()
assert isinstance(jsonrpc_notification.message, JSONRPCMessage)
assert isinstance(jsonrpc_notification.root, JSONRPCNotification)
initialized_notification = ClientNotification.model_validate(
jsonrpc_notification.message.model_dump(
jsonrpc_notification.model_dump(
by_alias=True, mode="json", exclude_none=True
)
)

View File

@@ -11,7 +11,6 @@ from mcp.types import (
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
MessageFrame,
NotificationParams,
)
@@ -65,9 +64,7 @@ async def test_request_id_match() -> None:
jsonrpc="2.0",
)
await client_writer.send(
MessageFrame(message=JSONRPCMessage(root=init_req), raw=None)
)
await client_writer.send(JSONRPCMessage(root=init_req))
await server_reader.receive() # Get init response but don't need to check it
# Send initialized notification
@@ -76,27 +73,21 @@ async def test_request_id_match() -> None:
params=NotificationParams().model_dump(by_alias=True, exclude_none=True),
jsonrpc="2.0",
)
await client_writer.send(
MessageFrame(
message=JSONRPCMessage(root=initialized_notification), raw=None
)
)
await client_writer.send(JSONRPCMessage(root=initialized_notification))
# Send ping request with custom ID
ping_request = JSONRPCRequest(
id=custom_request_id, method="ping", params={}, jsonrpc="2.0"
)
await client_writer.send(
MessageFrame(message=JSONRPCMessage(root=ping_request), raw=None)
)
await client_writer.send(JSONRPCMessage(root=ping_request))
# Read response
response = await server_reader.receive()
# Verify response ID matches request ID
assert (
response.message.root.id == custom_request_id
response.root.id == custom_request_id
), "Response ID should match request ID"
# Cancel server task

View File

@@ -17,7 +17,6 @@ from mcp.types import (
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
MessageFrame,
)
@@ -65,7 +64,7 @@ async def test_lowlevel_server_lifespan():
send_stream2,
InitializationOptions(
server_name="test",
server_version="1.0.0",
server_version="0.1.0",
capabilities=server.get_capabilities(
notification_options=NotificationOptions(),
experimental_capabilities={},
@@ -83,51 +82,42 @@ async def test_lowlevel_server_lifespan():
clientInfo=Implementation(name="test-client", version="0.1.0"),
)
await send_stream1.send(
MessageFrame(
message=JSONRPCMessage(
root=JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="initialize",
params=TypeAdapter(InitializeRequestParams).dump_python(params),
)
),
raw=None,
JSONRPCMessage(
root=JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="initialize",
params=TypeAdapter(InitializeRequestParams).dump_python(params),
)
)
)
response = await receive_stream2.receive()
# Send initialized notification
await send_stream1.send(
MessageFrame(
message=JSONRPCMessage(
root=JSONRPCNotification(
jsonrpc="2.0",
method="notifications/initialized",
)
),
raw=None,
JSONRPCMessage(
root=JSONRPCNotification(
jsonrpc="2.0",
method="notifications/initialized",
)
)
)
# Call the tool to verify lifespan context
await send_stream1.send(
MessageFrame(
message=JSONRPCMessage(
root=JSONRPCRequest(
jsonrpc="2.0",
id=2,
method="tools/call",
params={"name": "check_lifespan", "arguments": {}},
)
),
raw=None,
JSONRPCMessage(
root=JSONRPCRequest(
jsonrpc="2.0",
id=2,
method="tools/call",
params={"name": "check_lifespan", "arguments": {}},
)
)
)
# Get response and verify
response = await receive_stream2.receive()
assert response.message.root.result["content"][0]["text"] == "true"
assert response.root.result["content"][0]["text"] == "true"
# Cancel server task
tg.cancel_scope.cancel()
@@ -188,51 +178,42 @@ async def test_fastmcp_server_lifespan():
clientInfo=Implementation(name="test-client", version="0.1.0"),
)
await send_stream1.send(
MessageFrame(
message=JSONRPCMessage(
root=JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="initialize",
params=TypeAdapter(InitializeRequestParams).dump_python(params),
)
),
raw=None,
JSONRPCMessage(
root=JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="initialize",
params=TypeAdapter(InitializeRequestParams).dump_python(params),
)
)
)
response = await receive_stream2.receive()
# Send initialized notification
await send_stream1.send(
MessageFrame(
message=JSONRPCMessage(
root=JSONRPCNotification(
jsonrpc="2.0",
method="notifications/initialized",
)
),
raw=None,
JSONRPCMessage(
root=JSONRPCNotification(
jsonrpc="2.0",
method="notifications/initialized",
)
)
)
# Call the tool to verify lifespan context
await send_stream1.send(
MessageFrame(
message=JSONRPCMessage(
root=JSONRPCRequest(
jsonrpc="2.0",
id=2,
method="tools/call",
params={"name": "check_lifespan", "arguments": {}},
)
),
raw=None,
JSONRPCMessage(
root=JSONRPCRequest(
jsonrpc="2.0",
id=2,
method="tools/call",
params={"name": "check_lifespan", "arguments": {}},
)
)
)
# Get response and verify
response = await receive_stream2.receive()
assert response.message.root.result["content"][0]["text"] == "true"
assert response.root.result["content"][0]["text"] == "true"
# Cancel server task
tg.cancel_scope.cancel()

View File

@@ -9,7 +9,7 @@ from mcp.server.session import ServerSession
from mcp.types import (
ClientNotification,
InitializedNotification,
MessageFrame,
JSONRPCMessage,
PromptsCapability,
ResourcesCapability,
ServerCapabilities,
@@ -19,10 +19,10 @@ from mcp.types import (
@pytest.mark.anyio
async def test_server_session_initialize():
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
MessageFrame[None]
JSONRPCMessage
](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
MessageFrame[None]
JSONRPCMessage
](1)
async def run_client(client: ClientSession):

View File

@@ -4,7 +4,7 @@ import anyio
import pytest
from mcp.server.stdio import stdio_server
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, MessageFrame
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
@pytest.mark.anyio
@@ -13,8 +13,8 @@ async def test_stdio_server():
stdout = io.StringIO()
messages = [
JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"),
JSONRPCResponse(jsonrpc="2.0", id=2, result={}),
JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")),
JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})),
]
for message in messages:
@@ -35,29 +35,17 @@ async def test_stdio_server():
# Verify received messages
assert len(received_messages) == 2
assert isinstance(received_messages[0].message, JSONRPCMessage)
assert isinstance(received_messages[0].message.root, JSONRPCRequest)
assert received_messages[0].message.root.id == 1
assert received_messages[0].message.root.method == "ping"
assert isinstance(received_messages[1].message, JSONRPCMessage)
assert isinstance(received_messages[1].message.root, JSONRPCResponse)
assert received_messages[1].message.root.id == 2
assert received_messages[0] == JSONRPCMessage(
root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
)
assert received_messages[1] == JSONRPCMessage(
root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})
)
# Test sending responses from the server
responses = [
MessageFrame(
message=JSONRPCMessage(
root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")
),
raw=None,
),
MessageFrame(
message=JSONRPCMessage(
root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})
),
raw=None,
),
JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")),
JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})),
]
async with write_stream:
@@ -68,10 +56,13 @@ async def test_stdio_server():
output_lines = stdout.readlines()
assert len(output_lines) == 2
# Parse and verify the JSON responses directly
request_json = JSONRPCRequest.model_validate_json(output_lines[0].strip())
response_json = JSONRPCResponse.model_validate_json(output_lines[1].strip())
assert request_json.id == 3
assert request_json.method == "ping"
assert response_json.id == 4
received_responses = [
JSONRPCMessage.model_validate_json(line.strip()) for line in output_lines
]
assert len(received_responses) == 2
assert received_responses[0] == JSONRPCMessage(
root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")
)
assert received_responses[1] == JSONRPCMessage(
root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})
)