Revert "refactor: reorganize message handling for better type safety and clar…" (#282)

This reverts commit 9d0f2daddb.
This commit is contained in:
Marcelo Trylesinski
2025-03-14 10:50:46 +01:00
committed by GitHub
parent ebb81d3b2b
commit 7196604468
17 changed files with 151 additions and 283 deletions

View File

@@ -1,11 +1,12 @@
from datetime import timedelta
from typing import Any, Protocol
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl, TypeAdapter
import mcp.types as types
from mcp.shared.context import RequestContext
from mcp.shared.session import BaseSession, ReadStream, RequestResponder, WriteStream
from mcp.shared.session import BaseSession, RequestResponder
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
@@ -58,8 +59,8 @@ class ClientSession(
):
def __init__(
self,
read_stream: ReadStream,
write_stream: WriteStream,
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,

View File

@@ -6,16 +6,10 @@ from urllib.parse import urljoin, urlparse
import anyio
import httpx
from anyio.abc import TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import aconnect_sse
import mcp.types as types
from mcp.shared.session import (
ReadStream,
ReadStreamWriter,
WriteStream,
WriteStreamReader,
)
from mcp.types import MessageFrame
logger = logging.getLogger(__name__)
@@ -37,11 +31,11 @@ async def sse_client(
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.
"""
read_stream: ReadStream
read_stream_writer: ReadStreamWriter
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: WriteStream
write_stream_reader: WriteStreamReader
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -90,11 +84,8 @@ async def sse_client(
case "message":
try:
message = MessageFrame(
message=types.JSONRPCMessage.model_validate_json( # noqa: E501
sse.data
),
raw=sse,
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
sse.data
)
logger.debug(
f"Received server message: {message}"

View File

@@ -1,7 +1,7 @@
import json
import logging
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator
from typing import AsyncGenerator
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
@@ -10,7 +10,6 @@ from websockets.asyncio.client import connect as ws_connect
from websockets.typing import Subprotocol
import mcp.types as types
from mcp.types import MessageFrame
logger = logging.getLogger(__name__)
@@ -20,8 +19,8 @@ async def websocket_client(
url: str,
) -> AsyncGenerator[
tuple[
MemoryObjectReceiveStream[MessageFrame[Any] | Exception],
MemoryObjectSendStream[MessageFrame[Any]],
MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
MemoryObjectSendStream[types.JSONRPCMessage],
],
None,
]:
@@ -54,11 +53,7 @@ async def websocket_client(
async with read_stream_writer:
async for raw_text in ws:
try:
json_message = types.JSONRPCMessage.model_validate_json(
raw_text
)
# Create MessageFrame with JSON message as root
message = MessageFrame(message=json_message, raw=raw_text)
message = types.JSONRPCMessage.model_validate_json(raw_text)
await read_stream_writer.send(message)
except ValidationError as exc:
# If JSON parse or model validation fails, send the exception
@@ -71,8 +66,8 @@ async def websocket_client(
"""
async with write_stream_reader:
async for message in write_stream_reader:
# Extract the JSON-RPC message from MessageFrame and convert to JSON
msg_dict = message.message.model_dump(
# Convert to a dict, then to JSON
msg_dict = message.model_dump(
by_alias=True, mode="json", exclude_none=True
)
await ws.send(json.dumps(msg_dict))