mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 23:04:25 +01:00
feat: add lifespan support to low-level MCP server
Adds a context manager based lifespan API in mcp.server.lowlevel.server to manage server lifecycles in a type-safe way. This enables servers to: - Initialize resources on startup and clean them up on shutdown - Pass context data from startup to request handlers - Support async startup/shutdown operations
This commit is contained in:
@@ -68,7 +68,8 @@ import contextvars
|
|||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import Any, Sequence
|
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||||
|
from typing import Any, AsyncIterator, Generic, Sequence, TypeVar
|
||||||
|
|
||||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||||
from pydantic import AnyUrl
|
from pydantic import AnyUrl
|
||||||
@@ -101,13 +102,36 @@ class NotificationOptions:
|
|||||||
self.tools_changed = tools_changed
|
self.tools_changed = tools_changed
|
||||||
|
|
||||||
|
|
||||||
class Server:
|
LifespanResultT = TypeVar("LifespanResultT")
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(server: "Server") -> AsyncIterator[object]:
|
||||||
|
"""Default lifespan context manager that does nothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server: The server instance this lifespan is managing
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An empty context object
|
||||||
|
"""
|
||||||
|
yield {}
|
||||||
|
|
||||||
|
|
||||||
|
class Server(Generic[LifespanResultT]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, name: str, version: str | None = None, instructions: str | None = None
|
self,
|
||||||
|
name: str,
|
||||||
|
version: str | None = None,
|
||||||
|
instructions: str | None = None,
|
||||||
|
lifespan: Callable[
|
||||||
|
["Server"], AbstractAsyncContextManager[LifespanResultT]
|
||||||
|
] = lifespan,
|
||||||
):
|
):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.version = version
|
self.version = version
|
||||||
self.instructions = instructions
|
self.instructions = instructions
|
||||||
|
self.lifespan = lifespan
|
||||||
self.request_handlers: dict[
|
self.request_handlers: dict[
|
||||||
type, Callable[..., Awaitable[types.ServerResult]]
|
type, Callable[..., Awaitable[types.ServerResult]]
|
||||||
] = {
|
] = {
|
||||||
@@ -446,6 +470,7 @@ class Server:
|
|||||||
raise_exceptions: bool = False,
|
raise_exceptions: bool = False,
|
||||||
):
|
):
|
||||||
with warnings.catch_warnings(record=True) as w:
|
with warnings.catch_warnings(record=True) as w:
|
||||||
|
async with self.lifespan(self) as lifespan_context:
|
||||||
async with ServerSession(
|
async with ServerSession(
|
||||||
read_stream, write_stream, initialization_options
|
read_stream, write_stream, initialization_options
|
||||||
) as session:
|
) as session:
|
||||||
@@ -460,14 +485,20 @@ class Server:
|
|||||||
):
|
):
|
||||||
with responder:
|
with responder:
|
||||||
await self._handle_request(
|
await self._handle_request(
|
||||||
message, req, session, raise_exceptions
|
message,
|
||||||
|
req,
|
||||||
|
session,
|
||||||
|
lifespan_context,
|
||||||
|
raise_exceptions,
|
||||||
)
|
)
|
||||||
case types.ClientNotification(root=notify):
|
case types.ClientNotification(root=notify):
|
||||||
await self._handle_notification(notify)
|
await self._handle_notification(notify)
|
||||||
|
|
||||||
for warning in w:
|
for warning in w:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Warning: {warning.category.__name__}: {warning.message}"
|
"Warning: %s: %s",
|
||||||
|
warning.category.__name__,
|
||||||
|
warning.message,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _handle_request(
|
async def _handle_request(
|
||||||
@@ -475,6 +506,7 @@ class Server:
|
|||||||
message: RequestResponder,
|
message: RequestResponder,
|
||||||
req: Any,
|
req: Any,
|
||||||
session: ServerSession,
|
session: ServerSession,
|
||||||
|
lifespan_context: object,
|
||||||
raise_exceptions: bool,
|
raise_exceptions: bool,
|
||||||
):
|
):
|
||||||
logger.info(f"Processing request of type {type(req).__name__}")
|
logger.info(f"Processing request of type {type(req).__name__}")
|
||||||
@@ -491,6 +523,7 @@ class Server:
|
|||||||
message.request_id,
|
message.request_id,
|
||||||
message.request_meta,
|
message.request_meta,
|
||||||
session,
|
session,
|
||||||
|
lifespan_context,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
response = await handler(req)
|
response = await handler(req)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Generic, TypeVar
|
from typing import Any, Generic, TypeVar
|
||||||
|
|
||||||
from mcp.shared.session import BaseSession
|
from mcp.shared.session import BaseSession
|
||||||
from mcp.types import RequestId, RequestParams
|
from mcp.types import RequestId, RequestParams
|
||||||
@@ -12,3 +12,4 @@ class RequestContext(Generic[SessionT]):
|
|||||||
request_id: RequestId
|
request_id: RequestId
|
||||||
meta: RequestParams.Meta | None
|
meta: RequestParams.Meta | None
|
||||||
session: SessionT
|
session: SessionT
|
||||||
|
lifespan_context: Any
|
||||||
|
|||||||
Reference in New Issue
Block a user