refactor: reorganize message handling for better type safety and clarity (#239)

* refactor: improve typing with memory stream type aliases

Move memory stream type definitions to models.py and use them throughout
the codebase for better type safety and maintainability.

GitHub-Issue:#201

* refactor: move streams to ParsedMessage

* refactor: update test files to use ParsedMessage

Updates test files to work with the ParsedMessage stream type aliases
and fixes a line length issue in test_201_client_hangs_on_logging.py.

Github-Issue:#201

* refactor: rename ParsedMessage to MessageFrame for clarity

🤖 Generated with [Claude Code](https://claude.ai/code)
Co-Authored-By: Claude <noreply@anthropic.com>

* refactor: move MessageFrame class to types.py for better code organization

🤖 Generated with [Claude Code](https://claude.ai/code)
Co-Authored-By: Claude <noreply@anthropic.com>

* fix pyright

* refactor: update websocket client to use MessageFrame

Modified the websocket client to work with the new MessageFrame type,
preserving raw message text and properly extracting the root JSON-RPC
message when sending.

Github-Issue:#204

* fix: use NoneType instead of None for type parameters in MessageFrame

🤖 Generated with [Claude Code](https://claude.ai/code)

* refactor: rename root to message
This commit is contained in:
David Soria Parra
2025-03-13 13:44:55 +00:00
committed by GitHub
parent ad7f7a5473
commit 9d0f2daddb
17 changed files with 283 additions and 151 deletions

View File

@@ -1,12 +1,11 @@
from datetime import timedelta
from typing import Any, Protocol
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl, TypeAdapter
import mcp.types as types
from mcp.shared.context import RequestContext
from mcp.shared.session import BaseSession, RequestResponder
from mcp.shared.session import BaseSession, ReadStream, RequestResponder, WriteStream
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
@@ -59,8 +58,8 @@ class ClientSession(
):
def __init__(
self,
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
read_stream: ReadStream,
write_stream: WriteStream,
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,

View File

@@ -6,10 +6,16 @@ from urllib.parse import urljoin, urlparse
import anyio
import httpx
from anyio.abc import TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import aconnect_sse
import mcp.types as types
from mcp.shared.session import (
ReadStream,
ReadStreamWriter,
WriteStream,
WriteStreamReader,
)
from mcp.types import MessageFrame
logger = logging.getLogger(__name__)
@@ -31,11 +37,11 @@ async def sse_client(
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.
"""
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
read_stream: ReadStream
read_stream_writer: ReadStreamWriter
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
write_stream: WriteStream
write_stream_reader: WriteStreamReader
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -84,8 +90,11 @@ async def sse_client(
case "message":
try:
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
sse.data
message = MessageFrame(
message=types.JSONRPCMessage.model_validate_json( # noqa: E501
sse.data
),
raw=sse,
)
logger.debug(
f"Received server message: {message}"

View File

@@ -1,7 +1,7 @@
import json
import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from typing import Any, AsyncGenerator
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
@@ -10,6 +10,7 @@ from websockets.asyncio.client import connect as ws_connect
from websockets.typing import Subprotocol
import mcp.types as types
from mcp.types import MessageFrame
logger = logging.getLogger(__name__)
@@ -19,8 +20,8 @@ async def websocket_client(
url: str,
) -> AsyncGenerator[
tuple[
MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
MemoryObjectSendStream[types.JSONRPCMessage],
MemoryObjectReceiveStream[MessageFrame[Any] | Exception],
MemoryObjectSendStream[MessageFrame[Any]],
],
None,
]:
@@ -53,7 +54,11 @@ async def websocket_client(
async with read_stream_writer:
async for raw_text in ws:
try:
message = types.JSONRPCMessage.model_validate_json(raw_text)
json_message = types.JSONRPCMessage.model_validate_json(
raw_text
)
# Create MessageFrame with JSON message as root
message = MessageFrame(message=json_message, raw=raw_text)
await read_stream_writer.send(message)
except ValidationError as exc:
# If JSON parse or model validation fails, send the exception
@@ -66,8 +71,8 @@ async def websocket_client(
"""
async with write_stream_reader:
async for message in write_stream_reader:
# Convert to a dict, then to JSON
msg_dict = message.model_dump(
# Extract the JSON-RPC message from MessageFrame and convert to JSON
msg_dict = message.message.model_dump(
by_alias=True, mode="json", exclude_none=True
)
await ws.send(json.dumps(msg_dict))