refactor: reorganize message handling for better type safety and clarity (#239)

* refactor: improve typing with memory stream type aliases

Move memory stream type definitions to models.py and use them throughout
the codebase for better type safety and maintainability.

GitHub-Issue:#201

* refactor: move streams to ParsedMessage

* refactor: update test files to use ParsedMessage

Updates test files to work with the ParsedMessage stream type aliases
and fixes a line length issue in test_201_client_hangs_on_logging.py.

Github-Issue:#201

* refactor: rename ParsedMessage to MessageFrame for clarity

🤖 Generated with [Claude Code](https://claude.ai/code)
Co-Authored-By: Claude <noreply@anthropic.com>

* refactor: move MessageFrame class to types.py for better code organization

🤖 Generated with [Claude Code](https://claude.ai/code)
Co-Authored-By: Claude <noreply@anthropic.com>

* fix pyright

* refactor: update websocket client to use MessageFrame

Modified the websocket client to work with the new MessageFrame type,
preserving raw message text and properly extracting the root JSON-RPC
message when sending.

Github-Issue:#204

* fix: use NoneType instead of None for type parameters in MessageFrame

🤖 Generated with [Claude Code](https://claude.ai/code)

* refactor: rename root to message
This commit is contained in:
David Soria Parra
2025-03-13 13:44:55 +00:00
committed by GitHub
parent ad7f7a5473
commit 9d0f2daddb
17 changed files with 283 additions and 151 deletions

View File

@@ -22,12 +22,18 @@ from mcp.types import (
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
MessageFrame,
RequestParams,
ServerNotification,
ServerRequest,
ServerResult,
)
ReadStream = MemoryObjectReceiveStream[MessageFrame | Exception]
ReadStreamWriter = MemoryObjectSendStream[MessageFrame | Exception]
WriteStream = MemoryObjectSendStream[MessageFrame]
WriteStreamReader = MemoryObjectReceiveStream[MessageFrame]
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
@@ -165,8 +171,8 @@ class BaseSession(
def __init__(
self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
read_stream: ReadStream,
write_stream: WriteStream,
receive_request_type: type[ReceiveRequestT],
receive_notification_type: type[ReceiveNotificationT],
# If none, reading will never time out
@@ -242,7 +248,9 @@ class BaseSession(
# TODO: Support progress callbacks
await self._write_stream.send(JSONRPCMessage(jsonrpc_request))
await self._write_stream.send(
MessageFrame(message=JSONRPCMessage(jsonrpc_request), raw=None)
)
try:
with anyio.fail_after(
@@ -278,14 +286,18 @@ class BaseSession(
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
)
await self._write_stream.send(JSONRPCMessage(jsonrpc_notification))
await self._write_stream.send(
MessageFrame(message=JSONRPCMessage(jsonrpc_notification), raw=None)
)
async def _send_response(
self, request_id: RequestId, response: SendResultT | ErrorData
) -> None:
if isinstance(response, ErrorData):
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
await self._write_stream.send(JSONRPCMessage(jsonrpc_error))
await self._write_stream.send(
MessageFrame(message=JSONRPCMessage(jsonrpc_error), raw=None)
)
else:
jsonrpc_response = JSONRPCResponse(
jsonrpc="2.0",
@@ -294,7 +306,9 @@ class BaseSession(
by_alias=True, mode="json", exclude_none=True
),
)
await self._write_stream.send(JSONRPCMessage(jsonrpc_response))
await self._write_stream.send(
MessageFrame(message=JSONRPCMessage(jsonrpc_response), raw=None)
)
async def _receive_loop(self) -> None:
async with (
@@ -302,10 +316,13 @@ class BaseSession(
self._write_stream,
self._incoming_message_stream_writer,
):
async for message in self._read_stream:
if isinstance(message, Exception):
await self._incoming_message_stream_writer.send(message)
elif isinstance(message.root, JSONRPCRequest):
async for raw_message in self._read_stream:
if isinstance(raw_message, Exception):
await self._incoming_message_stream_writer.send(raw_message)
continue
message = raw_message.message
if isinstance(message.root, JSONRPCRequest):
validated_request = self._receive_request_type.model_validate(
message.root.model_dump(
by_alias=True, mode="json", exclude_none=True