mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
feat: add lifespan support to low-level MCP server
Adds a context manager based lifespan API in mcp.server.lowlevel.server to manage server lifecycles in a type-safe way. This enables servers to: - Initialize resources on startup and clean them up on shutdown - Pass context data from startup to request handlers - Support async startup/shutdown operations
This commit is contained in:
@@ -68,7 +68,8 @@ import contextvars
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, Sequence
|
||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||
from typing import Any, AsyncIterator, Generic, Sequence, TypeVar
|
||||
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl
|
||||
@@ -101,13 +102,36 @@ class NotificationOptions:
|
||||
self.tools_changed = tools_changed
|
||||
|
||||
|
||||
class Server:
|
||||
LifespanResultT = TypeVar("LifespanResultT")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(server: "Server") -> AsyncIterator[object]:
|
||||
"""Default lifespan context manager that does nothing.
|
||||
|
||||
Args:
|
||||
server: The server instance this lifespan is managing
|
||||
|
||||
Returns:
|
||||
An empty context object
|
||||
"""
|
||||
yield {}
|
||||
|
||||
|
||||
class Server(Generic[LifespanResultT]):
|
||||
def __init__(
|
||||
self, name: str, version: str | None = None, instructions: str | None = None
|
||||
self,
|
||||
name: str,
|
||||
version: str | None = None,
|
||||
instructions: str | None = None,
|
||||
lifespan: Callable[
|
||||
["Server"], AbstractAsyncContextManager[LifespanResultT]
|
||||
] = lifespan,
|
||||
):
|
||||
self.name = name
|
||||
self.version = version
|
||||
self.instructions = instructions
|
||||
self.lifespan = lifespan
|
||||
self.request_handlers: dict[
|
||||
type, Callable[..., Awaitable[types.ServerResult]]
|
||||
] = {
|
||||
@@ -446,6 +470,7 @@ class Server:
|
||||
raise_exceptions: bool = False,
|
||||
):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
async with self.lifespan(self) as lifespan_context:
|
||||
async with ServerSession(
|
||||
read_stream, write_stream, initialization_options
|
||||
) as session:
|
||||
@@ -460,14 +485,20 @@ class Server:
|
||||
):
|
||||
with responder:
|
||||
await self._handle_request(
|
||||
message, req, session, raise_exceptions
|
||||
message,
|
||||
req,
|
||||
session,
|
||||
lifespan_context,
|
||||
raise_exceptions,
|
||||
)
|
||||
case types.ClientNotification(root=notify):
|
||||
await self._handle_notification(notify)
|
||||
|
||||
for warning in w:
|
||||
logger.info(
|
||||
f"Warning: {warning.category.__name__}: {warning.message}"
|
||||
"Warning: %s: %s",
|
||||
warning.category.__name__,
|
||||
warning.message,
|
||||
)
|
||||
|
||||
async def _handle_request(
|
||||
@@ -475,6 +506,7 @@ class Server:
|
||||
message: RequestResponder,
|
||||
req: Any,
|
||||
session: ServerSession,
|
||||
lifespan_context: object,
|
||||
raise_exceptions: bool,
|
||||
):
|
||||
logger.info(f"Processing request of type {type(req).__name__}")
|
||||
@@ -491,6 +523,7 @@ class Server:
|
||||
message.request_id,
|
||||
message.request_meta,
|
||||
session,
|
||||
lifespan_context,
|
||||
)
|
||||
)
|
||||
response = await handler(req)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Generic, TypeVar
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from mcp.shared.session import BaseSession
|
||||
from mcp.types import RequestId, RequestParams
|
||||
@@ -12,3 +12,4 @@ class RequestContext(Generic[SessionT]):
|
||||
request_id: RequestId
|
||||
meta: RequestParams.Meta | None
|
||||
session: SessionT
|
||||
lifespan_context: Any
|
||||
|
||||
Reference in New Issue
Block a user