mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Wrap JSONRPC messages with SessionMessage for metadata support (#590)
This commit is contained in:
@@ -11,8 +11,8 @@ import mcp.types as types
|
|||||||
from mcp.client.session import ClientSession
|
from mcp.client.session import ClientSession
|
||||||
from mcp.client.sse import sse_client
|
from mcp.client.sse import sse_client
|
||||||
from mcp.client.stdio import StdioServerParameters, stdio_client
|
from mcp.client.stdio import StdioServerParameters, stdio_client
|
||||||
|
from mcp.shared.message import SessionMessage
|
||||||
from mcp.shared.session import RequestResponder
|
from mcp.shared.session import RequestResponder
|
||||||
from mcp.types import JSONRPCMessage
|
|
||||||
|
|
||||||
if not sys.warnoptions:
|
if not sys.warnoptions:
|
||||||
import warnings
|
import warnings
|
||||||
@@ -36,8 +36,8 @@ async def message_handler(
|
|||||||
|
|
||||||
|
|
||||||
async def run_session(
|
async def run_session(
|
||||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
write_stream: MemoryObjectSendStream[SessionMessage],
|
||||||
client_info: types.Implementation | None = None,
|
client_info: types.Implementation | None = None,
|
||||||
):
|
):
|
||||||
async with ClientSession(
|
async with ClientSession(
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from pydantic import AnyUrl, TypeAdapter
|
|||||||
|
|
||||||
import mcp.types as types
|
import mcp.types as types
|
||||||
from mcp.shared.context import RequestContext
|
from mcp.shared.context import RequestContext
|
||||||
|
from mcp.shared.message import SessionMessage
|
||||||
from mcp.shared.session import BaseSession, RequestResponder
|
from mcp.shared.session import BaseSession, RequestResponder
|
||||||
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
||||||
|
|
||||||
@@ -92,8 +93,8 @@ class ClientSession(
|
|||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
|
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
|
write_stream: MemoryObjectSendStream[SessionMessage],
|
||||||
read_timeout_seconds: timedelta | None = None,
|
read_timeout_seconds: timedelta | None = None,
|
||||||
sampling_callback: SamplingFnT | None = None,
|
sampling_callback: SamplingFnT | None = None,
|
||||||
list_roots_callback: ListRootsFnT | None = None,
|
list_roots_callback: ListRootsFnT | None = None,
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStre
|
|||||||
from httpx_sse import aconnect_sse
|
from httpx_sse import aconnect_sse
|
||||||
|
|
||||||
import mcp.types as types
|
import mcp.types as types
|
||||||
|
from mcp.shared.message import SessionMessage
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -31,11 +32,11 @@ async def sse_client(
|
|||||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
||||||
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||||
"""
|
"""
|
||||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
|
||||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
|
||||||
|
|
||||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
write_stream: MemoryObjectSendStream[SessionMessage]
|
||||||
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
|
||||||
|
|
||||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||||
@@ -97,7 +98,8 @@ async def sse_client(
|
|||||||
await read_stream_writer.send(exc)
|
await read_stream_writer.send(exc)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
await read_stream_writer.send(message)
|
session_message = SessionMessage(message)
|
||||||
|
await read_stream_writer.send(session_message)
|
||||||
case _:
|
case _:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Unknown SSE event: {sse.event}"
|
f"Unknown SSE event: {sse.event}"
|
||||||
@@ -111,11 +113,13 @@ async def sse_client(
|
|||||||
async def post_writer(endpoint_url: str):
|
async def post_writer(endpoint_url: str):
|
||||||
try:
|
try:
|
||||||
async with write_stream_reader:
|
async with write_stream_reader:
|
||||||
async for message in write_stream_reader:
|
async for session_message in write_stream_reader:
|
||||||
logger.debug(f"Sending client message: {message}")
|
logger.debug(
|
||||||
|
f"Sending client message: {session_message}"
|
||||||
|
)
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
endpoint_url,
|
endpoint_url,
|
||||||
json=message.model_dump(
|
json=session_message.message.model_dump(
|
||||||
by_alias=True,
|
by_alias=True,
|
||||||
mode="json",
|
mode="json",
|
||||||
exclude_none=True,
|
exclude_none=True,
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from anyio.streams.text import TextReceiveStream
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
import mcp.types as types
|
import mcp.types as types
|
||||||
|
from mcp.shared.message import SessionMessage
|
||||||
|
|
||||||
from .win32 import (
|
from .win32 import (
|
||||||
create_windows_process,
|
create_windows_process,
|
||||||
@@ -98,11 +99,11 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
|
|||||||
Client transport for stdio: this will connect to a server by spawning a
|
Client transport for stdio: this will connect to a server by spawning a
|
||||||
process and communicating with it over stdin/stdout.
|
process and communicating with it over stdin/stdout.
|
||||||
"""
|
"""
|
||||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
|
||||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
|
||||||
|
|
||||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
write_stream: MemoryObjectSendStream[SessionMessage]
|
||||||
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
|
||||||
|
|
||||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||||
@@ -143,7 +144,8 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
|
|||||||
await read_stream_writer.send(exc)
|
await read_stream_writer.send(exc)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
await read_stream_writer.send(message)
|
session_message = SessionMessage(message)
|
||||||
|
await read_stream_writer.send(session_message)
|
||||||
except anyio.ClosedResourceError:
|
except anyio.ClosedResourceError:
|
||||||
await anyio.lowlevel.checkpoint()
|
await anyio.lowlevel.checkpoint()
|
||||||
|
|
||||||
@@ -152,8 +154,10 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
async with write_stream_reader:
|
async with write_stream_reader:
|
||||||
async for message in write_stream_reader:
|
async for session_message in write_stream_reader:
|
||||||
json = message.model_dump_json(by_alias=True, exclude_none=True)
|
json = session_message.message.model_dump_json(
|
||||||
|
by_alias=True, exclude_none=True
|
||||||
|
)
|
||||||
await process.stdin.send(
|
await process.stdin.send(
|
||||||
(json + "\n").encode(
|
(json + "\n").encode(
|
||||||
encoding=server.encoding,
|
encoding=server.encoding,
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import anyio
|
|||||||
import httpx
|
import httpx
|
||||||
from httpx_sse import EventSource, aconnect_sse
|
from httpx_sse import EventSource, aconnect_sse
|
||||||
|
|
||||||
|
from mcp.shared.message import SessionMessage
|
||||||
from mcp.types import (
|
from mcp.types import (
|
||||||
ErrorData,
|
ErrorData,
|
||||||
JSONRPCError,
|
JSONRPCError,
|
||||||
@@ -52,10 +53,10 @@ async def streamablehttp_client(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
read_stream_writer, read_stream = anyio.create_memory_object_stream[
|
read_stream_writer, read_stream = anyio.create_memory_object_stream[
|
||||||
JSONRPCMessage | Exception
|
SessionMessage | Exception
|
||||||
](0)
|
](0)
|
||||||
write_stream, write_stream_reader = anyio.create_memory_object_stream[
|
write_stream, write_stream_reader = anyio.create_memory_object_stream[
|
||||||
JSONRPCMessage
|
SessionMessage
|
||||||
](0)
|
](0)
|
||||||
|
|
||||||
async def get_stream():
|
async def get_stream():
|
||||||
@@ -86,7 +87,8 @@ async def streamablehttp_client(
|
|||||||
try:
|
try:
|
||||||
message = JSONRPCMessage.model_validate_json(sse.data)
|
message = JSONRPCMessage.model_validate_json(sse.data)
|
||||||
logger.debug(f"GET message: {message}")
|
logger.debug(f"GET message: {message}")
|
||||||
await read_stream_writer.send(message)
|
session_message = SessionMessage(message)
|
||||||
|
await read_stream_writer.send(session_message)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error(f"Error parsing GET message: {exc}")
|
logger.error(f"Error parsing GET message: {exc}")
|
||||||
await read_stream_writer.send(exc)
|
await read_stream_writer.send(exc)
|
||||||
@@ -100,7 +102,8 @@ async def streamablehttp_client(
|
|||||||
nonlocal session_id
|
nonlocal session_id
|
||||||
try:
|
try:
|
||||||
async with write_stream_reader:
|
async with write_stream_reader:
|
||||||
async for message in write_stream_reader:
|
async for session_message in write_stream_reader:
|
||||||
|
message = session_message.message
|
||||||
# Add session ID to headers if we have one
|
# Add session ID to headers if we have one
|
||||||
post_headers = request_headers.copy()
|
post_headers = request_headers.copy()
|
||||||
if session_id:
|
if session_id:
|
||||||
@@ -141,9 +144,10 @@ async def streamablehttp_client(
|
|||||||
message="Session terminated",
|
message="Session terminated",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
await read_stream_writer.send(
|
session_message = SessionMessage(
|
||||||
JSONRPCMessage(jsonrpc_error)
|
JSONRPCMessage(jsonrpc_error)
|
||||||
)
|
)
|
||||||
|
await read_stream_writer.send(session_message)
|
||||||
continue
|
continue
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
@@ -163,7 +167,8 @@ async def streamablehttp_client(
|
|||||||
json_message = JSONRPCMessage.model_validate_json(
|
json_message = JSONRPCMessage.model_validate_json(
|
||||||
content
|
content
|
||||||
)
|
)
|
||||||
await read_stream_writer.send(json_message)
|
session_message = SessionMessage(json_message)
|
||||||
|
await read_stream_writer.send(session_message)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error(f"Error parsing JSON response: {exc}")
|
logger.error(f"Error parsing JSON response: {exc}")
|
||||||
await read_stream_writer.send(exc)
|
await read_stream_writer.send(exc)
|
||||||
@@ -175,11 +180,15 @@ async def streamablehttp_client(
|
|||||||
async for sse in event_source.aiter_sse():
|
async for sse in event_source.aiter_sse():
|
||||||
if sse.event == "message":
|
if sse.event == "message":
|
||||||
try:
|
try:
|
||||||
await read_stream_writer.send(
|
message = (
|
||||||
JSONRPCMessage.model_validate_json(
|
JSONRPCMessage.model_validate_json(
|
||||||
sse.data
|
sse.data
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
session_message = SessionMessage(message)
|
||||||
|
await read_stream_writer.send(
|
||||||
|
session_message
|
||||||
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception("Error parsing message")
|
logger.exception("Error parsing message")
|
||||||
await read_stream_writer.send(exc)
|
await read_stream_writer.send(exc)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from websockets.asyncio.client import connect as ws_connect
|
|||||||
from websockets.typing import Subprotocol
|
from websockets.typing import Subprotocol
|
||||||
|
|
||||||
import mcp.types as types
|
import mcp.types as types
|
||||||
|
from mcp.shared.message import SessionMessage
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -19,8 +20,8 @@ async def websocket_client(
|
|||||||
url: str,
|
url: str,
|
||||||
) -> AsyncGenerator[
|
) -> AsyncGenerator[
|
||||||
tuple[
|
tuple[
|
||||||
MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
|
MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||||
MemoryObjectSendStream[types.JSONRPCMessage],
|
MemoryObjectSendStream[SessionMessage],
|
||||||
],
|
],
|
||||||
None,
|
None,
|
||||||
]:
|
]:
|
||||||
@@ -39,10 +40,10 @@ async def websocket_client(
|
|||||||
# Create two in-memory streams:
|
# Create two in-memory streams:
|
||||||
# - One for incoming messages (read_stream, written by ws_reader)
|
# - One for incoming messages (read_stream, written by ws_reader)
|
||||||
# - One for outgoing messages (write_stream, read by ws_writer)
|
# - One for outgoing messages (write_stream, read by ws_writer)
|
||||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
|
||||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
|
||||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
write_stream: MemoryObjectSendStream[SessionMessage]
|
||||||
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
|
||||||
|
|
||||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||||
@@ -59,7 +60,8 @@ async def websocket_client(
|
|||||||
async for raw_text in ws:
|
async for raw_text in ws:
|
||||||
try:
|
try:
|
||||||
message = types.JSONRPCMessage.model_validate_json(raw_text)
|
message = types.JSONRPCMessage.model_validate_json(raw_text)
|
||||||
await read_stream_writer.send(message)
|
session_message = SessionMessage(message)
|
||||||
|
await read_stream_writer.send(session_message)
|
||||||
except ValidationError as exc:
|
except ValidationError as exc:
|
||||||
# If JSON parse or model validation fails, send the exception
|
# If JSON parse or model validation fails, send the exception
|
||||||
await read_stream_writer.send(exc)
|
await read_stream_writer.send(exc)
|
||||||
@@ -70,9 +72,9 @@ async def websocket_client(
|
|||||||
sends them to the server.
|
sends them to the server.
|
||||||
"""
|
"""
|
||||||
async with write_stream_reader:
|
async with write_stream_reader:
|
||||||
async for message in write_stream_reader:
|
async for session_message in write_stream_reader:
|
||||||
# Convert to a dict, then to JSON
|
# Convert to a dict, then to JSON
|
||||||
msg_dict = message.model_dump(
|
msg_dict = session_message.message.model_dump(
|
||||||
by_alias=True, mode="json", exclude_none=True
|
by_alias=True, mode="json", exclude_none=True
|
||||||
)
|
)
|
||||||
await ws.send(json.dumps(msg_dict))
|
await ws.send(json.dumps(msg_dict))
|
||||||
|
|||||||
@@ -84,6 +84,7 @@ from mcp.server.session import ServerSession
|
|||||||
from mcp.server.stdio import stdio_server as stdio_server
|
from mcp.server.stdio import stdio_server as stdio_server
|
||||||
from mcp.shared.context import RequestContext
|
from mcp.shared.context import RequestContext
|
||||||
from mcp.shared.exceptions import McpError
|
from mcp.shared.exceptions import McpError
|
||||||
|
from mcp.shared.message import SessionMessage
|
||||||
from mcp.shared.session import RequestResponder
|
from mcp.shared.session import RequestResponder
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -471,8 +472,8 @@ class Server(Generic[LifespanResultT]):
|
|||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
|
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
|
write_stream: MemoryObjectSendStream[SessionMessage],
|
||||||
initialization_options: InitializationOptions,
|
initialization_options: InitializationOptions,
|
||||||
# When False, exceptions are returned as messages to the client.
|
# When False, exceptions are returned as messages to the client.
|
||||||
# When True, exceptions are raised, which will cause the server to shut down
|
# When True, exceptions are raised, which will cause the server to shut down
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ from pydantic import AnyUrl
|
|||||||
|
|
||||||
import mcp.types as types
|
import mcp.types as types
|
||||||
from mcp.server.models import InitializationOptions
|
from mcp.server.models import InitializationOptions
|
||||||
|
from mcp.shared.message import SessionMessage
|
||||||
from mcp.shared.session import (
|
from mcp.shared.session import (
|
||||||
BaseSession,
|
BaseSession,
|
||||||
RequestResponder,
|
RequestResponder,
|
||||||
@@ -82,8 +83,8 @@ class ServerSession(
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
|
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
|
write_stream: MemoryObjectSendStream[SessionMessage],
|
||||||
init_options: InitializationOptions,
|
init_options: InitializationOptions,
|
||||||
stateless: bool = False,
|
stateless: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ from starlette.responses import Response
|
|||||||
from starlette.types import Receive, Scope, Send
|
from starlette.types import Receive, Scope, Send
|
||||||
|
|
||||||
import mcp.types as types
|
import mcp.types as types
|
||||||
|
from mcp.shared.message import SessionMessage
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -63,9 +64,7 @@ class SseServerTransport:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_endpoint: str
|
_endpoint: str
|
||||||
_read_stream_writers: dict[
|
_read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]]
|
||||||
UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(self, endpoint: str) -> None:
|
def __init__(self, endpoint: str) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -85,11 +84,11 @@ class SseServerTransport:
|
|||||||
raise ValueError("connect_sse can only handle HTTP requests")
|
raise ValueError("connect_sse can only handle HTTP requests")
|
||||||
|
|
||||||
logger.debug("Setting up SSE connection")
|
logger.debug("Setting up SSE connection")
|
||||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
|
||||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
|
||||||
|
|
||||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
write_stream: MemoryObjectSendStream[SessionMessage]
|
||||||
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
|
||||||
|
|
||||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||||
@@ -109,12 +108,12 @@ class SseServerTransport:
|
|||||||
await sse_stream_writer.send({"event": "endpoint", "data": session_uri})
|
await sse_stream_writer.send({"event": "endpoint", "data": session_uri})
|
||||||
logger.debug(f"Sent endpoint event: {session_uri}")
|
logger.debug(f"Sent endpoint event: {session_uri}")
|
||||||
|
|
||||||
async for message in write_stream_reader:
|
async for session_message in write_stream_reader:
|
||||||
logger.debug(f"Sending message via SSE: {message}")
|
logger.debug(f"Sending message via SSE: {session_message}")
|
||||||
await sse_stream_writer.send(
|
await sse_stream_writer.send(
|
||||||
{
|
{
|
||||||
"event": "message",
|
"event": "message",
|
||||||
"data": message.model_dump_json(
|
"data": session_message.message.model_dump_json(
|
||||||
by_alias=True, exclude_none=True
|
by_alias=True, exclude_none=True
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
@@ -169,7 +168,8 @@ class SseServerTransport:
|
|||||||
await writer.send(err)
|
await writer.send(err)
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.debug(f"Sending message to writer: {message}")
|
session_message = SessionMessage(message)
|
||||||
|
logger.debug(f"Sending session message to writer: {session_message}")
|
||||||
response = Response("Accepted", status_code=202)
|
response = Response("Accepted", status_code=202)
|
||||||
await response(scope, receive, send)
|
await response(scope, receive, send)
|
||||||
await writer.send(message)
|
await writer.send(session_message)
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ import anyio.lowlevel
|
|||||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||||
|
|
||||||
import mcp.types as types
|
import mcp.types as types
|
||||||
|
from mcp.shared.message import SessionMessage
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@@ -47,11 +48,11 @@ async def stdio_server(
|
|||||||
if not stdout:
|
if not stdout:
|
||||||
stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8"))
|
stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8"))
|
||||||
|
|
||||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
|
||||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
|
||||||
|
|
||||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
write_stream: MemoryObjectSendStream[SessionMessage]
|
||||||
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
|
||||||
|
|
||||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||||
@@ -66,15 +67,18 @@ async def stdio_server(
|
|||||||
await read_stream_writer.send(exc)
|
await read_stream_writer.send(exc)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
await read_stream_writer.send(message)
|
session_message = SessionMessage(message)
|
||||||
|
await read_stream_writer.send(session_message)
|
||||||
except anyio.ClosedResourceError:
|
except anyio.ClosedResourceError:
|
||||||
await anyio.lowlevel.checkpoint()
|
await anyio.lowlevel.checkpoint()
|
||||||
|
|
||||||
async def stdout_writer():
|
async def stdout_writer():
|
||||||
try:
|
try:
|
||||||
async with write_stream_reader:
|
async with write_stream_reader:
|
||||||
async for message in write_stream_reader:
|
async for session_message in write_stream_reader:
|
||||||
json = message.model_dump_json(by_alias=True, exclude_none=True)
|
json = session_message.message.model_dump_json(
|
||||||
|
by_alias=True, exclude_none=True
|
||||||
|
)
|
||||||
await stdout.write(json + "\n")
|
await stdout.write(json + "\n")
|
||||||
await stdout.flush()
|
await stdout.flush()
|
||||||
except anyio.ClosedResourceError:
|
except anyio.ClosedResourceError:
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from starlette.requests import Request
|
|||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
from starlette.types import Receive, Scope, Send
|
from starlette.types import Receive, Scope, Send
|
||||||
|
|
||||||
|
from mcp.shared.message import SessionMessage
|
||||||
from mcp.types import (
|
from mcp.types import (
|
||||||
INTERNAL_ERROR,
|
INTERNAL_ERROR,
|
||||||
INVALID_PARAMS,
|
INVALID_PARAMS,
|
||||||
@@ -125,10 +126,10 @@ class StreamableHTTPServerTransport:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Server notification streams for POST requests as well as standalone SSE stream
|
# Server notification streams for POST requests as well as standalone SSE stream
|
||||||
_read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] | None = (
|
_read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = (
|
||||||
None
|
None
|
||||||
)
|
)
|
||||||
_write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] | None = None
|
_write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -378,7 +379,8 @@ class StreamableHTTPServerTransport:
|
|||||||
await response(scope, receive, send)
|
await response(scope, receive, send)
|
||||||
|
|
||||||
# Process the message after sending the response
|
# Process the message after sending the response
|
||||||
await writer.send(message)
|
session_message = SessionMessage(message)
|
||||||
|
await writer.send(session_message)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -394,7 +396,8 @@ class StreamableHTTPServerTransport:
|
|||||||
|
|
||||||
if self.is_json_response_enabled:
|
if self.is_json_response_enabled:
|
||||||
# Process the message
|
# Process the message
|
||||||
await writer.send(message)
|
session_message = SessionMessage(message)
|
||||||
|
await writer.send(session_message)
|
||||||
try:
|
try:
|
||||||
# Process messages from the request-specific stream
|
# Process messages from the request-specific stream
|
||||||
# We need to collect all messages until we get a response
|
# We need to collect all messages until we get a response
|
||||||
@@ -500,7 +503,8 @@ class StreamableHTTPServerTransport:
|
|||||||
async with anyio.create_task_group() as tg:
|
async with anyio.create_task_group() as tg:
|
||||||
tg.start_soon(response, scope, receive, send)
|
tg.start_soon(response, scope, receive, send)
|
||||||
# Then send the message to be processed by the server
|
# Then send the message to be processed by the server
|
||||||
await writer.send(message)
|
session_message = SessionMessage(message)
|
||||||
|
await writer.send(session_message)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("SSE response error")
|
logger.exception("SSE response error")
|
||||||
# Clean up the request stream if something goes wrong
|
# Clean up the request stream if something goes wrong
|
||||||
@@ -792,8 +796,8 @@ class StreamableHTTPServerTransport:
|
|||||||
self,
|
self,
|
||||||
) -> AsyncGenerator[
|
) -> AsyncGenerator[
|
||||||
tuple[
|
tuple[
|
||||||
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||||
MemoryObjectSendStream[JSONRPCMessage],
|
MemoryObjectSendStream[SessionMessage],
|
||||||
],
|
],
|
||||||
None,
|
None,
|
||||||
]:
|
]:
|
||||||
@@ -806,10 +810,10 @@ class StreamableHTTPServerTransport:
|
|||||||
# Create the memory streams for this connection
|
# Create the memory streams for this connection
|
||||||
|
|
||||||
read_stream_writer, read_stream = anyio.create_memory_object_stream[
|
read_stream_writer, read_stream = anyio.create_memory_object_stream[
|
||||||
JSONRPCMessage | Exception
|
SessionMessage | Exception
|
||||||
](0)
|
](0)
|
||||||
write_stream, write_stream_reader = anyio.create_memory_object_stream[
|
write_stream, write_stream_reader = anyio.create_memory_object_stream[
|
||||||
JSONRPCMessage
|
SessionMessage
|
||||||
](0)
|
](0)
|
||||||
|
|
||||||
# Store the streams
|
# Store the streams
|
||||||
@@ -821,8 +825,9 @@ class StreamableHTTPServerTransport:
|
|||||||
# Create a message router that distributes messages to request streams
|
# Create a message router that distributes messages to request streams
|
||||||
async def message_router():
|
async def message_router():
|
||||||
try:
|
try:
|
||||||
async for message in write_stream_reader:
|
async for session_message in write_stream_reader:
|
||||||
# Determine which request stream(s) should receive this message
|
# Determine which request stream(s) should receive this message
|
||||||
|
message = session_message.message
|
||||||
target_request_id = None
|
target_request_id = None
|
||||||
if isinstance(
|
if isinstance(
|
||||||
message.root, JSONRPCNotification | JSONRPCRequest
|
message.root, JSONRPCNotification | JSONRPCRequest
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from starlette.types import Receive, Scope, Send
|
|||||||
from starlette.websockets import WebSocket
|
from starlette.websockets import WebSocket
|
||||||
|
|
||||||
import mcp.types as types
|
import mcp.types as types
|
||||||
|
from mcp.shared.message import SessionMessage
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -22,11 +23,11 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send):
|
|||||||
websocket = WebSocket(scope, receive, send)
|
websocket = WebSocket(scope, receive, send)
|
||||||
await websocket.accept(subprotocol="mcp")
|
await websocket.accept(subprotocol="mcp")
|
||||||
|
|
||||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
|
||||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
|
||||||
|
|
||||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
write_stream: MemoryObjectSendStream[SessionMessage]
|
||||||
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
|
||||||
|
|
||||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||||
@@ -41,15 +42,18 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send):
|
|||||||
await read_stream_writer.send(exc)
|
await read_stream_writer.send(exc)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
await read_stream_writer.send(client_message)
|
session_message = SessionMessage(client_message)
|
||||||
|
await read_stream_writer.send(session_message)
|
||||||
except anyio.ClosedResourceError:
|
except anyio.ClosedResourceError:
|
||||||
await websocket.close()
|
await websocket.close()
|
||||||
|
|
||||||
async def ws_writer():
|
async def ws_writer():
|
||||||
try:
|
try:
|
||||||
async with write_stream_reader:
|
async with write_stream_reader:
|
||||||
async for message in write_stream_reader:
|
async for session_message in write_stream_reader:
|
||||||
obj = message.model_dump_json(by_alias=True, exclude_none=True)
|
obj = session_message.message.model_dump_json(
|
||||||
|
by_alias=True, exclude_none=True
|
||||||
|
)
|
||||||
await websocket.send_text(obj)
|
await websocket.send_text(obj)
|
||||||
except anyio.ClosedResourceError:
|
except anyio.ClosedResourceError:
|
||||||
await websocket.close()
|
await websocket.close()
|
||||||
|
|||||||
@@ -19,11 +19,11 @@ from mcp.client.session import (
|
|||||||
SamplingFnT,
|
SamplingFnT,
|
||||||
)
|
)
|
||||||
from mcp.server import Server
|
from mcp.server import Server
|
||||||
from mcp.types import JSONRPCMessage
|
from mcp.shared.message import SessionMessage
|
||||||
|
|
||||||
MessageStream = tuple[
|
MessageStream = tuple[
|
||||||
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||||
MemoryObjectSendStream[JSONRPCMessage],
|
MemoryObjectSendStream[SessionMessage],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -40,10 +40,10 @@ async def create_client_server_memory_streams() -> (
|
|||||||
"""
|
"""
|
||||||
# Create streams for both directions
|
# Create streams for both directions
|
||||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||||
JSONRPCMessage | Exception
|
SessionMessage | Exception
|
||||||
](1)
|
](1)
|
||||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||||
JSONRPCMessage | Exception
|
SessionMessage | Exception
|
||||||
](1)
|
](1)
|
||||||
|
|
||||||
client_streams = (server_to_client_receive, client_to_server_send)
|
client_streams = (server_to_client_receive, client_to_server_send)
|
||||||
|
|||||||
35
src/mcp/shared/message.py
Normal file
35
src/mcp/shared/message.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
"""
|
||||||
|
Message wrapper with metadata support.
|
||||||
|
|
||||||
|
This module defines a wrapper type that combines JSONRPCMessage with metadata
|
||||||
|
to support transport-specific features like resumability.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from mcp.types import JSONRPCMessage, RequestId
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClientMessageMetadata:
|
||||||
|
"""Metadata specific to client messages."""
|
||||||
|
|
||||||
|
resumption_token: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ServerMessageMetadata:
|
||||||
|
"""Metadata specific to server messages."""
|
||||||
|
|
||||||
|
related_request_id: RequestId | None = None
|
||||||
|
|
||||||
|
|
||||||
|
MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SessionMessage:
|
||||||
|
"""A message with specific metadata for transport-specific features."""
|
||||||
|
|
||||||
|
message: JSONRPCMessage
|
||||||
|
metadata: MessageMetadata = None
|
||||||
@@ -12,6 +12,7 @@ from pydantic import BaseModel
|
|||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from mcp.shared.exceptions import McpError
|
from mcp.shared.exceptions import McpError
|
||||||
|
from mcp.shared.message import SessionMessage
|
||||||
from mcp.types import (
|
from mcp.types import (
|
||||||
CancelledNotification,
|
CancelledNotification,
|
||||||
ClientNotification,
|
ClientNotification,
|
||||||
@@ -172,8 +173,8 @@ class BaseSession(
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
write_stream: MemoryObjectSendStream[SessionMessage],
|
||||||
receive_request_type: type[ReceiveRequestT],
|
receive_request_type: type[ReceiveRequestT],
|
||||||
receive_notification_type: type[ReceiveNotificationT],
|
receive_notification_type: type[ReceiveNotificationT],
|
||||||
# If none, reading will never time out
|
# If none, reading will never time out
|
||||||
@@ -240,7 +241,9 @@ class BaseSession(
|
|||||||
|
|
||||||
# TODO: Support progress callbacks
|
# TODO: Support progress callbacks
|
||||||
|
|
||||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_request))
|
await self._write_stream.send(
|
||||||
|
SessionMessage(message=JSONRPCMessage(jsonrpc_request))
|
||||||
|
)
|
||||||
|
|
||||||
# request read timeout takes precedence over session read timeout
|
# request read timeout takes precedence over session read timeout
|
||||||
timeout = None
|
timeout = None
|
||||||
@@ -300,14 +303,16 @@ class BaseSession(
|
|||||||
jsonrpc="2.0",
|
jsonrpc="2.0",
|
||||||
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
|
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||||
)
|
)
|
||||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_notification))
|
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_notification))
|
||||||
|
await self._write_stream.send(session_message)
|
||||||
|
|
||||||
async def _send_response(
|
async def _send_response(
|
||||||
self, request_id: RequestId, response: SendResultT | ErrorData
|
self, request_id: RequestId, response: SendResultT | ErrorData
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(response, ErrorData):
|
if isinstance(response, ErrorData):
|
||||||
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
|
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
|
||||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_error))
|
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
|
||||||
|
await self._write_stream.send(session_message)
|
||||||
else:
|
else:
|
||||||
jsonrpc_response = JSONRPCResponse(
|
jsonrpc_response = JSONRPCResponse(
|
||||||
jsonrpc="2.0",
|
jsonrpc="2.0",
|
||||||
@@ -316,7 +321,8 @@ class BaseSession(
|
|||||||
by_alias=True, mode="json", exclude_none=True
|
by_alias=True, mode="json", exclude_none=True
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_response))
|
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
|
||||||
|
await self._write_stream.send(session_message)
|
||||||
|
|
||||||
async def _receive_loop(self) -> None:
|
async def _receive_loop(self) -> None:
|
||||||
async with (
|
async with (
|
||||||
@@ -326,15 +332,15 @@ class BaseSession(
|
|||||||
async for message in self._read_stream:
|
async for message in self._read_stream:
|
||||||
if isinstance(message, Exception):
|
if isinstance(message, Exception):
|
||||||
await self._handle_incoming(message)
|
await self._handle_incoming(message)
|
||||||
elif isinstance(message.root, JSONRPCRequest):
|
elif isinstance(message.message.root, JSONRPCRequest):
|
||||||
validated_request = self._receive_request_type.model_validate(
|
validated_request = self._receive_request_type.model_validate(
|
||||||
message.root.model_dump(
|
message.message.root.model_dump(
|
||||||
by_alias=True, mode="json", exclude_none=True
|
by_alias=True, mode="json", exclude_none=True
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
responder = RequestResponder(
|
responder = RequestResponder(
|
||||||
request_id=message.root.id,
|
request_id=message.message.root.id,
|
||||||
request_meta=validated_request.root.params.meta
|
request_meta=validated_request.root.params.meta
|
||||||
if validated_request.root.params
|
if validated_request.root.params
|
||||||
else None,
|
else None,
|
||||||
@@ -349,10 +355,10 @@ class BaseSession(
|
|||||||
if not responder._completed: # type: ignore[reportPrivateUsage]
|
if not responder._completed: # type: ignore[reportPrivateUsage]
|
||||||
await self._handle_incoming(responder)
|
await self._handle_incoming(responder)
|
||||||
|
|
||||||
elif isinstance(message.root, JSONRPCNotification):
|
elif isinstance(message.message.root, JSONRPCNotification):
|
||||||
try:
|
try:
|
||||||
notification = self._receive_notification_type.model_validate(
|
notification = self._receive_notification_type.model_validate(
|
||||||
message.root.model_dump(
|
message.message.root.model_dump(
|
||||||
by_alias=True, mode="json", exclude_none=True
|
by_alias=True, mode="json", exclude_none=True
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -368,12 +374,12 @@ class BaseSession(
|
|||||||
# For other validation errors, log and continue
|
# For other validation errors, log and continue
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Failed to validate notification: {e}. "
|
f"Failed to validate notification: {e}. "
|
||||||
f"Message was: {message.root}"
|
f"Message was: {message.message.root}"
|
||||||
)
|
)
|
||||||
else: # Response or error
|
else: # Response or error
|
||||||
stream = self._response_streams.pop(message.root.id, None)
|
stream = self._response_streams.pop(message.message.root.id, None)
|
||||||
if stream:
|
if stream:
|
||||||
await stream.send(message.root)
|
await stream.send(message.message.root)
|
||||||
else:
|
else:
|
||||||
await self._handle_incoming(
|
await self._handle_incoming(
|
||||||
RuntimeError(
|
RuntimeError(
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import pytest
|
|||||||
|
|
||||||
import mcp.types as types
|
import mcp.types as types
|
||||||
from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession
|
from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession
|
||||||
|
from mcp.shared.message import SessionMessage
|
||||||
from mcp.shared.session import RequestResponder
|
from mcp.shared.session import RequestResponder
|
||||||
from mcp.types import (
|
from mcp.types import (
|
||||||
LATEST_PROTOCOL_VERSION,
|
LATEST_PROTOCOL_VERSION,
|
||||||
@@ -24,10 +25,10 @@ from mcp.types import (
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_client_session_initialize():
|
async def test_client_session_initialize():
|
||||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||||
JSONRPCMessage
|
SessionMessage
|
||||||
](1)
|
](1)
|
||||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||||
JSONRPCMessage
|
SessionMessage
|
||||||
](1)
|
](1)
|
||||||
|
|
||||||
initialized_notification = None
|
initialized_notification = None
|
||||||
@@ -35,7 +36,8 @@ async def test_client_session_initialize():
|
|||||||
async def mock_server():
|
async def mock_server():
|
||||||
nonlocal initialized_notification
|
nonlocal initialized_notification
|
||||||
|
|
||||||
jsonrpc_request = await client_to_server_receive.receive()
|
session_message = await client_to_server_receive.receive()
|
||||||
|
jsonrpc_request = session_message.message
|
||||||
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
|
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
|
||||||
request = ClientRequest.model_validate(
|
request = ClientRequest.model_validate(
|
||||||
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
|
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||||
@@ -59,17 +61,20 @@ async def test_client_session_initialize():
|
|||||||
|
|
||||||
async with server_to_client_send:
|
async with server_to_client_send:
|
||||||
await server_to_client_send.send(
|
await server_to_client_send.send(
|
||||||
JSONRPCMessage(
|
SessionMessage(
|
||||||
JSONRPCResponse(
|
JSONRPCMessage(
|
||||||
jsonrpc="2.0",
|
JSONRPCResponse(
|
||||||
id=jsonrpc_request.root.id,
|
jsonrpc="2.0",
|
||||||
result=result.model_dump(
|
id=jsonrpc_request.root.id,
|
||||||
by_alias=True, mode="json", exclude_none=True
|
result=result.model_dump(
|
||||||
),
|
by_alias=True, mode="json", exclude_none=True
|
||||||
|
),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
jsonrpc_notification = await client_to_server_receive.receive()
|
session_notification = await client_to_server_receive.receive()
|
||||||
|
jsonrpc_notification = session_notification.message
|
||||||
assert isinstance(jsonrpc_notification.root, JSONRPCNotification)
|
assert isinstance(jsonrpc_notification.root, JSONRPCNotification)
|
||||||
initialized_notification = ClientNotification.model_validate(
|
initialized_notification = ClientNotification.model_validate(
|
||||||
jsonrpc_notification.model_dump(
|
jsonrpc_notification.model_dump(
|
||||||
@@ -116,10 +121,10 @@ async def test_client_session_initialize():
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_client_session_custom_client_info():
|
async def test_client_session_custom_client_info():
|
||||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||||
JSONRPCMessage
|
SessionMessage
|
||||||
](1)
|
](1)
|
||||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||||
JSONRPCMessage
|
SessionMessage
|
||||||
](1)
|
](1)
|
||||||
|
|
||||||
custom_client_info = Implementation(name="test-client", version="1.2.3")
|
custom_client_info = Implementation(name="test-client", version="1.2.3")
|
||||||
@@ -128,7 +133,8 @@ async def test_client_session_custom_client_info():
|
|||||||
async def mock_server():
|
async def mock_server():
|
||||||
nonlocal received_client_info
|
nonlocal received_client_info
|
||||||
|
|
||||||
jsonrpc_request = await client_to_server_receive.receive()
|
session_message = await client_to_server_receive.receive()
|
||||||
|
jsonrpc_request = session_message.message
|
||||||
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
|
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
|
||||||
request = ClientRequest.model_validate(
|
request = ClientRequest.model_validate(
|
||||||
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
|
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||||
@@ -146,13 +152,15 @@ async def test_client_session_custom_client_info():
|
|||||||
|
|
||||||
async with server_to_client_send:
|
async with server_to_client_send:
|
||||||
await server_to_client_send.send(
|
await server_to_client_send.send(
|
||||||
JSONRPCMessage(
|
SessionMessage(
|
||||||
JSONRPCResponse(
|
JSONRPCMessage(
|
||||||
jsonrpc="2.0",
|
JSONRPCResponse(
|
||||||
id=jsonrpc_request.root.id,
|
jsonrpc="2.0",
|
||||||
result=result.model_dump(
|
id=jsonrpc_request.root.id,
|
||||||
by_alias=True, mode="json", exclude_none=True
|
result=result.model_dump(
|
||||||
),
|
by_alias=True, mode="json", exclude_none=True
|
||||||
|
),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -181,10 +189,10 @@ async def test_client_session_custom_client_info():
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_client_session_default_client_info():
|
async def test_client_session_default_client_info():
|
||||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||||
JSONRPCMessage
|
SessionMessage
|
||||||
](1)
|
](1)
|
||||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||||
JSONRPCMessage
|
SessionMessage
|
||||||
](1)
|
](1)
|
||||||
|
|
||||||
received_client_info = None
|
received_client_info = None
|
||||||
@@ -192,7 +200,8 @@ async def test_client_session_default_client_info():
|
|||||||
async def mock_server():
|
async def mock_server():
|
||||||
nonlocal received_client_info
|
nonlocal received_client_info
|
||||||
|
|
||||||
jsonrpc_request = await client_to_server_receive.receive()
|
session_message = await client_to_server_receive.receive()
|
||||||
|
jsonrpc_request = session_message.message
|
||||||
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
|
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
|
||||||
request = ClientRequest.model_validate(
|
request = ClientRequest.model_validate(
|
||||||
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
|
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||||
@@ -210,13 +219,15 @@ async def test_client_session_default_client_info():
|
|||||||
|
|
||||||
async with server_to_client_send:
|
async with server_to_client_send:
|
||||||
await server_to_client_send.send(
|
await server_to_client_send.send(
|
||||||
JSONRPCMessage(
|
SessionMessage(
|
||||||
JSONRPCResponse(
|
JSONRPCMessage(
|
||||||
jsonrpc="2.0",
|
JSONRPCResponse(
|
||||||
id=jsonrpc_request.root.id,
|
jsonrpc="2.0",
|
||||||
result=result.model_dump(
|
id=jsonrpc_request.root.id,
|
||||||
by_alias=True, mode="json", exclude_none=True
|
result=result.model_dump(
|
||||||
),
|
by_alias=True, mode="json", exclude_none=True
|
||||||
|
),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import shutil
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from mcp.client.stdio import StdioServerParameters, stdio_client
|
from mcp.client.stdio import StdioServerParameters, stdio_client
|
||||||
|
from mcp.shared.message import SessionMessage
|
||||||
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
|
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
|
||||||
|
|
||||||
tee: str = shutil.which("tee") # type: ignore
|
tee: str = shutil.which("tee") # type: ignore
|
||||||
@@ -22,7 +23,8 @@ async def test_stdio_client():
|
|||||||
|
|
||||||
async with write_stream:
|
async with write_stream:
|
||||||
for message in messages:
|
for message in messages:
|
||||||
await write_stream.send(message)
|
session_message = SessionMessage(message)
|
||||||
|
await write_stream.send(session_message)
|
||||||
|
|
||||||
read_messages = []
|
read_messages = []
|
||||||
async with read_stream:
|
async with read_stream:
|
||||||
@@ -30,7 +32,7 @@ async def test_stdio_client():
|
|||||||
if isinstance(message, Exception):
|
if isinstance(message, Exception):
|
||||||
raise message
|
raise message
|
||||||
|
|
||||||
read_messages.append(message)
|
read_messages.append(message.message)
|
||||||
if len(read_messages) == 2:
|
if len(read_messages) == 2:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import pytest
|
|||||||
|
|
||||||
from mcp.server.lowlevel import NotificationOptions, Server
|
from mcp.server.lowlevel import NotificationOptions, Server
|
||||||
from mcp.server.models import InitializationOptions
|
from mcp.server.models import InitializationOptions
|
||||||
|
from mcp.shared.message import SessionMessage
|
||||||
from mcp.types import (
|
from mcp.types import (
|
||||||
LATEST_PROTOCOL_VERSION,
|
LATEST_PROTOCOL_VERSION,
|
||||||
ClientCapabilities,
|
ClientCapabilities,
|
||||||
@@ -64,8 +65,10 @@ async def test_request_id_match() -> None:
|
|||||||
jsonrpc="2.0",
|
jsonrpc="2.0",
|
||||||
)
|
)
|
||||||
|
|
||||||
await client_writer.send(JSONRPCMessage(root=init_req))
|
await client_writer.send(SessionMessage(JSONRPCMessage(root=init_req)))
|
||||||
await server_reader.receive() # Get init response but don't need to check it
|
response = (
|
||||||
|
await server_reader.receive()
|
||||||
|
) # Get init response but don't need to check it
|
||||||
|
|
||||||
# Send initialized notification
|
# Send initialized notification
|
||||||
initialized_notification = JSONRPCNotification(
|
initialized_notification = JSONRPCNotification(
|
||||||
@@ -73,21 +76,23 @@ async def test_request_id_match() -> None:
|
|||||||
params=NotificationParams().model_dump(by_alias=True, exclude_none=True),
|
params=NotificationParams().model_dump(by_alias=True, exclude_none=True),
|
||||||
jsonrpc="2.0",
|
jsonrpc="2.0",
|
||||||
)
|
)
|
||||||
await client_writer.send(JSONRPCMessage(root=initialized_notification))
|
await client_writer.send(
|
||||||
|
SessionMessage(JSONRPCMessage(root=initialized_notification))
|
||||||
|
)
|
||||||
|
|
||||||
# Send ping request with custom ID
|
# Send ping request with custom ID
|
||||||
ping_request = JSONRPCRequest(
|
ping_request = JSONRPCRequest(
|
||||||
id=custom_request_id, method="ping", params={}, jsonrpc="2.0"
|
id=custom_request_id, method="ping", params={}, jsonrpc="2.0"
|
||||||
)
|
)
|
||||||
|
|
||||||
await client_writer.send(JSONRPCMessage(root=ping_request))
|
await client_writer.send(SessionMessage(JSONRPCMessage(root=ping_request)))
|
||||||
|
|
||||||
# Read response
|
# Read response
|
||||||
response = await server_reader.receive()
|
response = await server_reader.receive()
|
||||||
|
|
||||||
# Verify response ID matches request ID
|
# Verify response ID matches request ID
|
||||||
assert (
|
assert (
|
||||||
response.root.id == custom_request_id
|
response.message.root.id == custom_request_id
|
||||||
), "Response ID should match request ID"
|
), "Response ID should match request ID"
|
||||||
|
|
||||||
# Cancel server task
|
# Cancel server task
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from pydantic import TypeAdapter
|
|||||||
from mcp.server.fastmcp import Context, FastMCP
|
from mcp.server.fastmcp import Context, FastMCP
|
||||||
from mcp.server.lowlevel.server import NotificationOptions, Server
|
from mcp.server.lowlevel.server import NotificationOptions, Server
|
||||||
from mcp.server.models import InitializationOptions
|
from mcp.server.models import InitializationOptions
|
||||||
|
from mcp.shared.message import SessionMessage
|
||||||
from mcp.types import (
|
from mcp.types import (
|
||||||
ClientCapabilities,
|
ClientCapabilities,
|
||||||
Implementation,
|
Implementation,
|
||||||
@@ -82,41 +83,49 @@ async def test_lowlevel_server_lifespan():
|
|||||||
clientInfo=Implementation(name="test-client", version="0.1.0"),
|
clientInfo=Implementation(name="test-client", version="0.1.0"),
|
||||||
)
|
)
|
||||||
await send_stream1.send(
|
await send_stream1.send(
|
||||||
JSONRPCMessage(
|
SessionMessage(
|
||||||
root=JSONRPCRequest(
|
JSONRPCMessage(
|
||||||
jsonrpc="2.0",
|
root=JSONRPCRequest(
|
||||||
id=1,
|
jsonrpc="2.0",
|
||||||
method="initialize",
|
id=1,
|
||||||
params=TypeAdapter(InitializeRequestParams).dump_python(params),
|
method="initialize",
|
||||||
|
params=TypeAdapter(InitializeRequestParams).dump_python(params),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
response = await receive_stream2.receive()
|
response = await receive_stream2.receive()
|
||||||
|
response = response.message
|
||||||
|
|
||||||
# Send initialized notification
|
# Send initialized notification
|
||||||
await send_stream1.send(
|
await send_stream1.send(
|
||||||
JSONRPCMessage(
|
SessionMessage(
|
||||||
root=JSONRPCNotification(
|
JSONRPCMessage(
|
||||||
jsonrpc="2.0",
|
root=JSONRPCNotification(
|
||||||
method="notifications/initialized",
|
jsonrpc="2.0",
|
||||||
|
method="notifications/initialized",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call the tool to verify lifespan context
|
# Call the tool to verify lifespan context
|
||||||
await send_stream1.send(
|
await send_stream1.send(
|
||||||
JSONRPCMessage(
|
SessionMessage(
|
||||||
root=JSONRPCRequest(
|
JSONRPCMessage(
|
||||||
jsonrpc="2.0",
|
root=JSONRPCRequest(
|
||||||
id=2,
|
jsonrpc="2.0",
|
||||||
method="tools/call",
|
id=2,
|
||||||
params={"name": "check_lifespan", "arguments": {}},
|
method="tools/call",
|
||||||
|
params={"name": "check_lifespan", "arguments": {}},
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get response and verify
|
# Get response and verify
|
||||||
response = await receive_stream2.receive()
|
response = await receive_stream2.receive()
|
||||||
|
response = response.message
|
||||||
assert response.root.result["content"][0]["text"] == "true"
|
assert response.root.result["content"][0]["text"] == "true"
|
||||||
|
|
||||||
# Cancel server task
|
# Cancel server task
|
||||||
@@ -178,41 +187,49 @@ async def test_fastmcp_server_lifespan():
|
|||||||
clientInfo=Implementation(name="test-client", version="0.1.0"),
|
clientInfo=Implementation(name="test-client", version="0.1.0"),
|
||||||
)
|
)
|
||||||
await send_stream1.send(
|
await send_stream1.send(
|
||||||
JSONRPCMessage(
|
SessionMessage(
|
||||||
root=JSONRPCRequest(
|
JSONRPCMessage(
|
||||||
jsonrpc="2.0",
|
root=JSONRPCRequest(
|
||||||
id=1,
|
jsonrpc="2.0",
|
||||||
method="initialize",
|
id=1,
|
||||||
params=TypeAdapter(InitializeRequestParams).dump_python(params),
|
method="initialize",
|
||||||
|
params=TypeAdapter(InitializeRequestParams).dump_python(params),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
response = await receive_stream2.receive()
|
response = await receive_stream2.receive()
|
||||||
|
response = response.message
|
||||||
|
|
||||||
# Send initialized notification
|
# Send initialized notification
|
||||||
await send_stream1.send(
|
await send_stream1.send(
|
||||||
JSONRPCMessage(
|
SessionMessage(
|
||||||
root=JSONRPCNotification(
|
JSONRPCMessage(
|
||||||
jsonrpc="2.0",
|
root=JSONRPCNotification(
|
||||||
method="notifications/initialized",
|
jsonrpc="2.0",
|
||||||
|
method="notifications/initialized",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call the tool to verify lifespan context
|
# Call the tool to verify lifespan context
|
||||||
await send_stream1.send(
|
await send_stream1.send(
|
||||||
JSONRPCMessage(
|
SessionMessage(
|
||||||
root=JSONRPCRequest(
|
JSONRPCMessage(
|
||||||
jsonrpc="2.0",
|
root=JSONRPCRequest(
|
||||||
id=2,
|
jsonrpc="2.0",
|
||||||
method="tools/call",
|
id=2,
|
||||||
params={"name": "check_lifespan", "arguments": {}},
|
method="tools/call",
|
||||||
|
params={"name": "check_lifespan", "arguments": {}},
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get response and verify
|
# Get response and verify
|
||||||
response = await receive_stream2.receive()
|
response = await receive_stream2.receive()
|
||||||
|
response = response.message
|
||||||
assert response.root.result["content"][0]["text"] == "true"
|
assert response.root.result["content"][0]["text"] == "true"
|
||||||
|
|
||||||
# Cancel server task
|
# Cancel server task
|
||||||
|
|||||||
@@ -8,10 +8,10 @@ from mcp.server import Server
|
|||||||
from mcp.server.lowlevel import NotificationOptions
|
from mcp.server.lowlevel import NotificationOptions
|
||||||
from mcp.server.models import InitializationOptions
|
from mcp.server.models import InitializationOptions
|
||||||
from mcp.server.session import ServerSession
|
from mcp.server.session import ServerSession
|
||||||
|
from mcp.shared.message import SessionMessage
|
||||||
from mcp.shared.session import RequestResponder
|
from mcp.shared.session import RequestResponder
|
||||||
from mcp.types import (
|
from mcp.types import (
|
||||||
ClientResult,
|
ClientResult,
|
||||||
JSONRPCMessage,
|
|
||||||
ServerNotification,
|
ServerNotification,
|
||||||
ServerRequest,
|
ServerRequest,
|
||||||
Tool,
|
Tool,
|
||||||
@@ -46,10 +46,10 @@ async def test_lowlevel_server_tool_annotations():
|
|||||||
]
|
]
|
||||||
|
|
||||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||||
JSONRPCMessage
|
SessionMessage
|
||||||
](10)
|
](10)
|
||||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||||
JSONRPCMessage
|
SessionMessage
|
||||||
](10)
|
](10)
|
||||||
|
|
||||||
# Message handler for client
|
# Message handler for client
|
||||||
|
|||||||
@@ -7,11 +7,11 @@ from mcp.server import Server
|
|||||||
from mcp.server.lowlevel import NotificationOptions
|
from mcp.server.lowlevel import NotificationOptions
|
||||||
from mcp.server.models import InitializationOptions
|
from mcp.server.models import InitializationOptions
|
||||||
from mcp.server.session import ServerSession
|
from mcp.server.session import ServerSession
|
||||||
|
from mcp.shared.message import SessionMessage
|
||||||
from mcp.shared.session import RequestResponder
|
from mcp.shared.session import RequestResponder
|
||||||
from mcp.types import (
|
from mcp.types import (
|
||||||
ClientNotification,
|
ClientNotification,
|
||||||
InitializedNotification,
|
InitializedNotification,
|
||||||
JSONRPCMessage,
|
|
||||||
PromptsCapability,
|
PromptsCapability,
|
||||||
ResourcesCapability,
|
ResourcesCapability,
|
||||||
ServerCapabilities,
|
ServerCapabilities,
|
||||||
@@ -21,10 +21,10 @@ from mcp.types import (
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_server_session_initialize():
|
async def test_server_session_initialize():
|
||||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||||
JSONRPCMessage
|
SessionMessage
|
||||||
](1)
|
](1)
|
||||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||||
JSONRPCMessage
|
SessionMessage
|
||||||
](1)
|
](1)
|
||||||
|
|
||||||
# Create a message handler to catch exceptions
|
# Create a message handler to catch exceptions
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import anyio
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from mcp.server.stdio import stdio_server
|
from mcp.server.stdio import stdio_server
|
||||||
|
from mcp.shared.message import SessionMessage
|
||||||
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
|
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
|
||||||
|
|
||||||
|
|
||||||
@@ -29,7 +30,7 @@ async def test_stdio_server():
|
|||||||
async for message in read_stream:
|
async for message in read_stream:
|
||||||
if isinstance(message, Exception):
|
if isinstance(message, Exception):
|
||||||
raise message
|
raise message
|
||||||
received_messages.append(message)
|
received_messages.append(message.message)
|
||||||
if len(received_messages) == 2:
|
if len(received_messages) == 2:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -50,7 +51,8 @@ async def test_stdio_server():
|
|||||||
|
|
||||||
async with write_stream:
|
async with write_stream:
|
||||||
for response in responses:
|
for response in responses:
|
||||||
await write_stream.send(response)
|
session_message = SessionMessage(response)
|
||||||
|
await write_stream.send(session_message)
|
||||||
|
|
||||||
stdout.seek(0)
|
stdout.seek(0)
|
||||||
output_lines = stdout.readlines()
|
output_lines = stdout.readlines()
|
||||||
|
|||||||
Reference in New Issue
Block a user