Merge pull request #203 from modelcontextprotocol/davidsp/clean-lifespan

feat: add lifespan context manager support
This commit is contained in:
David Soria Parra
2025-02-13 14:26:28 +00:00
committed by GitHub
7 changed files with 372 additions and 28 deletions

View File

@@ -128,6 +128,9 @@ The [Model Context Protocol (MCP)](https://modelcontextprotocol.io) lets you bui
The FastMCP server is your core interface to the MCP protocol. It handles connection management, protocol compliance, and message routing:
```python
# Add lifespan support for startup/shutdown with strong typing
from dataclasses import dataclass
from typing import AsyncIterator
from mcp.server.fastmcp import FastMCP
# Create a named server
@@ -135,6 +138,31 @@ mcp = FastMCP("My App")
# Specify dependencies for deployment and development
mcp = FastMCP("My App", dependencies=["pandas", "numpy"])
@dataclass
class AppContext:
db: Database # Replace with your actual DB type
@asynccontextmanager
async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]:
"""Manage application lifecycle with type-safe context"""
try:
# Initialize on startup
await db.connect()
yield AppContext(db=db)
finally:
# Cleanup on shutdown
await db.disconnect()
# Pass lifespan to server
mcp = FastMCP("My App", lifespan=app_lifespan)
# Access type-safe lifespan context in tools
@mcp.tool()
def query_db(ctx: Context) -> str:
"""Tool that uses initialized resources"""
db = ctx.request_context.lifespan_context["db"]
return db.query()
```
### Resources
@@ -334,7 +362,38 @@ def query_data(sql: str) -> str:
### Low-Level Server
For more control, you can use the low-level server implementation directly. This gives you full access to the protocol and allows you to customize every aspect of your server:
For more control, you can use the low-level server implementation directly. This gives you full access to the protocol and allows you to customize every aspect of your server, including lifecycle management through the lifespan API:
```python
from contextlib import asynccontextmanager
from typing import AsyncIterator
@asynccontextmanager
async def server_lifespan(server: Server) -> AsyncIterator[dict]:
"""Manage server startup and shutdown lifecycle."""
try:
# Initialize resources on startup
await db.connect()
yield {"db": db}
finally:
# Clean up on shutdown
await db.disconnect()
# Pass lifespan to server
server = Server("example-server", lifespan=server_lifespan)
# Access lifespan context in handlers
@server.call_tool()
async def query_db(name: str, arguments: dict) -> list:
ctx = server.request_context
db = ctx.lifespan_context["db"]
return await db.query(arguments["query"])
```
The lifespan API provides:
- A way to initialize resources when the server starts and clean them up when it stops
- Access to initialized resources through the request context in handlers
- Type-safe context passing between lifespan and request handlers
```python
from mcp.server.lowlevel import Server, NotificationOptions

View File

@@ -3,8 +3,13 @@
import inspect
import json
import re
from collections.abc import AsyncIterator
from contextlib import (
AbstractAsyncContextManager,
asynccontextmanager,
)
from itertools import chain
from typing import Any, Callable, Literal, Sequence
from typing import Any, Callable, Generic, Literal, Sequence
import anyio
import pydantic_core
@@ -19,8 +24,16 @@ from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceMan
from mcp.server.fastmcp.tools import ToolManager
from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger
from mcp.server.fastmcp.utilities.types import Image
from mcp.server.lowlevel import Server as MCPServer
from mcp.server.lowlevel.helper_types import ReadResourceContents
from mcp.server.lowlevel.server import (
LifespanResultT,
)
from mcp.server.lowlevel.server import (
Server as MCPServer,
)
from mcp.server.lowlevel.server import (
lifespan as default_lifespan,
)
from mcp.server.sse import SseServerTransport
from mcp.server.stdio import stdio_server
from mcp.shared.context import RequestContext
@@ -50,7 +63,7 @@ from mcp.types import (
logger = get_logger(__name__)
class Settings(BaseSettings):
class Settings(BaseSettings, Generic[LifespanResultT]):
"""FastMCP server settings.
All settings can be configured via environment variables with the prefix FASTMCP_.
@@ -85,13 +98,36 @@ class Settings(BaseSettings):
description="List of dependencies to install in the server environment",
)
lifespan: (
Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]] | None
) = Field(None, description="Lifespan context manager")
def lifespan_wrapper(
app: "FastMCP",
lifespan: Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]],
) -> Callable[[MCPServer], AbstractAsyncContextManager[object]]:
@asynccontextmanager
async def wrap(s: MCPServer) -> AsyncIterator[object]:
async with lifespan(app) as context:
yield context
return wrap
class FastMCP:
def __init__(
self, name: str | None = None, instructions: str | None = None, **settings: Any
):
self.settings = Settings(**settings)
self._mcp_server = MCPServer(name=name or "FastMCP", instructions=instructions)
self._mcp_server = MCPServer(
name=name or "FastMCP",
instructions=instructions,
lifespan=lifespan_wrapper(self, self.settings.lifespan)
if self.settings.lifespan
else default_lifespan,
)
self._tool_manager = ToolManager(
warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools
)

View File

@@ -68,7 +68,8 @@ import contextvars
import logging
import warnings
from collections.abc import Awaitable, Callable
from typing import Any, Sequence
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from typing import Any, AsyncIterator, Generic, Sequence, TypeVar
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
@@ -84,7 +85,10 @@ from mcp.shared.session import RequestResponder
logger = logging.getLogger(__name__)
request_ctx: contextvars.ContextVar[RequestContext[ServerSession]] = (
LifespanResultT = TypeVar("LifespanResultT")
# This will be properly typed in each Server instance's context
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any]] = (
contextvars.ContextVar("request_ctx")
)
@@ -101,13 +105,33 @@ class NotificationOptions:
self.tools_changed = tools_changed
class Server:
@asynccontextmanager
async def lifespan(server: "Server") -> AsyncIterator[object]:
"""Default lifespan context manager that does nothing.
Args:
server: The server instance this lifespan is managing
Returns:
An empty context object
"""
yield {}
class Server(Generic[LifespanResultT]):
def __init__(
self, name: str, version: str | None = None, instructions: str | None = None
self,
name: str,
version: str | None = None,
instructions: str | None = None,
lifespan: Callable[
["Server"], AbstractAsyncContextManager[LifespanResultT]
] = lifespan,
):
self.name = name
self.version = version
self.instructions = instructions
self.lifespan = lifespan
self.request_handlers: dict[
type, Callable[..., Awaitable[types.ServerResult]]
] = {
@@ -188,7 +212,7 @@ class Server:
)
@property
def request_context(self) -> RequestContext[ServerSession]:
def request_context(self) -> RequestContext[ServerSession, LifespanResultT]:
"""If called outside of a request context, this will raise a LookupError."""
return request_ctx.get()
@@ -446,9 +470,14 @@ class Server:
raise_exceptions: bool = False,
):
with warnings.catch_warnings(record=True) as w:
async with ServerSession(
read_stream, write_stream, initialization_options
) as session:
from contextlib import AsyncExitStack
async with AsyncExitStack() as stack:
lifespan_context = await stack.enter_async_context(self.lifespan(self))
session = await stack.enter_async_context(
ServerSession(read_stream, write_stream, initialization_options)
)
async for message in session.incoming_messages:
logger.debug(f"Received message: {message}")
@@ -460,14 +489,20 @@ class Server:
):
with responder:
await self._handle_request(
message, req, session, raise_exceptions
message,
req,
session,
lifespan_context,
raise_exceptions,
)
case types.ClientNotification(root=notify):
await self._handle_notification(notify)
for warning in w:
logger.info(
f"Warning: {warning.category.__name__}: {warning.message}"
"Warning: %s: %s",
warning.category.__name__,
warning.message,
)
async def _handle_request(
@@ -475,6 +510,7 @@ class Server:
message: RequestResponder,
req: Any,
session: ServerSession,
lifespan_context: LifespanResultT,
raise_exceptions: bool,
):
logger.info(f"Processing request of type {type(req).__name__}")
@@ -491,6 +527,7 @@ class Server:
message.request_id,
message.request_meta,
session,
lifespan_context,
)
)
response = await handler(req)

View File

@@ -5,10 +5,12 @@ from mcp.shared.session import BaseSession
from mcp.types import RequestId, RequestParams
SessionT = TypeVar("SessionT", bound=BaseSession)
LifespanContextT = TypeVar("LifespanContextT")
@dataclass
class RequestContext(Generic[SessionT]):
class RequestContext(Generic[SessionT, LifespanContextT]):
request_id: RequestId
meta: RequestParams.Meta | None
session: SessionT
lifespan_context: LifespanContextT

View File

@@ -20,7 +20,10 @@ async def test_progress_token_zero_first_call():
mock_meta.progressToken = 0 # This is the key test case - token is 0
request_context = RequestContext(
request_id="test-request", session=mock_session, meta=mock_meta
request_id="test-request",
session=mock_session,
meta=mock_meta,
lifespan_context=None,
)
# Create context with our mocks

View File

@@ -236,7 +236,7 @@ async def test_lambda_function():
def test_complex_function_json_schema():
"""Test JSON schema generation for complex function arguments.
Note: Different versions of pydantic output slightly different
JSON Schema formats for model fields with defaults. The format changed in 2.9.0:
@@ -245,16 +245,16 @@ def test_complex_function_json_schema():
"allOf": [{"$ref": "#/$defs/Model"}],
"default": {}
}
2. Since 2.9.0:
{
"$ref": "#/$defs/Model",
"default": {}
}
Both formats are valid and functionally equivalent. This test accepts either format
to ensure compatibility across our supported pydantic versions.
This change in format does not affect runtime behavior since:
1. Both schemas validate the same way
2. The actual model classes and validation logic are unchanged
@@ -262,17 +262,17 @@ def test_complex_function_json_schema():
"""
meta = func_metadata(complex_arguments_fn)
actual_schema = meta.arg_model.model_json_schema()
# Create a copy of the actual schema to normalize
normalized_schema = actual_schema.copy()
# Normalize the my_model_a_with_default field to handle both pydantic formats
if 'allOf' in actual_schema['properties']['my_model_a_with_default']:
normalized_schema['properties']['my_model_a_with_default'] = {
'$ref': '#/$defs/SomeInputModelA',
'default': {}
if "allOf" in actual_schema["properties"]["my_model_a_with_default"]:
normalized_schema["properties"]["my_model_a_with_default"] = {
"$ref": "#/$defs/SomeInputModelA",
"default": {},
}
assert normalized_schema == {
"$defs": {
"InnerModel": {

View File

@@ -0,0 +1,207 @@
"""Tests for lifespan functionality in both low-level and FastMCP servers."""
from contextlib import asynccontextmanager
from typing import AsyncIterator
import anyio
import pytest
from pydantic import TypeAdapter
from mcp.server.fastmcp import Context, FastMCP
from mcp.server.lowlevel.server import NotificationOptions, Server
from mcp.server.models import InitializationOptions
from mcp.types import (
ClientCapabilities,
Implementation,
InitializeRequestParams,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
)
@pytest.mark.anyio
async def test_lowlevel_server_lifespan():
"""Test that lifespan works in low-level server."""
@asynccontextmanager
async def test_lifespan(server: Server) -> AsyncIterator[dict]:
"""Test lifespan context that tracks startup/shutdown."""
context = {"started": False, "shutdown": False}
try:
context["started"] = True
yield context
finally:
context["shutdown"] = True
server = Server("test", lifespan=test_lifespan)
# Create memory streams for testing
send_stream1, receive_stream1 = anyio.create_memory_object_stream(100)
send_stream2, receive_stream2 = anyio.create_memory_object_stream(100)
# Create a tool that accesses lifespan context
@server.call_tool()
async def check_lifespan(name: str, arguments: dict) -> list:
ctx = server.request_context
assert isinstance(ctx.lifespan_context, dict)
assert ctx.lifespan_context["started"]
assert not ctx.lifespan_context["shutdown"]
return [{"type": "text", "text": "true"}]
# Run server in background task
async with anyio.create_task_group() as tg:
async def run_server():
await server.run(
receive_stream1,
send_stream2,
InitializationOptions(
server_name="test",
server_version="0.1.0",
capabilities=server.get_capabilities(
notification_options=NotificationOptions(),
experimental_capabilities={},
),
),
raise_exceptions=True,
)
tg.start_soon(run_server)
# Initialize the server
params = InitializeRequestParams(
protocolVersion="2024-11-05",
capabilities=ClientCapabilities(),
clientInfo=Implementation(name="test-client", version="0.1.0"),
)
await send_stream1.send(
JSONRPCMessage(
root=JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="initialize",
params=TypeAdapter(InitializeRequestParams).dump_python(params),
)
)
)
response = await receive_stream2.receive()
# Send initialized notification
await send_stream1.send(
JSONRPCMessage(
root=JSONRPCNotification(
jsonrpc="2.0",
method="notifications/initialized",
)
)
)
# Call the tool to verify lifespan context
await send_stream1.send(
JSONRPCMessage(
root=JSONRPCRequest(
jsonrpc="2.0",
id=2,
method="tools/call",
params={"name": "check_lifespan", "arguments": {}},
)
)
)
# Get response and verify
response = await receive_stream2.receive()
assert response.root.result["content"][0]["text"] == "true"
# Cancel server task
tg.cancel_scope.cancel()
@pytest.mark.anyio
async def test_fastmcp_server_lifespan():
"""Test that lifespan works in FastMCP server."""
@asynccontextmanager
async def test_lifespan(server: FastMCP) -> AsyncIterator[dict]:
"""Test lifespan context that tracks startup/shutdown."""
context = {"started": False, "shutdown": False}
try:
context["started"] = True
yield context
finally:
context["shutdown"] = True
server = FastMCP("test", lifespan=test_lifespan)
# Create memory streams for testing
send_stream1, receive_stream1 = anyio.create_memory_object_stream(100)
send_stream2, receive_stream2 = anyio.create_memory_object_stream(100)
# Add a tool that checks lifespan context
@server.tool()
def check_lifespan(ctx: Context) -> bool:
"""Tool that checks lifespan context."""
assert isinstance(ctx.request_context.lifespan_context, dict)
assert ctx.request_context.lifespan_context["started"]
assert not ctx.request_context.lifespan_context["shutdown"]
return True
# Run server in background task
async with anyio.create_task_group() as tg:
async def run_server():
await server._mcp_server.run(
receive_stream1,
send_stream2,
server._mcp_server.create_initialization_options(),
raise_exceptions=True,
)
tg.start_soon(run_server)
# Initialize the server
params = InitializeRequestParams(
protocolVersion="2024-11-05",
capabilities=ClientCapabilities(),
clientInfo=Implementation(name="test-client", version="0.1.0"),
)
await send_stream1.send(
JSONRPCMessage(
root=JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="initialize",
params=TypeAdapter(InitializeRequestParams).dump_python(params),
)
)
)
response = await receive_stream2.receive()
# Send initialized notification
await send_stream1.send(
JSONRPCMessage(
root=JSONRPCNotification(
jsonrpc="2.0",
method="notifications/initialized",
)
)
)
# Call the tool to verify lifespan context
await send_stream1.send(
JSONRPCMessage(
root=JSONRPCRequest(
jsonrpc="2.0",
id=2,
method="tools/call",
params={"name": "check_lifespan", "arguments": {}},
)
)
)
# Get response and verify
response = await receive_stream2.receive()
assert response.root.result["content"][0]["text"] == "true"
# Cancel server task
tg.cancel_scope.cancel()