diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 643e1a2..27ca276 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -68,9 +68,10 @@ import contextvars import logging import warnings from collections.abc import Awaitable, Callable -from contextlib import AbstractAsyncContextManager, asynccontextmanager +from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager from typing import Any, AsyncIterator, Generic, Sequence, TypeVar +import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl @@ -458,6 +459,30 @@ class Server(Generic[LifespanResultT]): return decorator + async def _handle_message( + self, + message: RequestResponder[types.ClientRequest, types.ServerResult] + | types.ClientNotification + | Exception, + session: ServerSession, + lifespan_context: LifespanResultT, + raise_exceptions: bool = False, + ): + with warnings.catch_warnings(record=True) as w: + match message: + case ( + RequestResponder(request=types.ClientRequest(root=req)) as responder + ): + with responder: + await self._handle_request( + message, req, session, lifespan_context, raise_exceptions + ) + case types.ClientNotification(root=notify): + await self._handle_notification(notify) + + for warning in w: + logger.info(f"Warning: {warning.category.__name__}: {warning.message}") + async def run( self, read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], @@ -469,41 +494,23 @@ class Server(Generic[LifespanResultT]): # in-process servers. raise_exceptions: bool = False, ): - with warnings.catch_warnings(record=True) as w: - from contextlib import AsyncExitStack - - async with AsyncExitStack() as stack: - lifespan_context = await stack.enter_async_context(self.lifespan(self)) - session = await stack.enter_async_context( - ServerSession(read_stream, write_stream, initialization_options) - ) + async with AsyncExitStack() as stack: + lifespan_context = await stack.enter_async_context(self.lifespan(self)) + session = await stack.enter_async_context( + ServerSession(read_stream, write_stream, initialization_options) + ) + async with anyio.create_task_group() as tg: async for message in session.incoming_messages: logger.debug(f"Received message: {message}") - match message: - case ( - RequestResponder( - request=types.ClientRequest(root=req) - ) as responder - ): - with responder: - await self._handle_request( - message, - req, - session, - lifespan_context, - raise_exceptions, - ) - case types.ClientNotification(root=notify): - await self._handle_notification(notify) - - for warning in w: - logger.info( - "Warning: %s: %s", - warning.category.__name__, - warning.message, - ) + tg.start_soon( + self._handle_message, + message, + session, + lifespan_context, + raise_exceptions, + ) async def _handle_request( self,