mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2026-01-08 16:34:19 +01:00
Wrap JSONRPC messages with SessionMessage for metadata support (#590)
This commit is contained in:
@@ -11,8 +11,8 @@ import mcp.types as types
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import StdioServerParameters, stdio_client
|
||||
from mcp.shared.message import SessionMessage
|
||||
from mcp.shared.session import RequestResponder
|
||||
from mcp.types import JSONRPCMessage
|
||||
|
||||
if not sys.warnoptions:
|
||||
import warnings
|
||||
@@ -36,8 +36,8 @@ async def message_handler(
|
||||
|
||||
|
||||
async def run_session(
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[SessionMessage],
|
||||
client_info: types.Implementation | None = None,
|
||||
):
|
||||
async with ClientSession(
|
||||
|
||||
@@ -7,6 +7,7 @@ from pydantic import AnyUrl, TypeAdapter
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.shared.message import SessionMessage
|
||||
from mcp.shared.session import BaseSession, RequestResponder
|
||||
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
||||
|
||||
@@ -92,8 +93,8 @@ class ClientSession(
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
|
||||
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[SessionMessage],
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
sampling_callback: SamplingFnT | None = None,
|
||||
list_roots_callback: ListRootsFnT | None = None,
|
||||
|
||||
@@ -10,6 +10,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStre
|
||||
from httpx_sse import aconnect_sse
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.shared.message import SessionMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,11 +32,11 @@ async def sse_client(
|
||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
||||
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||
"""
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
||||
write_stream: MemoryObjectSendStream[SessionMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
@@ -97,7 +98,8 @@ async def sse_client(
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
await read_stream_writer.send(message)
|
||||
session_message = SessionMessage(message)
|
||||
await read_stream_writer.send(session_message)
|
||||
case _:
|
||||
logger.warning(
|
||||
f"Unknown SSE event: {sse.event}"
|
||||
@@ -111,11 +113,13 @@ async def sse_client(
|
||||
async def post_writer(endpoint_url: str):
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for message in write_stream_reader:
|
||||
logger.debug(f"Sending client message: {message}")
|
||||
async for session_message in write_stream_reader:
|
||||
logger.debug(
|
||||
f"Sending client message: {session_message}"
|
||||
)
|
||||
response = await client.post(
|
||||
endpoint_url,
|
||||
json=message.model_dump(
|
||||
json=session_message.message.model_dump(
|
||||
by_alias=True,
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
|
||||
@@ -11,6 +11,7 @@ from anyio.streams.text import TextReceiveStream
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.shared.message import SessionMessage
|
||||
|
||||
from .win32 import (
|
||||
create_windows_process,
|
||||
@@ -98,11 +99,11 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
|
||||
Client transport for stdio: this will connect to a server by spawning a
|
||||
process and communicating with it over stdin/stdout.
|
||||
"""
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
||||
write_stream: MemoryObjectSendStream[SessionMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
@@ -143,7 +144,8 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
await read_stream_writer.send(message)
|
||||
session_message = SessionMessage(message)
|
||||
await read_stream_writer.send(session_message)
|
||||
except anyio.ClosedResourceError:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
@@ -152,8 +154,10 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
|
||||
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for message in write_stream_reader:
|
||||
json = message.model_dump_json(by_alias=True, exclude_none=True)
|
||||
async for session_message in write_stream_reader:
|
||||
json = session_message.message.model_dump_json(
|
||||
by_alias=True, exclude_none=True
|
||||
)
|
||||
await process.stdin.send(
|
||||
(json + "\n").encode(
|
||||
encoding=server.encoding,
|
||||
|
||||
@@ -15,6 +15,7 @@ import anyio
|
||||
import httpx
|
||||
from httpx_sse import EventSource, aconnect_sse
|
||||
|
||||
from mcp.shared.message import SessionMessage
|
||||
from mcp.types import (
|
||||
ErrorData,
|
||||
JSONRPCError,
|
||||
@@ -52,10 +53,10 @@ async def streamablehttp_client(
|
||||
"""
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream[
|
||||
JSONRPCMessage | Exception
|
||||
SessionMessage | Exception
|
||||
](0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream[
|
||||
JSONRPCMessage
|
||||
SessionMessage
|
||||
](0)
|
||||
|
||||
async def get_stream():
|
||||
@@ -86,7 +87,8 @@ async def streamablehttp_client(
|
||||
try:
|
||||
message = JSONRPCMessage.model_validate_json(sse.data)
|
||||
logger.debug(f"GET message: {message}")
|
||||
await read_stream_writer.send(message)
|
||||
session_message = SessionMessage(message)
|
||||
await read_stream_writer.send(session_message)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error parsing GET message: {exc}")
|
||||
await read_stream_writer.send(exc)
|
||||
@@ -100,7 +102,8 @@ async def streamablehttp_client(
|
||||
nonlocal session_id
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for message in write_stream_reader:
|
||||
async for session_message in write_stream_reader:
|
||||
message = session_message.message
|
||||
# Add session ID to headers if we have one
|
||||
post_headers = request_headers.copy()
|
||||
if session_id:
|
||||
@@ -141,9 +144,10 @@ async def streamablehttp_client(
|
||||
message="Session terminated",
|
||||
),
|
||||
)
|
||||
await read_stream_writer.send(
|
||||
session_message = SessionMessage(
|
||||
JSONRPCMessage(jsonrpc_error)
|
||||
)
|
||||
await read_stream_writer.send(session_message)
|
||||
continue
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -163,7 +167,8 @@ async def streamablehttp_client(
|
||||
json_message = JSONRPCMessage.model_validate_json(
|
||||
content
|
||||
)
|
||||
await read_stream_writer.send(json_message)
|
||||
session_message = SessionMessage(json_message)
|
||||
await read_stream_writer.send(session_message)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error parsing JSON response: {exc}")
|
||||
await read_stream_writer.send(exc)
|
||||
@@ -175,11 +180,15 @@ async def streamablehttp_client(
|
||||
async for sse in event_source.aiter_sse():
|
||||
if sse.event == "message":
|
||||
try:
|
||||
await read_stream_writer.send(
|
||||
message = (
|
||||
JSONRPCMessage.model_validate_json(
|
||||
sse.data
|
||||
)
|
||||
)
|
||||
session_message = SessionMessage(message)
|
||||
await read_stream_writer.send(
|
||||
session_message
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("Error parsing message")
|
||||
await read_stream_writer.send(exc)
|
||||
|
||||
@@ -10,6 +10,7 @@ from websockets.asyncio.client import connect as ws_connect
|
||||
from websockets.typing import Subprotocol
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.shared.message import SessionMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -19,8 +20,8 @@ async def websocket_client(
|
||||
url: str,
|
||||
) -> AsyncGenerator[
|
||||
tuple[
|
||||
MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
|
||||
MemoryObjectSendStream[types.JSONRPCMessage],
|
||||
MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||
MemoryObjectSendStream[SessionMessage],
|
||||
],
|
||||
None,
|
||||
]:
|
||||
@@ -39,10 +40,10 @@ async def websocket_client(
|
||||
# Create two in-memory streams:
|
||||
# - One for incoming messages (read_stream, written by ws_reader)
|
||||
# - One for outgoing messages (write_stream, read by ws_writer)
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
||||
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
|
||||
write_stream: MemoryObjectSendStream[SessionMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
@@ -59,7 +60,8 @@ async def websocket_client(
|
||||
async for raw_text in ws:
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json(raw_text)
|
||||
await read_stream_writer.send(message)
|
||||
session_message = SessionMessage(message)
|
||||
await read_stream_writer.send(session_message)
|
||||
except ValidationError as exc:
|
||||
# If JSON parse or model validation fails, send the exception
|
||||
await read_stream_writer.send(exc)
|
||||
@@ -70,9 +72,9 @@ async def websocket_client(
|
||||
sends them to the server.
|
||||
"""
|
||||
async with write_stream_reader:
|
||||
async for message in write_stream_reader:
|
||||
async for session_message in write_stream_reader:
|
||||
# Convert to a dict, then to JSON
|
||||
msg_dict = message.model_dump(
|
||||
msg_dict = session_message.message.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
)
|
||||
await ws.send(json.dumps(msg_dict))
|
||||
|
||||
Reference in New Issue
Block a user