feat: add lifespan support to low-level MCP server

Adds a context manager based lifespan API in mcp.server.lowlevel.server to manage server lifecycles in a
type-safe way. This enables servers to:
- Initialize resources on startup and clean them up on shutdown
- Pass context data from startup to request handlers
- Support async startup/shutdown operations
This commit is contained in:
David Soria Parra
2025-02-11 12:14:58 +00:00
parent f10665db4c
commit 2c7bd8343e
2 changed files with 59 additions and 25 deletions

View File

@@ -68,7 +68,8 @@ import contextvars
import logging import logging
import warnings import warnings
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import Any, Sequence from contextlib import AbstractAsyncContextManager, asynccontextmanager
from typing import Any, AsyncIterator, Generic, Sequence, TypeVar
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl from pydantic import AnyUrl
@@ -101,13 +102,36 @@ class NotificationOptions:
self.tools_changed = tools_changed self.tools_changed = tools_changed
class Server: LifespanResultT = TypeVar("LifespanResultT")
@asynccontextmanager
async def lifespan(server: "Server") -> AsyncIterator[object]:
"""Default lifespan context manager that does nothing.
Args:
server: The server instance this lifespan is managing
Returns:
An empty context object
"""
yield {}
class Server(Generic[LifespanResultT]):
def __init__( def __init__(
self, name: str, version: str | None = None, instructions: str | None = None self,
name: str,
version: str | None = None,
instructions: str | None = None,
lifespan: Callable[
["Server"], AbstractAsyncContextManager[LifespanResultT]
] = lifespan,
): ):
self.name = name self.name = name
self.version = version self.version = version
self.instructions = instructions self.instructions = instructions
self.lifespan = lifespan
self.request_handlers: dict[ self.request_handlers: dict[
type, Callable[..., Awaitable[types.ServerResult]] type, Callable[..., Awaitable[types.ServerResult]]
] = { ] = {
@@ -446,6 +470,7 @@ class Server:
raise_exceptions: bool = False, raise_exceptions: bool = False,
): ):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
async with self.lifespan(self) as lifespan_context:
async with ServerSession( async with ServerSession(
read_stream, write_stream, initialization_options read_stream, write_stream, initialization_options
) as session: ) as session:
@@ -460,14 +485,20 @@ class Server:
): ):
with responder: with responder:
await self._handle_request( await self._handle_request(
message, req, session, raise_exceptions message,
req,
session,
lifespan_context,
raise_exceptions,
) )
case types.ClientNotification(root=notify): case types.ClientNotification(root=notify):
await self._handle_notification(notify) await self._handle_notification(notify)
for warning in w: for warning in w:
logger.info( logger.info(
f"Warning: {warning.category.__name__}: {warning.message}" "Warning: %s: %s",
warning.category.__name__,
warning.message,
) )
async def _handle_request( async def _handle_request(
@@ -475,6 +506,7 @@ class Server:
message: RequestResponder, message: RequestResponder,
req: Any, req: Any,
session: ServerSession, session: ServerSession,
lifespan_context: object,
raise_exceptions: bool, raise_exceptions: bool,
): ):
logger.info(f"Processing request of type {type(req).__name__}") logger.info(f"Processing request of type {type(req).__name__}")
@@ -491,6 +523,7 @@ class Server:
message.request_id, message.request_id,
message.request_meta, message.request_meta,
session, session,
lifespan_context,
) )
) )
response = await handler(req) response = await handler(req)

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Generic, TypeVar from typing import Any, Generic, TypeVar
from mcp.shared.session import BaseSession from mcp.shared.session import BaseSession
from mcp.types import RequestId, RequestParams from mcp.types import RequestId, RequestParams
@@ -12,3 +12,4 @@ class RequestContext(Generic[SessionT]):
request_id: RequestId request_id: RequestId
meta: RequestParams.Meta | None meta: RequestParams.Meta | None
session: SessionT session: SessionT
lifespan_context: Any