Fix #201: Move incoming message stream from BaseSession to ServerSession (#325)

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
David Soria Parra
2025-03-24 14:14:14 +00:00
committed by GitHub
parent 9ae4df85fb
commit 568cbd1a66
9 changed files with 168 additions and 109 deletions

View File

@@ -7,9 +7,11 @@ from urllib.parse import urlparse
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
import mcp.types as types
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.shared.session import RequestResponder
from mcp.types import JSONRPCMessage
if not sys.warnoptions:
@@ -21,26 +23,25 @@ logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("client")
async def receive_loop(session: ClientSession):
logger.info("Starting receive loop")
async for message in session.incoming_messages:
if isinstance(message, Exception):
logger.error("Error: %s", message)
continue
async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
) -> None:
if isinstance(message, Exception):
logger.error("Error: %s", message)
return
logger.info("Received message from server: %s", message)
logger.info("Received message from server: %s", message)
async def run_session(
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
):
async with (
ClientSession(read_stream, write_stream) as session,
anyio.create_task_group() as tg,
):
tg.start_soon(receive_loop, session)
async with ClientSession(
read_stream, write_stream, message_handler=message_handler
) as session:
logger.info("Initializing session")
await session.initialize()
logger.info("Initialized")

View File

@@ -1,6 +1,7 @@
from datetime import timedelta
from typing import Any, Protocol
import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl, TypeAdapter
@@ -31,6 +32,23 @@ class LoggingFnT(Protocol):
) -> None: ...
class MessageHandlerFnT(Protocol):
async def __call__(
self,
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
) -> None: ...
async def _default_message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
) -> None:
await anyio.lowlevel.checkpoint()
async def _default_sampling_callback(
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
@@ -78,6 +96,7 @@ class ClientSession(
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
) -> None:
super().__init__(
read_stream,
@@ -89,6 +108,7 @@ class ClientSession(
self._sampling_callback = sampling_callback or _default_sampling_callback
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
self._logging_callback = logging_callback or _default_logging_callback
self._message_handler = message_handler or _default_message_handler
async def initialize(self) -> types.InitializeResult:
sampling = types.SamplingCapability()
@@ -337,10 +357,20 @@ class ClientSession(
types.ClientResult(root=types.EmptyResult())
)
async def _handle_incoming(
self,
req: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
) -> None:
"""Handle incoming messages by forwarding to the message handler."""
await self._message_handler(req)
async def _received_notification(
self, notification: types.ServerNotification
) -> None:
"""Handle notifications from the server."""
# Process specific notification types
match notification.root:
case types.LoggingMessageNotification(params=params):
await self._logging_callback(params)

View File

@@ -61,6 +61,12 @@ class InitializationState(Enum):
ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession")
ServerRequestResponder = (
RequestResponder[types.ClientRequest, types.ServerResult]
| types.ClientNotification
| Exception
)
class ServerSession(
BaseSession[
@@ -85,6 +91,15 @@ class ServerSession(
)
self._initialization_state = InitializationState.NotInitialized
self._init_options = init_options
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
anyio.create_memory_object_stream[ServerRequestResponder](0)
)
self._exit_stack.push_async_callback(
lambda: self._incoming_message_stream_reader.aclose()
)
self._exit_stack.push_async_callback(
lambda: self._incoming_message_stream_writer.aclose()
)
@property
def client_params(self) -> types.InitializeRequestParams | None:
@@ -291,3 +306,12 @@ class ServerSession(
)
)
)
async def _handle_incoming(self, req: ServerRequestResponder) -> None:
await self._incoming_message_stream_writer.send(req)
@property
def incoming_messages(
self,
) -> MemoryObjectReceiveStream[ServerRequestResponder]:
return self._incoming_message_stream_reader

View File

@@ -10,7 +10,13 @@ from typing import Any
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp.client.session import ClientSession, ListRootsFnT, LoggingFnT, SamplingFnT
from mcp.client.session import (
ClientSession,
ListRootsFnT,
LoggingFnT,
MessageHandlerFnT,
SamplingFnT,
)
from mcp.server import Server
from mcp.types import JSONRPCMessage
@@ -58,6 +64,7 @@ async def create_connected_server_and_client_session(
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
raise_exceptions: bool = False,
) -> AsyncGenerator[ClientSession, None]:
"""Creates a ClientSession that is connected to a running MCP server."""
@@ -87,6 +94,7 @@ async def create_connected_server_and_client_session(
sampling_callback=sampling_callback,
list_roots_callback=list_roots_callback,
logging_callback=logging_callback,
message_handler=message_handler,
) as client_session:
await client_session.initialize()
yield client_session

View File

@@ -189,19 +189,6 @@ class BaseSession(
self._in_flight = {}
self._exit_stack = AsyncExitStack()
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
anyio.create_memory_object_stream[
RequestResponder[ReceiveRequestT, SendResultT]
| ReceiveNotificationT
| Exception
]()
)
self._exit_stack.push_async_callback(
lambda: self._incoming_message_stream_reader.aclose()
)
self._exit_stack.push_async_callback(
lambda: self._incoming_message_stream_writer.aclose()
)
async def __aenter__(self) -> Self:
self._task_group = anyio.create_task_group()
@@ -312,11 +299,10 @@ class BaseSession(
async with (
self._read_stream,
self._write_stream,
self._incoming_message_stream_writer,
):
async for message in self._read_stream:
if isinstance(message, Exception):
await self._incoming_message_stream_writer.send(message)
await self._handle_incoming(message)
elif isinstance(message.root, JSONRPCRequest):
validated_request = self._receive_request_type.model_validate(
message.root.model_dump(
@@ -336,8 +322,9 @@ class BaseSession(
self._in_flight[responder.request_id] = responder
await self._received_request(responder)
if not responder._completed: # type: ignore[reportPrivateUsage]
await self._incoming_message_stream_writer.send(responder)
await self._handle_incoming(responder)
elif isinstance(message.root, JSONRPCNotification):
try:
@@ -353,9 +340,7 @@ class BaseSession(
await self._in_flight[cancelled_id].cancel()
else:
await self._received_notification(notification)
await self._incoming_message_stream_writer.send(
notification
)
await self._handle_incoming(notification)
except Exception as e:
# For other validation errors, log and continue
logging.warning(
@@ -367,7 +352,7 @@ class BaseSession(
if stream:
await stream.send(message.root)
else:
await self._incoming_message_stream_writer.send(
await self._handle_incoming(
RuntimeError(
"Received response with an unknown "
f"request ID: {message}"
@@ -399,12 +384,11 @@ class BaseSession(
processed.
"""
@property
def incoming_messages(
async def _handle_incoming(
self,
) -> MemoryObjectReceiveStream[
RequestResponder[ReceiveRequestT, SendResultT]
req: RequestResponder[ReceiveRequestT, SendResultT]
| ReceiveNotificationT
| Exception
]:
return self._incoming_message_stream_reader
| Exception,
) -> None:
"""A generic handler for incoming messages. Overwritten by subclasses."""
pass

View File

@@ -1,11 +1,12 @@
from typing import Literal
import anyio
import pytest
import mcp.types as types
from mcp.shared.memory import (
create_connected_server_and_client_session as create_session,
)
from mcp.shared.session import RequestResponder
from mcp.types import (
LoggingMessageNotificationParams,
TextContent,
@@ -46,40 +47,37 @@ async def test_logging_callback():
)
return True
async with anyio.create_task_group() as tg:
async with create_session(
server._mcp_server, logging_callback=logging_collector
) as client_session:
# Create a message handler to catch exceptions
async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
) -> None:
if isinstance(message, Exception):
raise message
async def listen_session():
try:
async for message in client_session.incoming_messages:
if isinstance(message, Exception):
raise message
except anyio.EndOfStream:
pass
async with create_session(
server._mcp_server,
logging_callback=logging_collector,
message_handler=message_handler,
) as client_session:
# First verify our test tool works
result = await client_session.call_tool("test_tool", {})
assert result.isError is False
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "true"
tg.start_soon(listen_session)
# First verify our test tool works
result = await client_session.call_tool("test_tool", {})
assert result.isError is False
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "true"
# Now send a log message via our tool
log_result = await client_session.call_tool(
"test_tool_with_log",
{
"message": "Test log message",
"level": "info",
"logger": "test_logger",
},
)
assert log_result.isError is False
assert len(logging_collector.log_messages) == 1
assert logging_collector.log_messages[
0
] == LoggingMessageNotificationParams(
level="info", logger="test_logger", data="Test log message"
)
# Now send a log message via our tool
log_result = await client_session.call_tool(
"test_tool_with_log",
{
"message": "Test log message",
"level": "info",
"logger": "test_logger",
},
)
assert log_result.isError is False
assert len(logging_collector.log_messages) == 1
assert logging_collector.log_messages[0] == LoggingMessageNotificationParams(
level="info", logger="test_logger", data="Test log message"
)

View File

@@ -1,7 +1,9 @@
import anyio
import pytest
import mcp.types as types
from mcp.client.session import ClientSession
from mcp.shared.session import RequestResponder
from mcp.types import (
LATEST_PROTOCOL_VERSION,
ClientNotification,
@@ -75,13 +77,21 @@ async def test_client_session_initialize():
)
)
async def listen_session():
async for message in session.incoming_messages:
if isinstance(message, Exception):
raise message
# Create a message handler to catch exceptions
async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
) -> None:
if isinstance(message, Exception):
raise message
async with (
ClientSession(server_to_client_receive, client_to_server_send) as session,
ClientSession(
server_to_client_receive,
client_to_server_send,
message_handler=message_handler,
) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
@@ -89,7 +99,6 @@ async def test_client_session_initialize():
server_to_client_receive,
):
tg.start_soon(mock_server)
tg.start_soon(listen_session)
result = await session.initialize()
# Assert the result

View File

@@ -6,6 +6,7 @@ from pathlib import Path
import anyio
import pytest
from anyio.abc import TaskStatus
from mcp.client.session import ClientSession
from mcp.server.lowlevel import Server
@@ -54,15 +55,21 @@ async def test_notification_validation_error(tmp_path: Path):
return [TextContent(type="text", text=f"fast {request_count}")]
return [TextContent(type="text", text=f"unknown {request_count}")]
async def server_handler(read_stream, write_stream):
await server.run(
read_stream,
write_stream,
server.create_initialization_options(),
raise_exceptions=True,
)
async def server_handler(
read_stream,
write_stream,
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
):
with anyio.CancelScope() as scope:
task_status.started(scope) # type: ignore
await server.run(
read_stream,
write_stream,
server.create_initialization_options(),
raise_exceptions=True,
)
async def client(read_stream, write_stream):
async def client(read_stream, write_stream, scope):
# Use a timeout that's:
# - Long enough for fast operations (>10ms)
# - Short enough for slow operations (<200ms)
@@ -90,22 +97,13 @@ async def test_notification_validation_error(tmp_path: Path):
# proving server is still responsive
result = await session.call_tool("fast")
assert result.content == [TextContent(type="text", text="fast 3")]
scope.cancel()
# Run server and client in separate task groups to avoid cancellation
server_writer, server_reader = anyio.create_memory_object_stream(1)
client_writer, client_reader = anyio.create_memory_object_stream(1)
server_ready = anyio.Event()
async def wrapped_server_handler(read_stream, write_stream):
server_ready.set()
await server_handler(read_stream, write_stream)
async with anyio.create_task_group() as tg:
tg.start_soon(wrapped_server_handler, server_reader, client_writer)
# Wait for server to start and initialize
with anyio.fail_after(1): # Timeout after 1 second
await server_ready.wait()
scope = await tg.start(server_handler, server_reader, client_writer)
# Run client in a separate task to avoid cancellation
async with anyio.create_task_group() as client_tg:
client_tg.start_soon(client, client_reader, server_writer)
tg.start_soon(client, client_reader, server_writer, scope)

View File

@@ -1,11 +1,13 @@
import anyio
import pytest
import mcp.types as types
from mcp.client.session import ClientSession
from mcp.server import Server
from mcp.server.lowlevel import NotificationOptions
from mcp.server.models import InitializationOptions
from mcp.server.session import ServerSession
from mcp.shared.session import RequestResponder
from mcp.types import (
ClientNotification,
InitializedNotification,
@@ -25,10 +27,14 @@ async def test_server_session_initialize():
JSONRPCMessage
](1)
async def run_client(client: ClientSession):
async for message in client_session.incoming_messages:
if isinstance(message, Exception):
raise message
# Create a message handler to catch exceptions
async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
) -> None:
if isinstance(message, Exception):
raise message
received_initialized = False
@@ -57,11 +63,12 @@ async def test_server_session_initialize():
try:
async with (
ClientSession(
server_to_client_receive, client_to_server_send
server_to_client_receive,
client_to_server_send,
message_handler=message_handler,
) as client_session,
anyio.create_task_group() as tg,
):
tg.start_soon(run_client, client_session)
tg.start_soon(run_server)
await client_session.initialize()