From 2c7bd8343eb92bf6d3d3dbf3e66687b8d47cbd5f Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Tue, 11 Feb 2025 12:14:58 +0000 Subject: [PATCH] feat: add lifespan support to low-level MCP server Adds a context manager based lifespan API in mcp.server.lowlevel.server to manage server lifecycles in a type-safe way. This enables servers to: - Initialize resources on startup and clean them up on shutdown - Pass context data from startup to request handlers - Support async startup/shutdown operations --- src/mcp/server/lowlevel/server.py | 81 ++++++++++++++++++++++--------- src/mcp/shared/context.py | 3 +- 2 files changed, 59 insertions(+), 25 deletions(-) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 3d91722..28942cf 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -68,7 +68,8 @@ import contextvars import logging import warnings from collections.abc import Awaitable, Callable -from typing import Any, Sequence +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import Any, AsyncIterator, Generic, Sequence, TypeVar from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl @@ -101,13 +102,36 @@ class NotificationOptions: self.tools_changed = tools_changed -class Server: +LifespanResultT = TypeVar("LifespanResultT") + + +@asynccontextmanager +async def lifespan(server: "Server") -> AsyncIterator[object]: + """Default lifespan context manager that does nothing. + + Args: + server: The server instance this lifespan is managing + + Returns: + An empty context object + """ + yield {} + + +class Server(Generic[LifespanResultT]): def __init__( - self, name: str, version: str | None = None, instructions: str | None = None + self, + name: str, + version: str | None = None, + instructions: str | None = None, + lifespan: Callable[ + ["Server"], AbstractAsyncContextManager[LifespanResultT] + ] = lifespan, ): self.name = name self.version = version self.instructions = instructions + self.lifespan = lifespan self.request_handlers: dict[ type, Callable[..., Awaitable[types.ServerResult]] ] = { @@ -446,35 +470,43 @@ class Server: raise_exceptions: bool = False, ): with warnings.catch_warnings(record=True) as w: - async with ServerSession( - read_stream, write_stream, initialization_options - ) as session: - async for message in session.incoming_messages: - logger.debug(f"Received message: {message}") + async with self.lifespan(self) as lifespan_context: + async with ServerSession( + read_stream, write_stream, initialization_options + ) as session: + async for message in session.incoming_messages: + logger.debug(f"Received message: {message}") - match message: - case ( - RequestResponder( - request=types.ClientRequest(root=req) - ) as responder - ): - with responder: - await self._handle_request( - message, req, session, raise_exceptions - ) - case types.ClientNotification(root=notify): - await self._handle_notification(notify) + match message: + case ( + RequestResponder( + request=types.ClientRequest(root=req) + ) as responder + ): + with responder: + await self._handle_request( + message, + req, + session, + lifespan_context, + raise_exceptions, + ) + case types.ClientNotification(root=notify): + await self._handle_notification(notify) - for warning in w: - logger.info( - f"Warning: {warning.category.__name__}: {warning.message}" - ) + for warning in w: + logger.info( + "Warning: %s: %s", + warning.category.__name__, + warning.message, + ) async def _handle_request( self, message: RequestResponder, req: Any, session: ServerSession, + lifespan_context: object, raise_exceptions: bool, ): logger.info(f"Processing request of type {type(req).__name__}") @@ -491,6 +523,7 @@ class Server: message.request_id, message.request_meta, session, + lifespan_context, ) ) response = await handler(req) diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 760d558..50e5d51 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Generic, TypeVar +from typing import Any, Generic, TypeVar from mcp.shared.session import BaseSession from mcp.types import RequestId, RequestParams @@ -12,3 +12,4 @@ class RequestContext(Generic[SessionT]): request_id: RequestId meta: RequestParams.Meta | None session: SessionT + lifespan_context: Any