mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
refactor: reorganize message handling for better type safety and clarity (#239)
* refactor: improve typing with memory stream type aliases Move memory stream type definitions to models.py and use them throughout the codebase for better type safety and maintainability. GitHub-Issue:#201 * refactor: move streams to ParsedMessage * refactor: update test files to use ParsedMessage Updates test files to work with the ParsedMessage stream type aliases and fixes a line length issue in test_201_client_hangs_on_logging.py. Github-Issue:#201 * refactor: rename ParsedMessage to MessageFrame for clarity 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * refactor: move MessageFrame class to types.py for better code organization 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * fix pyright * refactor: update websocket client to use MessageFrame Modified the websocket client to work with the new MessageFrame type, preserving raw message text and properly extracting the root JSON-RPC message when sending. Github-Issue:#204 * fix: use NoneType instead of None for type parameters in MessageFrame 🤖 Generated with [Claude Code](https://claude.ai/code) * refactor: rename root to message
This commit is contained in:
committed by
GitHub
parent
ad7f7a5473
commit
9d0f2daddb
@@ -1,12 +1,11 @@
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Any, Protocol
|
from typing import Any, Protocol
|
||||||
|
|
||||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
||||||
from pydantic import AnyUrl, TypeAdapter
|
from pydantic import AnyUrl, TypeAdapter
|
||||||
|
|
||||||
import mcp.types as types
|
import mcp.types as types
|
||||||
from mcp.shared.context import RequestContext
|
from mcp.shared.context import RequestContext
|
||||||
from mcp.shared.session import BaseSession, RequestResponder
|
from mcp.shared.session import BaseSession, ReadStream, RequestResponder, WriteStream
|
||||||
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
||||||
|
|
||||||
|
|
||||||
@@ -59,8 +58,8 @@ class ClientSession(
|
|||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
|
read_stream: ReadStream,
|
||||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
|
write_stream: WriteStream,
|
||||||
read_timeout_seconds: timedelta | None = None,
|
read_timeout_seconds: timedelta | None = None,
|
||||||
sampling_callback: SamplingFnT | None = None,
|
sampling_callback: SamplingFnT | None = None,
|
||||||
list_roots_callback: ListRootsFnT | None = None,
|
list_roots_callback: ListRootsFnT | None = None,
|
||||||
|
|||||||
@@ -6,10 +6,16 @@ from urllib.parse import urljoin, urlparse
|
|||||||
import anyio
|
import anyio
|
||||||
import httpx
|
import httpx
|
||||||
from anyio.abc import TaskStatus
|
from anyio.abc import TaskStatus
|
||||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
||||||
from httpx_sse import aconnect_sse
|
from httpx_sse import aconnect_sse
|
||||||
|
|
||||||
import mcp.types as types
|
import mcp.types as types
|
||||||
|
from mcp.shared.session import (
|
||||||
|
ReadStream,
|
||||||
|
ReadStreamWriter,
|
||||||
|
WriteStream,
|
||||||
|
WriteStreamReader,
|
||||||
|
)
|
||||||
|
from mcp.types import MessageFrame
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -31,11 +37,11 @@ async def sse_client(
|
|||||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
||||||
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||||
"""
|
"""
|
||||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
read_stream: ReadStream
|
||||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
read_stream_writer: ReadStreamWriter
|
||||||
|
|
||||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
write_stream: WriteStream
|
||||||
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
write_stream_reader: WriteStreamReader
|
||||||
|
|
||||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||||
@@ -84,8 +90,11 @@ async def sse_client(
|
|||||||
|
|
||||||
case "message":
|
case "message":
|
||||||
try:
|
try:
|
||||||
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
|
message = MessageFrame(
|
||||||
sse.data
|
message=types.JSONRPCMessage.model_validate_json( # noqa: E501
|
||||||
|
sse.data
|
||||||
|
),
|
||||||
|
raw=sse,
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Received server message: {message}"
|
f"Received server message: {message}"
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import AsyncGenerator
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||||
@@ -10,6 +10,7 @@ from websockets.asyncio.client import connect as ws_connect
|
|||||||
from websockets.typing import Subprotocol
|
from websockets.typing import Subprotocol
|
||||||
|
|
||||||
import mcp.types as types
|
import mcp.types as types
|
||||||
|
from mcp.types import MessageFrame
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -19,8 +20,8 @@ async def websocket_client(
|
|||||||
url: str,
|
url: str,
|
||||||
) -> AsyncGenerator[
|
) -> AsyncGenerator[
|
||||||
tuple[
|
tuple[
|
||||||
MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
|
MemoryObjectReceiveStream[MessageFrame[Any] | Exception],
|
||||||
MemoryObjectSendStream[types.JSONRPCMessage],
|
MemoryObjectSendStream[MessageFrame[Any]],
|
||||||
],
|
],
|
||||||
None,
|
None,
|
||||||
]:
|
]:
|
||||||
@@ -53,7 +54,11 @@ async def websocket_client(
|
|||||||
async with read_stream_writer:
|
async with read_stream_writer:
|
||||||
async for raw_text in ws:
|
async for raw_text in ws:
|
||||||
try:
|
try:
|
||||||
message = types.JSONRPCMessage.model_validate_json(raw_text)
|
json_message = types.JSONRPCMessage.model_validate_json(
|
||||||
|
raw_text
|
||||||
|
)
|
||||||
|
# Create MessageFrame with JSON message as root
|
||||||
|
message = MessageFrame(message=json_message, raw=raw_text)
|
||||||
await read_stream_writer.send(message)
|
await read_stream_writer.send(message)
|
||||||
except ValidationError as exc:
|
except ValidationError as exc:
|
||||||
# If JSON parse or model validation fails, send the exception
|
# If JSON parse or model validation fails, send the exception
|
||||||
@@ -66,8 +71,8 @@ async def websocket_client(
|
|||||||
"""
|
"""
|
||||||
async with write_stream_reader:
|
async with write_stream_reader:
|
||||||
async for message in write_stream_reader:
|
async for message in write_stream_reader:
|
||||||
# Convert to a dict, then to JSON
|
# Extract the JSON-RPC message from MessageFrame and convert to JSON
|
||||||
msg_dict = message.model_dump(
|
msg_dict = message.message.model_dump(
|
||||||
by_alias=True, mode="json", exclude_none=True
|
by_alias=True, mode="json", exclude_none=True
|
||||||
)
|
)
|
||||||
await ws.send(json.dumps(msg_dict))
|
await ws.send(json.dumps(msg_dict))
|
||||||
|
|||||||
@@ -74,7 +74,6 @@ from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontext
|
|||||||
from typing import Any, AsyncIterator, Generic, TypeVar
|
from typing import Any, AsyncIterator, Generic, TypeVar
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
||||||
from pydantic import AnyUrl
|
from pydantic import AnyUrl
|
||||||
|
|
||||||
import mcp.types as types
|
import mcp.types as types
|
||||||
@@ -84,7 +83,7 @@ from mcp.server.session import ServerSession
|
|||||||
from mcp.server.stdio import stdio_server as stdio_server
|
from mcp.server.stdio import stdio_server as stdio_server
|
||||||
from mcp.shared.context import RequestContext
|
from mcp.shared.context import RequestContext
|
||||||
from mcp.shared.exceptions import McpError
|
from mcp.shared.exceptions import McpError
|
||||||
from mcp.shared.session import RequestResponder
|
from mcp.shared.session import ReadStream, RequestResponder, WriteStream
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -474,8 +473,8 @@ class Server(Generic[LifespanResultT]):
|
|||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
|
read_stream: ReadStream,
|
||||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
|
write_stream: WriteStream,
|
||||||
initialization_options: InitializationOptions,
|
initialization_options: InitializationOptions,
|
||||||
# When False, exceptions are returned as messages to the client.
|
# When False, exceptions are returned as messages to the client.
|
||||||
# When True, exceptions are raised, which will cause the server to shut down
|
# When True, exceptions are raised, which will cause the server to shut down
|
||||||
|
|||||||
@@ -5,9 +5,7 @@ and tools.
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from mcp.types import (
|
from mcp.types import ServerCapabilities
|
||||||
ServerCapabilities,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class InitializationOptions(BaseModel):
|
class InitializationOptions(BaseModel):
|
||||||
|
|||||||
@@ -42,14 +42,15 @@ from typing import Any, TypeVar
|
|||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
import anyio.lowlevel
|
import anyio.lowlevel
|
||||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
||||||
from pydantic import AnyUrl
|
from pydantic import AnyUrl
|
||||||
|
|
||||||
import mcp.types as types
|
import mcp.types as types
|
||||||
from mcp.server.models import InitializationOptions
|
from mcp.server.models import InitializationOptions
|
||||||
from mcp.shared.session import (
|
from mcp.shared.session import (
|
||||||
BaseSession,
|
BaseSession,
|
||||||
|
ReadStream,
|
||||||
RequestResponder,
|
RequestResponder,
|
||||||
|
WriteStream,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -76,8 +77,8 @@ class ServerSession(
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
|
read_stream: ReadStream,
|
||||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
|
write_stream: WriteStream,
|
||||||
init_options: InitializationOptions,
|
init_options: InitializationOptions,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ from urllib.parse import quote
|
|||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from sse_starlette import EventSourceResponse
|
from sse_starlette import EventSourceResponse
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
@@ -46,6 +45,13 @@ from starlette.responses import Response
|
|||||||
from starlette.types import Receive, Scope, Send
|
from starlette.types import Receive, Scope, Send
|
||||||
|
|
||||||
import mcp.types as types
|
import mcp.types as types
|
||||||
|
from mcp.shared.session import (
|
||||||
|
ReadStream,
|
||||||
|
ReadStreamWriter,
|
||||||
|
WriteStream,
|
||||||
|
WriteStreamReader,
|
||||||
|
)
|
||||||
|
from mcp.types import MessageFrame
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -63,9 +69,7 @@ class SseServerTransport:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_endpoint: str
|
_endpoint: str
|
||||||
_read_stream_writers: dict[
|
_read_stream_writers: dict[UUID, ReadStreamWriter]
|
||||||
UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(self, endpoint: str) -> None:
|
def __init__(self, endpoint: str) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -85,11 +89,11 @@ class SseServerTransport:
|
|||||||
raise ValueError("connect_sse can only handle HTTP requests")
|
raise ValueError("connect_sse can only handle HTTP requests")
|
||||||
|
|
||||||
logger.debug("Setting up SSE connection")
|
logger.debug("Setting up SSE connection")
|
||||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
read_stream: ReadStream
|
||||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
read_stream_writer: ReadStreamWriter
|
||||||
|
|
||||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
write_stream: WriteStream
|
||||||
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
write_stream_reader: WriteStreamReader
|
||||||
|
|
||||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||||
@@ -172,4 +176,4 @@ class SseServerTransport:
|
|||||||
logger.debug(f"Sending message to writer: {message}")
|
logger.debug(f"Sending message to writer: {message}")
|
||||||
response = Response("Accepted", status_code=202)
|
response = Response("Accepted", status_code=202)
|
||||||
await response(scope, receive, send)
|
await response(scope, receive, send)
|
||||||
await writer.send(message)
|
await writer.send(MessageFrame(message=message, raw=request))
|
||||||
|
|||||||
@@ -24,9 +24,15 @@ from io import TextIOWrapper
|
|||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
import anyio.lowlevel
|
import anyio.lowlevel
|
||||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
||||||
|
|
||||||
import mcp.types as types
|
import mcp.types as types
|
||||||
|
from mcp.shared.session import (
|
||||||
|
ReadStream,
|
||||||
|
ReadStreamWriter,
|
||||||
|
WriteStream,
|
||||||
|
WriteStreamReader,
|
||||||
|
)
|
||||||
|
from mcp.types import MessageFrame
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@@ -47,11 +53,11 @@ async def stdio_server(
|
|||||||
if not stdout:
|
if not stdout:
|
||||||
stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8"))
|
stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8"))
|
||||||
|
|
||||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
read_stream: ReadStream
|
||||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
read_stream_writer: ReadStreamWriter
|
||||||
|
|
||||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
write_stream: WriteStream
|
||||||
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
write_stream_reader: WriteStreamReader
|
||||||
|
|
||||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||||
@@ -66,7 +72,9 @@ async def stdio_server(
|
|||||||
await read_stream_writer.send(exc)
|
await read_stream_writer.send(exc)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
await read_stream_writer.send(message)
|
await read_stream_writer.send(
|
||||||
|
MessageFrame(message=message, raw=line)
|
||||||
|
)
|
||||||
except anyio.ClosedResourceError:
|
except anyio.ClosedResourceError:
|
||||||
await anyio.lowlevel.checkpoint()
|
await anyio.lowlevel.checkpoint()
|
||||||
|
|
||||||
@@ -74,6 +82,7 @@ async def stdio_server(
|
|||||||
try:
|
try:
|
||||||
async with write_stream_reader:
|
async with write_stream_reader:
|
||||||
async for message in write_stream_reader:
|
async for message in write_stream_reader:
|
||||||
|
# Extract the inner JSONRPCRequest/JSONRPCResponse from MessageFrame
|
||||||
json = message.model_dump_json(by_alias=True, exclude_none=True)
|
json = message.model_dump_json(by_alias=True, exclude_none=True)
|
||||||
await stdout.write(json + "\n")
|
await stdout.write(json + "\n")
|
||||||
await stdout.flush()
|
await stdout.flush()
|
||||||
|
|||||||
@@ -2,11 +2,17 @@ import logging
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
||||||
from starlette.types import Receive, Scope, Send
|
from starlette.types import Receive, Scope, Send
|
||||||
from starlette.websockets import WebSocket
|
from starlette.websockets import WebSocket
|
||||||
|
|
||||||
import mcp.types as types
|
import mcp.types as types
|
||||||
|
from mcp.shared.session import (
|
||||||
|
ReadStream,
|
||||||
|
ReadStreamWriter,
|
||||||
|
WriteStream,
|
||||||
|
WriteStreamReader,
|
||||||
|
)
|
||||||
|
from mcp.types import MessageFrame
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -21,11 +27,11 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send):
|
|||||||
websocket = WebSocket(scope, receive, send)
|
websocket = WebSocket(scope, receive, send)
|
||||||
await websocket.accept(subprotocol="mcp")
|
await websocket.accept(subprotocol="mcp")
|
||||||
|
|
||||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
read_stream: ReadStream
|
||||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
read_stream_writer: ReadStreamWriter
|
||||||
|
|
||||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
write_stream: WriteStream
|
||||||
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
write_stream_reader: WriteStreamReader
|
||||||
|
|
||||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||||
@@ -40,7 +46,9 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send):
|
|||||||
await read_stream_writer.send(exc)
|
await read_stream_writer.send(exc)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
await read_stream_writer.send(client_message)
|
await read_stream_writer.send(
|
||||||
|
MessageFrame(message=client_message, raw=message)
|
||||||
|
)
|
||||||
except anyio.ClosedResourceError:
|
except anyio.ClosedResourceError:
|
||||||
await websocket.close()
|
await websocket.close()
|
||||||
|
|
||||||
|
|||||||
@@ -11,11 +11,11 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStre
|
|||||||
|
|
||||||
from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT
|
from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT
|
||||||
from mcp.server import Server
|
from mcp.server import Server
|
||||||
from mcp.types import JSONRPCMessage
|
from mcp.types import MessageFrame
|
||||||
|
|
||||||
MessageStream = tuple[
|
MessageStream = tuple[
|
||||||
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
MemoryObjectReceiveStream[MessageFrame | Exception],
|
||||||
MemoryObjectSendStream[JSONRPCMessage],
|
MemoryObjectSendStream[MessageFrame],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -32,10 +32,10 @@ async def create_client_server_memory_streams() -> (
|
|||||||
"""
|
"""
|
||||||
# Create streams for both directions
|
# Create streams for both directions
|
||||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||||
JSONRPCMessage | Exception
|
MessageFrame | Exception
|
||||||
](1)
|
](1)
|
||||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||||
JSONRPCMessage | Exception
|
MessageFrame | Exception
|
||||||
](1)
|
](1)
|
||||||
|
|
||||||
client_streams = (server_to_client_receive, client_to_server_send)
|
client_streams = (server_to_client_receive, client_to_server_send)
|
||||||
@@ -60,12 +60,9 @@ async def create_connected_server_and_client_session(
|
|||||||
) -> AsyncGenerator[ClientSession, None]:
|
) -> AsyncGenerator[ClientSession, None]:
|
||||||
"""Creates a ClientSession that is connected to a running MCP server."""
|
"""Creates a ClientSession that is connected to a running MCP server."""
|
||||||
async with create_client_server_memory_streams() as (
|
async with create_client_server_memory_streams() as (
|
||||||
client_streams,
|
(client_read, client_write),
|
||||||
server_streams,
|
(server_read, server_write),
|
||||||
):
|
):
|
||||||
client_read, client_write = client_streams
|
|
||||||
server_read, server_write = server_streams
|
|
||||||
|
|
||||||
# Create a cancel scope for the server task
|
# Create a cancel scope for the server task
|
||||||
async with anyio.create_task_group() as tg:
|
async with anyio.create_task_group() as tg:
|
||||||
tg.start_soon(
|
tg.start_soon(
|
||||||
|
|||||||
@@ -22,12 +22,18 @@ from mcp.types import (
|
|||||||
JSONRPCNotification,
|
JSONRPCNotification,
|
||||||
JSONRPCRequest,
|
JSONRPCRequest,
|
||||||
JSONRPCResponse,
|
JSONRPCResponse,
|
||||||
|
MessageFrame,
|
||||||
RequestParams,
|
RequestParams,
|
||||||
ServerNotification,
|
ServerNotification,
|
||||||
ServerRequest,
|
ServerRequest,
|
||||||
ServerResult,
|
ServerResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ReadStream = MemoryObjectReceiveStream[MessageFrame | Exception]
|
||||||
|
ReadStreamWriter = MemoryObjectSendStream[MessageFrame | Exception]
|
||||||
|
WriteStream = MemoryObjectSendStream[MessageFrame]
|
||||||
|
WriteStreamReader = MemoryObjectReceiveStream[MessageFrame]
|
||||||
|
|
||||||
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
|
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
|
||||||
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
|
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
|
||||||
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
|
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
|
||||||
@@ -165,8 +171,8 @@ class BaseSession(
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
read_stream: ReadStream,
|
||||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
write_stream: WriteStream,
|
||||||
receive_request_type: type[ReceiveRequestT],
|
receive_request_type: type[ReceiveRequestT],
|
||||||
receive_notification_type: type[ReceiveNotificationT],
|
receive_notification_type: type[ReceiveNotificationT],
|
||||||
# If none, reading will never time out
|
# If none, reading will never time out
|
||||||
@@ -242,7 +248,9 @@ class BaseSession(
|
|||||||
|
|
||||||
# TODO: Support progress callbacks
|
# TODO: Support progress callbacks
|
||||||
|
|
||||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_request))
|
await self._write_stream.send(
|
||||||
|
MessageFrame(message=JSONRPCMessage(jsonrpc_request), raw=None)
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with anyio.fail_after(
|
with anyio.fail_after(
|
||||||
@@ -278,14 +286,18 @@ class BaseSession(
|
|||||||
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
|
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_notification))
|
await self._write_stream.send(
|
||||||
|
MessageFrame(message=JSONRPCMessage(jsonrpc_notification), raw=None)
|
||||||
|
)
|
||||||
|
|
||||||
async def _send_response(
|
async def _send_response(
|
||||||
self, request_id: RequestId, response: SendResultT | ErrorData
|
self, request_id: RequestId, response: SendResultT | ErrorData
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(response, ErrorData):
|
if isinstance(response, ErrorData):
|
||||||
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
|
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
|
||||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_error))
|
await self._write_stream.send(
|
||||||
|
MessageFrame(message=JSONRPCMessage(jsonrpc_error), raw=None)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
jsonrpc_response = JSONRPCResponse(
|
jsonrpc_response = JSONRPCResponse(
|
||||||
jsonrpc="2.0",
|
jsonrpc="2.0",
|
||||||
@@ -294,7 +306,9 @@ class BaseSession(
|
|||||||
by_alias=True, mode="json", exclude_none=True
|
by_alias=True, mode="json", exclude_none=True
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_response))
|
await self._write_stream.send(
|
||||||
|
MessageFrame(message=JSONRPCMessage(jsonrpc_response), raw=None)
|
||||||
|
)
|
||||||
|
|
||||||
async def _receive_loop(self) -> None:
|
async def _receive_loop(self) -> None:
|
||||||
async with (
|
async with (
|
||||||
@@ -302,10 +316,13 @@ class BaseSession(
|
|||||||
self._write_stream,
|
self._write_stream,
|
||||||
self._incoming_message_stream_writer,
|
self._incoming_message_stream_writer,
|
||||||
):
|
):
|
||||||
async for message in self._read_stream:
|
async for raw_message in self._read_stream:
|
||||||
if isinstance(message, Exception):
|
if isinstance(raw_message, Exception):
|
||||||
await self._incoming_message_stream_writer.send(message)
|
await self._incoming_message_stream_writer.send(raw_message)
|
||||||
elif isinstance(message.root, JSONRPCRequest):
|
continue
|
||||||
|
|
||||||
|
message = raw_message.message
|
||||||
|
if isinstance(message.root, JSONRPCRequest):
|
||||||
validated_request = self._receive_request_type.model_validate(
|
validated_request = self._receive_request_type.model_validate(
|
||||||
message.root.model_dump(
|
message.root.model_dump(
|
||||||
by_alias=True, mode="json", exclude_none=True
|
by_alias=True, mode="json", exclude_none=True
|
||||||
|
|||||||
@@ -180,6 +180,49 @@ class JSONRPCMessage(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
RawT = TypeVar("RawT")
|
||||||
|
|
||||||
|
|
||||||
|
class MessageFrame(BaseModel, Generic[RawT]):
|
||||||
|
"""
|
||||||
|
A wrapper around the general message received that contains both the parsed message
|
||||||
|
and the raw message.
|
||||||
|
|
||||||
|
This class serves as an encapsulation for JSON-RPC messages, providing access to
|
||||||
|
both the parsed structure (root) and the original raw data. This design is
|
||||||
|
particularly useful for Server-Sent Events (SSE) consumers who may need to access
|
||||||
|
additional metadata or headers associated with the message.
|
||||||
|
|
||||||
|
The 'root' attribute contains the parsed JSONRPCMessage, which could be a request,
|
||||||
|
notification, response, or error. The 'raw' attribute preserves the original
|
||||||
|
message as received, allowing access to any additional context or metadata that
|
||||||
|
might be lost in parsing.
|
||||||
|
|
||||||
|
This dual representation allows for flexible handling of messages, where consumers
|
||||||
|
can work with the structured data for standard operations, but still have the
|
||||||
|
option to examine or utilize the raw data when needed, such as for debugging,
|
||||||
|
logging, or accessing transport-specific information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
message: JSONRPCMessage
|
||||||
|
raw: RawT | None = None
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
def model_dump(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Dumps the model to a dictionary, delegating to the root JSONRPCMessage.
|
||||||
|
This method allows for consistent serialization of the parsed message.
|
||||||
|
"""
|
||||||
|
return self.message.model_dump(*args, **kwargs)
|
||||||
|
|
||||||
|
def model_dump_json(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Dumps the model to a JSON string, delegating to the root JSONRPCMessage.
|
||||||
|
This method provides a convenient way to serialize the parsed message to JSON.
|
||||||
|
"""
|
||||||
|
return self.message.model_dump_json(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class EmptyResult(Result):
|
class EmptyResult(Result):
|
||||||
"""A response that indicates success but carries no data."""
|
"""A response that indicates success but carries no data."""
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from types import NoneType
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -11,9 +13,9 @@ from mcp.types import (
|
|||||||
InitializeRequest,
|
InitializeRequest,
|
||||||
InitializeResult,
|
InitializeResult,
|
||||||
JSONRPCMessage,
|
JSONRPCMessage,
|
||||||
JSONRPCNotification,
|
|
||||||
JSONRPCRequest,
|
JSONRPCRequest,
|
||||||
JSONRPCResponse,
|
JSONRPCResponse,
|
||||||
|
MessageFrame,
|
||||||
ServerCapabilities,
|
ServerCapabilities,
|
||||||
ServerResult,
|
ServerResult,
|
||||||
)
|
)
|
||||||
@@ -22,10 +24,10 @@ from mcp.types import (
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_client_session_initialize():
|
async def test_client_session_initialize():
|
||||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||||
JSONRPCMessage
|
MessageFrame[NoneType]
|
||||||
](1)
|
](1)
|
||||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||||
JSONRPCMessage
|
MessageFrame[NoneType]
|
||||||
](1)
|
](1)
|
||||||
|
|
||||||
initialized_notification = None
|
initialized_notification = None
|
||||||
@@ -34,7 +36,7 @@ async def test_client_session_initialize():
|
|||||||
nonlocal initialized_notification
|
nonlocal initialized_notification
|
||||||
|
|
||||||
jsonrpc_request = await client_to_server_receive.receive()
|
jsonrpc_request = await client_to_server_receive.receive()
|
||||||
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
|
assert isinstance(jsonrpc_request, MessageFrame)
|
||||||
request = ClientRequest.model_validate(
|
request = ClientRequest.model_validate(
|
||||||
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
|
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||||
)
|
)
|
||||||
@@ -56,21 +58,25 @@ async def test_client_session_initialize():
|
|||||||
)
|
)
|
||||||
|
|
||||||
async with server_to_client_send:
|
async with server_to_client_send:
|
||||||
|
assert isinstance(jsonrpc_request.message.root, JSONRPCRequest)
|
||||||
await server_to_client_send.send(
|
await server_to_client_send.send(
|
||||||
JSONRPCMessage(
|
MessageFrame(
|
||||||
JSONRPCResponse(
|
message=JSONRPCMessage(
|
||||||
jsonrpc="2.0",
|
JSONRPCResponse(
|
||||||
id=jsonrpc_request.root.id,
|
jsonrpc="2.0",
|
||||||
result=result.model_dump(
|
id=jsonrpc_request.message.root.id,
|
||||||
by_alias=True, mode="json", exclude_none=True
|
result=result.model_dump(
|
||||||
),
|
by_alias=True, mode="json", exclude_none=True
|
||||||
)
|
),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
raw=None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
jsonrpc_notification = await client_to_server_receive.receive()
|
jsonrpc_notification = await client_to_server_receive.receive()
|
||||||
assert isinstance(jsonrpc_notification.root, JSONRPCNotification)
|
assert isinstance(jsonrpc_notification.message, JSONRPCMessage)
|
||||||
initialized_notification = ClientNotification.model_validate(
|
initialized_notification = ClientNotification.model_validate(
|
||||||
jsonrpc_notification.model_dump(
|
jsonrpc_notification.message.model_dump(
|
||||||
by_alias=True, mode="json", exclude_none=True
|
by_alias=True, mode="json", exclude_none=True
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from mcp.types import (
|
|||||||
JSONRPCMessage,
|
JSONRPCMessage,
|
||||||
JSONRPCNotification,
|
JSONRPCNotification,
|
||||||
JSONRPCRequest,
|
JSONRPCRequest,
|
||||||
|
MessageFrame,
|
||||||
NotificationParams,
|
NotificationParams,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -64,7 +65,9 @@ async def test_request_id_match() -> None:
|
|||||||
jsonrpc="2.0",
|
jsonrpc="2.0",
|
||||||
)
|
)
|
||||||
|
|
||||||
await client_writer.send(JSONRPCMessage(root=init_req))
|
await client_writer.send(
|
||||||
|
MessageFrame(message=JSONRPCMessage(root=init_req), raw=None)
|
||||||
|
)
|
||||||
await server_reader.receive() # Get init response but don't need to check it
|
await server_reader.receive() # Get init response but don't need to check it
|
||||||
|
|
||||||
# Send initialized notification
|
# Send initialized notification
|
||||||
@@ -73,21 +76,27 @@ async def test_request_id_match() -> None:
|
|||||||
params=NotificationParams().model_dump(by_alias=True, exclude_none=True),
|
params=NotificationParams().model_dump(by_alias=True, exclude_none=True),
|
||||||
jsonrpc="2.0",
|
jsonrpc="2.0",
|
||||||
)
|
)
|
||||||
await client_writer.send(JSONRPCMessage(root=initialized_notification))
|
await client_writer.send(
|
||||||
|
MessageFrame(
|
||||||
|
message=JSONRPCMessage(root=initialized_notification), raw=None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Send ping request with custom ID
|
# Send ping request with custom ID
|
||||||
ping_request = JSONRPCRequest(
|
ping_request = JSONRPCRequest(
|
||||||
id=custom_request_id, method="ping", params={}, jsonrpc="2.0"
|
id=custom_request_id, method="ping", params={}, jsonrpc="2.0"
|
||||||
)
|
)
|
||||||
|
|
||||||
await client_writer.send(JSONRPCMessage(root=ping_request))
|
await client_writer.send(
|
||||||
|
MessageFrame(message=JSONRPCMessage(root=ping_request), raw=None)
|
||||||
|
)
|
||||||
|
|
||||||
# Read response
|
# Read response
|
||||||
response = await server_reader.receive()
|
response = await server_reader.receive()
|
||||||
|
|
||||||
# Verify response ID matches request ID
|
# Verify response ID matches request ID
|
||||||
assert (
|
assert (
|
||||||
response.root.id == custom_request_id
|
response.message.root.id == custom_request_id
|
||||||
), "Response ID should match request ID"
|
), "Response ID should match request ID"
|
||||||
|
|
||||||
# Cancel server task
|
# Cancel server task
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from mcp.types import (
|
|||||||
JSONRPCMessage,
|
JSONRPCMessage,
|
||||||
JSONRPCNotification,
|
JSONRPCNotification,
|
||||||
JSONRPCRequest,
|
JSONRPCRequest,
|
||||||
|
MessageFrame,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -64,7 +65,7 @@ async def test_lowlevel_server_lifespan():
|
|||||||
send_stream2,
|
send_stream2,
|
||||||
InitializationOptions(
|
InitializationOptions(
|
||||||
server_name="test",
|
server_name="test",
|
||||||
server_version="0.1.0",
|
server_version="1.0.0",
|
||||||
capabilities=server.get_capabilities(
|
capabilities=server.get_capabilities(
|
||||||
notification_options=NotificationOptions(),
|
notification_options=NotificationOptions(),
|
||||||
experimental_capabilities={},
|
experimental_capabilities={},
|
||||||
@@ -82,42 +83,51 @@ async def test_lowlevel_server_lifespan():
|
|||||||
clientInfo=Implementation(name="test-client", version="0.1.0"),
|
clientInfo=Implementation(name="test-client", version="0.1.0"),
|
||||||
)
|
)
|
||||||
await send_stream1.send(
|
await send_stream1.send(
|
||||||
JSONRPCMessage(
|
MessageFrame(
|
||||||
root=JSONRPCRequest(
|
message=JSONRPCMessage(
|
||||||
jsonrpc="2.0",
|
root=JSONRPCRequest(
|
||||||
id=1,
|
jsonrpc="2.0",
|
||||||
method="initialize",
|
id=1,
|
||||||
params=TypeAdapter(InitializeRequestParams).dump_python(params),
|
method="initialize",
|
||||||
)
|
params=TypeAdapter(InitializeRequestParams).dump_python(params),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
raw=None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
response = await receive_stream2.receive()
|
response = await receive_stream2.receive()
|
||||||
|
|
||||||
# Send initialized notification
|
# Send initialized notification
|
||||||
await send_stream1.send(
|
await send_stream1.send(
|
||||||
JSONRPCMessage(
|
MessageFrame(
|
||||||
root=JSONRPCNotification(
|
message=JSONRPCMessage(
|
||||||
jsonrpc="2.0",
|
root=JSONRPCNotification(
|
||||||
method="notifications/initialized",
|
jsonrpc="2.0",
|
||||||
)
|
method="notifications/initialized",
|
||||||
|
)
|
||||||
|
),
|
||||||
|
raw=None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call the tool to verify lifespan context
|
# Call the tool to verify lifespan context
|
||||||
await send_stream1.send(
|
await send_stream1.send(
|
||||||
JSONRPCMessage(
|
MessageFrame(
|
||||||
root=JSONRPCRequest(
|
message=JSONRPCMessage(
|
||||||
jsonrpc="2.0",
|
root=JSONRPCRequest(
|
||||||
id=2,
|
jsonrpc="2.0",
|
||||||
method="tools/call",
|
id=2,
|
||||||
params={"name": "check_lifespan", "arguments": {}},
|
method="tools/call",
|
||||||
)
|
params={"name": "check_lifespan", "arguments": {}},
|
||||||
|
)
|
||||||
|
),
|
||||||
|
raw=None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get response and verify
|
# Get response and verify
|
||||||
response = await receive_stream2.receive()
|
response = await receive_stream2.receive()
|
||||||
assert response.root.result["content"][0]["text"] == "true"
|
assert response.message.root.result["content"][0]["text"] == "true"
|
||||||
|
|
||||||
# Cancel server task
|
# Cancel server task
|
||||||
tg.cancel_scope.cancel()
|
tg.cancel_scope.cancel()
|
||||||
@@ -178,42 +188,51 @@ async def test_fastmcp_server_lifespan():
|
|||||||
clientInfo=Implementation(name="test-client", version="0.1.0"),
|
clientInfo=Implementation(name="test-client", version="0.1.0"),
|
||||||
)
|
)
|
||||||
await send_stream1.send(
|
await send_stream1.send(
|
||||||
JSONRPCMessage(
|
MessageFrame(
|
||||||
root=JSONRPCRequest(
|
message=JSONRPCMessage(
|
||||||
jsonrpc="2.0",
|
root=JSONRPCRequest(
|
||||||
id=1,
|
jsonrpc="2.0",
|
||||||
method="initialize",
|
id=1,
|
||||||
params=TypeAdapter(InitializeRequestParams).dump_python(params),
|
method="initialize",
|
||||||
)
|
params=TypeAdapter(InitializeRequestParams).dump_python(params),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
raw=None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
response = await receive_stream2.receive()
|
response = await receive_stream2.receive()
|
||||||
|
|
||||||
# Send initialized notification
|
# Send initialized notification
|
||||||
await send_stream1.send(
|
await send_stream1.send(
|
||||||
JSONRPCMessage(
|
MessageFrame(
|
||||||
root=JSONRPCNotification(
|
message=JSONRPCMessage(
|
||||||
jsonrpc="2.0",
|
root=JSONRPCNotification(
|
||||||
method="notifications/initialized",
|
jsonrpc="2.0",
|
||||||
)
|
method="notifications/initialized",
|
||||||
|
)
|
||||||
|
),
|
||||||
|
raw=None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call the tool to verify lifespan context
|
# Call the tool to verify lifespan context
|
||||||
await send_stream1.send(
|
await send_stream1.send(
|
||||||
JSONRPCMessage(
|
MessageFrame(
|
||||||
root=JSONRPCRequest(
|
message=JSONRPCMessage(
|
||||||
jsonrpc="2.0",
|
root=JSONRPCRequest(
|
||||||
id=2,
|
jsonrpc="2.0",
|
||||||
method="tools/call",
|
id=2,
|
||||||
params={"name": "check_lifespan", "arguments": {}},
|
method="tools/call",
|
||||||
)
|
params={"name": "check_lifespan", "arguments": {}},
|
||||||
|
)
|
||||||
|
),
|
||||||
|
raw=None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get response and verify
|
# Get response and verify
|
||||||
response = await receive_stream2.receive()
|
response = await receive_stream2.receive()
|
||||||
assert response.root.result["content"][0]["text"] == "true"
|
assert response.message.root.result["content"][0]["text"] == "true"
|
||||||
|
|
||||||
# Cancel server task
|
# Cancel server task
|
||||||
tg.cancel_scope.cancel()
|
tg.cancel_scope.cancel()
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from mcp.server.session import ServerSession
|
|||||||
from mcp.types import (
|
from mcp.types import (
|
||||||
ClientNotification,
|
ClientNotification,
|
||||||
InitializedNotification,
|
InitializedNotification,
|
||||||
JSONRPCMessage,
|
MessageFrame,
|
||||||
PromptsCapability,
|
PromptsCapability,
|
||||||
ResourcesCapability,
|
ResourcesCapability,
|
||||||
ServerCapabilities,
|
ServerCapabilities,
|
||||||
@@ -19,10 +19,10 @@ from mcp.types import (
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_server_session_initialize():
|
async def test_server_session_initialize():
|
||||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||||
JSONRPCMessage
|
MessageFrame[None]
|
||||||
](1)
|
](1)
|
||||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||||
JSONRPCMessage
|
MessageFrame[None]
|
||||||
](1)
|
](1)
|
||||||
|
|
||||||
async def run_client(client: ClientSession):
|
async def run_client(client: ClientSession):
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import anyio
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from mcp.server.stdio import stdio_server
|
from mcp.server.stdio import stdio_server
|
||||||
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
|
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, MessageFrame
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -13,8 +13,8 @@ async def test_stdio_server():
|
|||||||
stdout = io.StringIO()
|
stdout = io.StringIO()
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")),
|
JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"),
|
||||||
JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})),
|
JSONRPCResponse(jsonrpc="2.0", id=2, result={}),
|
||||||
]
|
]
|
||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
@@ -35,17 +35,29 @@ async def test_stdio_server():
|
|||||||
|
|
||||||
# Verify received messages
|
# Verify received messages
|
||||||
assert len(received_messages) == 2
|
assert len(received_messages) == 2
|
||||||
assert received_messages[0] == JSONRPCMessage(
|
assert isinstance(received_messages[0].message, JSONRPCMessage)
|
||||||
root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
|
assert isinstance(received_messages[0].message.root, JSONRPCRequest)
|
||||||
)
|
assert received_messages[0].message.root.id == 1
|
||||||
assert received_messages[1] == JSONRPCMessage(
|
assert received_messages[0].message.root.method == "ping"
|
||||||
root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})
|
|
||||||
)
|
assert isinstance(received_messages[1].message, JSONRPCMessage)
|
||||||
|
assert isinstance(received_messages[1].message.root, JSONRPCResponse)
|
||||||
|
assert received_messages[1].message.root.id == 2
|
||||||
|
|
||||||
# Test sending responses from the server
|
# Test sending responses from the server
|
||||||
responses = [
|
responses = [
|
||||||
JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")),
|
MessageFrame(
|
||||||
JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})),
|
message=JSONRPCMessage(
|
||||||
|
root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")
|
||||||
|
),
|
||||||
|
raw=None,
|
||||||
|
),
|
||||||
|
MessageFrame(
|
||||||
|
message=JSONRPCMessage(
|
||||||
|
root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})
|
||||||
|
),
|
||||||
|
raw=None,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
async with write_stream:
|
async with write_stream:
|
||||||
@@ -56,13 +68,10 @@ async def test_stdio_server():
|
|||||||
output_lines = stdout.readlines()
|
output_lines = stdout.readlines()
|
||||||
assert len(output_lines) == 2
|
assert len(output_lines) == 2
|
||||||
|
|
||||||
received_responses = [
|
# Parse and verify the JSON responses directly
|
||||||
JSONRPCMessage.model_validate_json(line.strip()) for line in output_lines
|
request_json = JSONRPCRequest.model_validate_json(output_lines[0].strip())
|
||||||
]
|
response_json = JSONRPCResponse.model_validate_json(output_lines[1].strip())
|
||||||
assert len(received_responses) == 2
|
|
||||||
assert received_responses[0] == JSONRPCMessage(
|
assert request_json.id == 3
|
||||||
root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")
|
assert request_json.method == "ping"
|
||||||
)
|
assert response_json.id == 4
|
||||||
assert received_responses[1] == JSONRPCMessage(
|
|
||||||
root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})
|
|
||||||
)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user