Fix #201: Move incoming message stream from BaseSession to ServerSession (#325)

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
David Soria Parra
2025-03-24 14:14:14 +00:00
committed by GitHub
parent 9ae4df85fb
commit 568cbd1a66
9 changed files with 168 additions and 109 deletions

View File

@@ -7,9 +7,11 @@ from urllib.parse import urlparse
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
import mcp.types as types
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.shared.session import RequestResponder
from mcp.types import JSONRPCMessage
if not sys.warnoptions:
@@ -21,26 +23,25 @@ logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("client")
async def receive_loop(session: ClientSession):
logger.info("Starting receive loop")
async for message in session.incoming_messages:
if isinstance(message, Exception):
logger.error("Error: %s", message)
continue
async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
) -> None:
if isinstance(message, Exception):
logger.error("Error: %s", message)
return
logger.info("Received message from server: %s", message)
logger.info("Received message from server: %s", message)
async def run_session(
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
):
async with (
ClientSession(read_stream, write_stream) as session,
anyio.create_task_group() as tg,
):
tg.start_soon(receive_loop, session)
async with ClientSession(
read_stream, write_stream, message_handler=message_handler
) as session:
logger.info("Initializing session")
await session.initialize()
logger.info("Initialized")

View File

@@ -1,6 +1,7 @@
from datetime import timedelta
from typing import Any, Protocol
import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl, TypeAdapter
@@ -31,6 +32,23 @@ class LoggingFnT(Protocol):
) -> None: ...
class MessageHandlerFnT(Protocol):
async def __call__(
self,
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
) -> None: ...
async def _default_message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
) -> None:
await anyio.lowlevel.checkpoint()
async def _default_sampling_callback(
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
@@ -78,6 +96,7 @@ class ClientSession(
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
) -> None:
super().__init__(
read_stream,
@@ -89,6 +108,7 @@ class ClientSession(
self._sampling_callback = sampling_callback or _default_sampling_callback
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
self._logging_callback = logging_callback or _default_logging_callback
self._message_handler = message_handler or _default_message_handler
async def initialize(self) -> types.InitializeResult:
sampling = types.SamplingCapability()
@@ -337,10 +357,20 @@ class ClientSession(
types.ClientResult(root=types.EmptyResult())
)
async def _handle_incoming(
self,
req: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
) -> None:
"""Handle incoming messages by forwarding to the message handler."""
await self._message_handler(req)
async def _received_notification(
self, notification: types.ServerNotification
) -> None:
"""Handle notifications from the server."""
# Process specific notification types
match notification.root:
case types.LoggingMessageNotification(params=params):
await self._logging_callback(params)

View File

@@ -61,6 +61,12 @@ class InitializationState(Enum):
ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession")
ServerRequestResponder = (
RequestResponder[types.ClientRequest, types.ServerResult]
| types.ClientNotification
| Exception
)
class ServerSession(
BaseSession[
@@ -85,6 +91,15 @@ class ServerSession(
)
self._initialization_state = InitializationState.NotInitialized
self._init_options = init_options
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
anyio.create_memory_object_stream[ServerRequestResponder](0)
)
self._exit_stack.push_async_callback(
lambda: self._incoming_message_stream_reader.aclose()
)
self._exit_stack.push_async_callback(
lambda: self._incoming_message_stream_writer.aclose()
)
@property
def client_params(self) -> types.InitializeRequestParams | None:
@@ -291,3 +306,12 @@ class ServerSession(
)
)
)
async def _handle_incoming(self, req: ServerRequestResponder) -> None:
await self._incoming_message_stream_writer.send(req)
@property
def incoming_messages(
self,
) -> MemoryObjectReceiveStream[ServerRequestResponder]:
return self._incoming_message_stream_reader

View File

@@ -10,7 +10,13 @@ from typing import Any
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp.client.session import ClientSession, ListRootsFnT, LoggingFnT, SamplingFnT
from mcp.client.session import (
ClientSession,
ListRootsFnT,
LoggingFnT,
MessageHandlerFnT,
SamplingFnT,
)
from mcp.server import Server
from mcp.types import JSONRPCMessage
@@ -58,6 +64,7 @@ async def create_connected_server_and_client_session(
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
raise_exceptions: bool = False,
) -> AsyncGenerator[ClientSession, None]:
"""Creates a ClientSession that is connected to a running MCP server."""
@@ -87,6 +94,7 @@ async def create_connected_server_and_client_session(
sampling_callback=sampling_callback,
list_roots_callback=list_roots_callback,
logging_callback=logging_callback,
message_handler=message_handler,
) as client_session:
await client_session.initialize()
yield client_session

View File

@@ -189,19 +189,6 @@ class BaseSession(
self._in_flight = {}
self._exit_stack = AsyncExitStack()
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
anyio.create_memory_object_stream[
RequestResponder[ReceiveRequestT, SendResultT]
| ReceiveNotificationT
| Exception
]()
)
self._exit_stack.push_async_callback(
lambda: self._incoming_message_stream_reader.aclose()
)
self._exit_stack.push_async_callback(
lambda: self._incoming_message_stream_writer.aclose()
)
async def __aenter__(self) -> Self:
self._task_group = anyio.create_task_group()
@@ -312,11 +299,10 @@ class BaseSession(
async with (
self._read_stream,
self._write_stream,
self._incoming_message_stream_writer,
):
async for message in self._read_stream:
if isinstance(message, Exception):
await self._incoming_message_stream_writer.send(message)
await self._handle_incoming(message)
elif isinstance(message.root, JSONRPCRequest):
validated_request = self._receive_request_type.model_validate(
message.root.model_dump(
@@ -336,8 +322,9 @@ class BaseSession(
self._in_flight[responder.request_id] = responder
await self._received_request(responder)
if not responder._completed: # type: ignore[reportPrivateUsage]
await self._incoming_message_stream_writer.send(responder)
await self._handle_incoming(responder)
elif isinstance(message.root, JSONRPCNotification):
try:
@@ -353,9 +340,7 @@ class BaseSession(
await self._in_flight[cancelled_id].cancel()
else:
await self._received_notification(notification)
await self._incoming_message_stream_writer.send(
notification
)
await self._handle_incoming(notification)
except Exception as e:
# For other validation errors, log and continue
logging.warning(
@@ -367,7 +352,7 @@ class BaseSession(
if stream:
await stream.send(message.root)
else:
await self._incoming_message_stream_writer.send(
await self._handle_incoming(
RuntimeError(
"Received response with an unknown "
f"request ID: {message}"
@@ -399,12 +384,11 @@ class BaseSession(
processed.
"""
@property
def incoming_messages(
async def _handle_incoming(
self,
) -> MemoryObjectReceiveStream[
RequestResponder[ReceiveRequestT, SendResultT]
req: RequestResponder[ReceiveRequestT, SendResultT]
| ReceiveNotificationT
| Exception
]:
return self._incoming_message_stream_reader
| Exception,
) -> None:
"""A generic handler for incoming messages. Overwritten by subclasses."""
pass