feat: add lifespan support to low-level MCP server

Adds a context manager based lifespan API in mcp.server.lowlevel.server to manage server lifecycles in a
type-safe way. This enables servers to:
- Initialize resources on startup and clean them up on shutdown
- Pass context data from startup to request handlers
- Support async startup/shutdown operations
This commit is contained in:
David Soria Parra
2025-02-11 12:14:58 +00:00
parent f10665db4c
commit 2c7bd8343e
2 changed files with 59 additions and 25 deletions

View File

@@ -68,7 +68,8 @@ import contextvars
import logging
import warnings
from collections.abc import Awaitable, Callable
from typing import Any, Sequence
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from typing import Any, AsyncIterator, Generic, Sequence, TypeVar
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
@@ -101,13 +102,36 @@ class NotificationOptions:
self.tools_changed = tools_changed
class Server:
LifespanResultT = TypeVar("LifespanResultT")
@asynccontextmanager
async def lifespan(server: "Server") -> AsyncIterator[object]:
"""Default lifespan context manager that does nothing.
Args:
server: The server instance this lifespan is managing
Returns:
An empty context object
"""
yield {}
class Server(Generic[LifespanResultT]):
def __init__(
self, name: str, version: str | None = None, instructions: str | None = None
self,
name: str,
version: str | None = None,
instructions: str | None = None,
lifespan: Callable[
["Server"], AbstractAsyncContextManager[LifespanResultT]
] = lifespan,
):
self.name = name
self.version = version
self.instructions = instructions
self.lifespan = lifespan
self.request_handlers: dict[
type, Callable[..., Awaitable[types.ServerResult]]
] = {
@@ -446,35 +470,43 @@ class Server:
raise_exceptions: bool = False,
):
with warnings.catch_warnings(record=True) as w:
async with ServerSession(
read_stream, write_stream, initialization_options
) as session:
async for message in session.incoming_messages:
logger.debug(f"Received message: {message}")
async with self.lifespan(self) as lifespan_context:
async with ServerSession(
read_stream, write_stream, initialization_options
) as session:
async for message in session.incoming_messages:
logger.debug(f"Received message: {message}")
match message:
case (
RequestResponder(
request=types.ClientRequest(root=req)
) as responder
):
with responder:
await self._handle_request(
message, req, session, raise_exceptions
)
case types.ClientNotification(root=notify):
await self._handle_notification(notify)
match message:
case (
RequestResponder(
request=types.ClientRequest(root=req)
) as responder
):
with responder:
await self._handle_request(
message,
req,
session,
lifespan_context,
raise_exceptions,
)
case types.ClientNotification(root=notify):
await self._handle_notification(notify)
for warning in w:
logger.info(
f"Warning: {warning.category.__name__}: {warning.message}"
)
for warning in w:
logger.info(
"Warning: %s: %s",
warning.category.__name__,
warning.message,
)
async def _handle_request(
self,
message: RequestResponder,
req: Any,
session: ServerSession,
lifespan_context: object,
raise_exceptions: bool,
):
logger.info(f"Processing request of type {type(req).__name__}")
@@ -491,6 +523,7 @@ class Server:
message.request_id,
message.request_meta,
session,
lifespan_context,
)
)
response = await handler(req)