diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 28942cf..a4a8510 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -470,36 +470,40 @@ class Server(Generic[LifespanResultT]): raise_exceptions: bool = False, ): with warnings.catch_warnings(record=True) as w: - async with self.lifespan(self) as lifespan_context: - async with ServerSession( - read_stream, write_stream, initialization_options - ) as session: - async for message in session.incoming_messages: - logger.debug(f"Received message: {message}") + from contextlib import AsyncExitStack - match message: - case ( - RequestResponder( - request=types.ClientRequest(root=req) - ) as responder - ): - with responder: - await self._handle_request( - message, - req, - session, - lifespan_context, - raise_exceptions, - ) - case types.ClientNotification(root=notify): - await self._handle_notification(notify) + async with AsyncExitStack() as stack: + lifespan_context = await stack.enter_async_context(self.lifespan(self)) + session = await stack.enter_async_context( + ServerSession(read_stream, write_stream, initialization_options) + ) - for warning in w: - logger.info( - "Warning: %s: %s", - warning.category.__name__, - warning.message, - ) + async for message in session.incoming_messages: + logger.debug(f"Received message: {message}") + + match message: + case ( + RequestResponder( + request=types.ClientRequest(root=req) + ) as responder + ): + with responder: + await self._handle_request( + message, + req, + session, + lifespan_context, + raise_exceptions, + ) + case types.ClientNotification(root=notify): + await self._handle_notification(notify) + + for warning in w: + logger.info( + "Warning: %s: %s", + warning.category.__name__, + warning.message, + ) async def _handle_request( self,