mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Merge pull request #203 from modelcontextprotocol/davidsp/clean-lifespan
feat: add lifespan context manager support
This commit is contained in:
61
README.md
61
README.md
@@ -128,6 +128,9 @@ The [Model Context Protocol (MCP)](https://modelcontextprotocol.io) lets you bui
|
||||
The FastMCP server is your core interface to the MCP protocol. It handles connection management, protocol compliance, and message routing:
|
||||
|
||||
```python
|
||||
# Add lifespan support for startup/shutdown with strong typing
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncIterator
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
# Create a named server
|
||||
@@ -135,6 +138,31 @@ mcp = FastMCP("My App")
|
||||
|
||||
# Specify dependencies for deployment and development
|
||||
mcp = FastMCP("My App", dependencies=["pandas", "numpy"])
|
||||
|
||||
@dataclass
|
||||
class AppContext:
|
||||
db: Database # Replace with your actual DB type
|
||||
|
||||
@asynccontextmanager
|
||||
async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]:
|
||||
"""Manage application lifecycle with type-safe context"""
|
||||
try:
|
||||
# Initialize on startup
|
||||
await db.connect()
|
||||
yield AppContext(db=db)
|
||||
finally:
|
||||
# Cleanup on shutdown
|
||||
await db.disconnect()
|
||||
|
||||
# Pass lifespan to server
|
||||
mcp = FastMCP("My App", lifespan=app_lifespan)
|
||||
|
||||
# Access type-safe lifespan context in tools
|
||||
@mcp.tool()
|
||||
def query_db(ctx: Context) -> str:
|
||||
"""Tool that uses initialized resources"""
|
||||
db = ctx.request_context.lifespan_context["db"]
|
||||
return db.query()
|
||||
```
|
||||
|
||||
### Resources
|
||||
@@ -334,7 +362,38 @@ def query_data(sql: str) -> str:
|
||||
|
||||
### Low-Level Server
|
||||
|
||||
For more control, you can use the low-level server implementation directly. This gives you full access to the protocol and allows you to customize every aspect of your server:
|
||||
For more control, you can use the low-level server implementation directly. This gives you full access to the protocol and allows you to customize every aspect of your server, including lifecycle management through the lifespan API:
|
||||
|
||||
```python
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncIterator
|
||||
|
||||
@asynccontextmanager
|
||||
async def server_lifespan(server: Server) -> AsyncIterator[dict]:
|
||||
"""Manage server startup and shutdown lifecycle."""
|
||||
try:
|
||||
# Initialize resources on startup
|
||||
await db.connect()
|
||||
yield {"db": db}
|
||||
finally:
|
||||
# Clean up on shutdown
|
||||
await db.disconnect()
|
||||
|
||||
# Pass lifespan to server
|
||||
server = Server("example-server", lifespan=server_lifespan)
|
||||
|
||||
# Access lifespan context in handlers
|
||||
@server.call_tool()
|
||||
async def query_db(name: str, arguments: dict) -> list:
|
||||
ctx = server.request_context
|
||||
db = ctx.lifespan_context["db"]
|
||||
return await db.query(arguments["query"])
|
||||
```
|
||||
|
||||
The lifespan API provides:
|
||||
- A way to initialize resources when the server starts and clean them up when it stops
|
||||
- Access to initialized resources through the request context in handlers
|
||||
- Type-safe context passing between lifespan and request handlers
|
||||
|
||||
```python
|
||||
from mcp.server.lowlevel import Server, NotificationOptions
|
||||
|
||||
@@ -3,8 +3,13 @@
|
||||
import inspect
|
||||
import json
|
||||
import re
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import (
|
||||
AbstractAsyncContextManager,
|
||||
asynccontextmanager,
|
||||
)
|
||||
from itertools import chain
|
||||
from typing import Any, Callable, Literal, Sequence
|
||||
from typing import Any, Callable, Generic, Literal, Sequence
|
||||
|
||||
import anyio
|
||||
import pydantic_core
|
||||
@@ -19,8 +24,16 @@ from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceMan
|
||||
from mcp.server.fastmcp.tools import ToolManager
|
||||
from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger
|
||||
from mcp.server.fastmcp.utilities.types import Image
|
||||
from mcp.server.lowlevel import Server as MCPServer
|
||||
from mcp.server.lowlevel.helper_types import ReadResourceContents
|
||||
from mcp.server.lowlevel.server import (
|
||||
LifespanResultT,
|
||||
)
|
||||
from mcp.server.lowlevel.server import (
|
||||
Server as MCPServer,
|
||||
)
|
||||
from mcp.server.lowlevel.server import (
|
||||
lifespan as default_lifespan,
|
||||
)
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from mcp.server.stdio import stdio_server
|
||||
from mcp.shared.context import RequestContext
|
||||
@@ -50,7 +63,7 @@ from mcp.types import (
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
class Settings(BaseSettings, Generic[LifespanResultT]):
|
||||
"""FastMCP server settings.
|
||||
|
||||
All settings can be configured via environment variables with the prefix FASTMCP_.
|
||||
@@ -85,13 +98,36 @@ class Settings(BaseSettings):
|
||||
description="List of dependencies to install in the server environment",
|
||||
)
|
||||
|
||||
lifespan: (
|
||||
Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]] | None
|
||||
) = Field(None, description="Lifespan context manager")
|
||||
|
||||
|
||||
def lifespan_wrapper(
|
||||
app: "FastMCP",
|
||||
lifespan: Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]],
|
||||
) -> Callable[[MCPServer], AbstractAsyncContextManager[object]]:
|
||||
@asynccontextmanager
|
||||
async def wrap(s: MCPServer) -> AsyncIterator[object]:
|
||||
async with lifespan(app) as context:
|
||||
yield context
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
class FastMCP:
|
||||
def __init__(
|
||||
self, name: str | None = None, instructions: str | None = None, **settings: Any
|
||||
):
|
||||
self.settings = Settings(**settings)
|
||||
self._mcp_server = MCPServer(name=name or "FastMCP", instructions=instructions)
|
||||
|
||||
self._mcp_server = MCPServer(
|
||||
name=name or "FastMCP",
|
||||
instructions=instructions,
|
||||
lifespan=lifespan_wrapper(self, self.settings.lifespan)
|
||||
if self.settings.lifespan
|
||||
else default_lifespan,
|
||||
)
|
||||
self._tool_manager = ToolManager(
|
||||
warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools
|
||||
)
|
||||
|
||||
@@ -68,7 +68,8 @@ import contextvars
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, Sequence
|
||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||
from typing import Any, AsyncIterator, Generic, Sequence, TypeVar
|
||||
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl
|
||||
@@ -84,7 +85,10 @@ from mcp.shared.session import RequestResponder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
request_ctx: contextvars.ContextVar[RequestContext[ServerSession]] = (
|
||||
LifespanResultT = TypeVar("LifespanResultT")
|
||||
|
||||
# This will be properly typed in each Server instance's context
|
||||
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any]] = (
|
||||
contextvars.ContextVar("request_ctx")
|
||||
)
|
||||
|
||||
@@ -101,13 +105,33 @@ class NotificationOptions:
|
||||
self.tools_changed = tools_changed
|
||||
|
||||
|
||||
class Server:
|
||||
@asynccontextmanager
|
||||
async def lifespan(server: "Server") -> AsyncIterator[object]:
|
||||
"""Default lifespan context manager that does nothing.
|
||||
|
||||
Args:
|
||||
server: The server instance this lifespan is managing
|
||||
|
||||
Returns:
|
||||
An empty context object
|
||||
"""
|
||||
yield {}
|
||||
|
||||
|
||||
class Server(Generic[LifespanResultT]):
|
||||
def __init__(
|
||||
self, name: str, version: str | None = None, instructions: str | None = None
|
||||
self,
|
||||
name: str,
|
||||
version: str | None = None,
|
||||
instructions: str | None = None,
|
||||
lifespan: Callable[
|
||||
["Server"], AbstractAsyncContextManager[LifespanResultT]
|
||||
] = lifespan,
|
||||
):
|
||||
self.name = name
|
||||
self.version = version
|
||||
self.instructions = instructions
|
||||
self.lifespan = lifespan
|
||||
self.request_handlers: dict[
|
||||
type, Callable[..., Awaitable[types.ServerResult]]
|
||||
] = {
|
||||
@@ -188,7 +212,7 @@ class Server:
|
||||
)
|
||||
|
||||
@property
|
||||
def request_context(self) -> RequestContext[ServerSession]:
|
||||
def request_context(self) -> RequestContext[ServerSession, LifespanResultT]:
|
||||
"""If called outside of a request context, this will raise a LookupError."""
|
||||
return request_ctx.get()
|
||||
|
||||
@@ -446,9 +470,14 @@ class Server:
|
||||
raise_exceptions: bool = False,
|
||||
):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
async with ServerSession(
|
||||
read_stream, write_stream, initialization_options
|
||||
) as session:
|
||||
from contextlib import AsyncExitStack
|
||||
|
||||
async with AsyncExitStack() as stack:
|
||||
lifespan_context = await stack.enter_async_context(self.lifespan(self))
|
||||
session = await stack.enter_async_context(
|
||||
ServerSession(read_stream, write_stream, initialization_options)
|
||||
)
|
||||
|
||||
async for message in session.incoming_messages:
|
||||
logger.debug(f"Received message: {message}")
|
||||
|
||||
@@ -460,14 +489,20 @@ class Server:
|
||||
):
|
||||
with responder:
|
||||
await self._handle_request(
|
||||
message, req, session, raise_exceptions
|
||||
message,
|
||||
req,
|
||||
session,
|
||||
lifespan_context,
|
||||
raise_exceptions,
|
||||
)
|
||||
case types.ClientNotification(root=notify):
|
||||
await self._handle_notification(notify)
|
||||
|
||||
for warning in w:
|
||||
logger.info(
|
||||
f"Warning: {warning.category.__name__}: {warning.message}"
|
||||
"Warning: %s: %s",
|
||||
warning.category.__name__,
|
||||
warning.message,
|
||||
)
|
||||
|
||||
async def _handle_request(
|
||||
@@ -475,6 +510,7 @@ class Server:
|
||||
message: RequestResponder,
|
||||
req: Any,
|
||||
session: ServerSession,
|
||||
lifespan_context: LifespanResultT,
|
||||
raise_exceptions: bool,
|
||||
):
|
||||
logger.info(f"Processing request of type {type(req).__name__}")
|
||||
@@ -491,6 +527,7 @@ class Server:
|
||||
message.request_id,
|
||||
message.request_meta,
|
||||
session,
|
||||
lifespan_context,
|
||||
)
|
||||
)
|
||||
response = await handler(req)
|
||||
|
||||
@@ -5,10 +5,12 @@ from mcp.shared.session import BaseSession
|
||||
from mcp.types import RequestId, RequestParams
|
||||
|
||||
SessionT = TypeVar("SessionT", bound=BaseSession)
|
||||
LifespanContextT = TypeVar("LifespanContextT")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext(Generic[SessionT]):
|
||||
class RequestContext(Generic[SessionT, LifespanContextT]):
|
||||
request_id: RequestId
|
||||
meta: RequestParams.Meta | None
|
||||
session: SessionT
|
||||
lifespan_context: LifespanContextT
|
||||
|
||||
@@ -20,7 +20,10 @@ async def test_progress_token_zero_first_call():
|
||||
mock_meta.progressToken = 0 # This is the key test case - token is 0
|
||||
|
||||
request_context = RequestContext(
|
||||
request_id="test-request", session=mock_session, meta=mock_meta
|
||||
request_id="test-request",
|
||||
session=mock_session,
|
||||
meta=mock_meta,
|
||||
lifespan_context=None,
|
||||
)
|
||||
|
||||
# Create context with our mocks
|
||||
|
||||
@@ -236,7 +236,7 @@ async def test_lambda_function():
|
||||
|
||||
def test_complex_function_json_schema():
|
||||
"""Test JSON schema generation for complex function arguments.
|
||||
|
||||
|
||||
Note: Different versions of pydantic output slightly different
|
||||
JSON Schema formats for model fields with defaults. The format changed in 2.9.0:
|
||||
|
||||
@@ -245,16 +245,16 @@ def test_complex_function_json_schema():
|
||||
"allOf": [{"$ref": "#/$defs/Model"}],
|
||||
"default": {}
|
||||
}
|
||||
|
||||
|
||||
2. Since 2.9.0:
|
||||
{
|
||||
"$ref": "#/$defs/Model",
|
||||
"default": {}
|
||||
}
|
||||
|
||||
|
||||
Both formats are valid and functionally equivalent. This test accepts either format
|
||||
to ensure compatibility across our supported pydantic versions.
|
||||
|
||||
|
||||
This change in format does not affect runtime behavior since:
|
||||
1. Both schemas validate the same way
|
||||
2. The actual model classes and validation logic are unchanged
|
||||
@@ -262,17 +262,17 @@ def test_complex_function_json_schema():
|
||||
"""
|
||||
meta = func_metadata(complex_arguments_fn)
|
||||
actual_schema = meta.arg_model.model_json_schema()
|
||||
|
||||
|
||||
# Create a copy of the actual schema to normalize
|
||||
normalized_schema = actual_schema.copy()
|
||||
|
||||
|
||||
# Normalize the my_model_a_with_default field to handle both pydantic formats
|
||||
if 'allOf' in actual_schema['properties']['my_model_a_with_default']:
|
||||
normalized_schema['properties']['my_model_a_with_default'] = {
|
||||
'$ref': '#/$defs/SomeInputModelA',
|
||||
'default': {}
|
||||
if "allOf" in actual_schema["properties"]["my_model_a_with_default"]:
|
||||
normalized_schema["properties"]["my_model_a_with_default"] = {
|
||||
"$ref": "#/$defs/SomeInputModelA",
|
||||
"default": {},
|
||||
}
|
||||
|
||||
|
||||
assert normalized_schema == {
|
||||
"$defs": {
|
||||
"InnerModel": {
|
||||
|
||||
207
tests/server/test_lifespan.py
Normal file
207
tests/server/test_lifespan.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""Tests for lifespan functionality in both low-level and FastMCP servers."""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncIterator
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from mcp.server.fastmcp import Context, FastMCP
|
||||
from mcp.server.lowlevel.server import NotificationOptions, Server
|
||||
from mcp.server.models import InitializationOptions
|
||||
from mcp.types import (
|
||||
ClientCapabilities,
|
||||
Implementation,
|
||||
InitializeRequestParams,
|
||||
JSONRPCMessage,
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_lowlevel_server_lifespan():
|
||||
"""Test that lifespan works in low-level server."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def test_lifespan(server: Server) -> AsyncIterator[dict]:
|
||||
"""Test lifespan context that tracks startup/shutdown."""
|
||||
context = {"started": False, "shutdown": False}
|
||||
try:
|
||||
context["started"] = True
|
||||
yield context
|
||||
finally:
|
||||
context["shutdown"] = True
|
||||
|
||||
server = Server("test", lifespan=test_lifespan)
|
||||
|
||||
# Create memory streams for testing
|
||||
send_stream1, receive_stream1 = anyio.create_memory_object_stream(100)
|
||||
send_stream2, receive_stream2 = anyio.create_memory_object_stream(100)
|
||||
|
||||
# Create a tool that accesses lifespan context
|
||||
@server.call_tool()
|
||||
async def check_lifespan(name: str, arguments: dict) -> list:
|
||||
ctx = server.request_context
|
||||
assert isinstance(ctx.lifespan_context, dict)
|
||||
assert ctx.lifespan_context["started"]
|
||||
assert not ctx.lifespan_context["shutdown"]
|
||||
return [{"type": "text", "text": "true"}]
|
||||
|
||||
# Run server in background task
|
||||
async with anyio.create_task_group() as tg:
|
||||
|
||||
async def run_server():
|
||||
await server.run(
|
||||
receive_stream1,
|
||||
send_stream2,
|
||||
InitializationOptions(
|
||||
server_name="test",
|
||||
server_version="0.1.0",
|
||||
capabilities=server.get_capabilities(
|
||||
notification_options=NotificationOptions(),
|
||||
experimental_capabilities={},
|
||||
),
|
||||
),
|
||||
raise_exceptions=True,
|
||||
)
|
||||
|
||||
tg.start_soon(run_server)
|
||||
|
||||
# Initialize the server
|
||||
params = InitializeRequestParams(
|
||||
protocolVersion="2024-11-05",
|
||||
capabilities=ClientCapabilities(),
|
||||
clientInfo=Implementation(name="test-client", version="0.1.0"),
|
||||
)
|
||||
await send_stream1.send(
|
||||
JSONRPCMessage(
|
||||
root=JSONRPCRequest(
|
||||
jsonrpc="2.0",
|
||||
id=1,
|
||||
method="initialize",
|
||||
params=TypeAdapter(InitializeRequestParams).dump_python(params),
|
||||
)
|
||||
)
|
||||
)
|
||||
response = await receive_stream2.receive()
|
||||
|
||||
# Send initialized notification
|
||||
await send_stream1.send(
|
||||
JSONRPCMessage(
|
||||
root=JSONRPCNotification(
|
||||
jsonrpc="2.0",
|
||||
method="notifications/initialized",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Call the tool to verify lifespan context
|
||||
await send_stream1.send(
|
||||
JSONRPCMessage(
|
||||
root=JSONRPCRequest(
|
||||
jsonrpc="2.0",
|
||||
id=2,
|
||||
method="tools/call",
|
||||
params={"name": "check_lifespan", "arguments": {}},
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Get response and verify
|
||||
response = await receive_stream2.receive()
|
||||
assert response.root.result["content"][0]["text"] == "true"
|
||||
|
||||
# Cancel server task
|
||||
tg.cancel_scope.cancel()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_fastmcp_server_lifespan():
|
||||
"""Test that lifespan works in FastMCP server."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def test_lifespan(server: FastMCP) -> AsyncIterator[dict]:
|
||||
"""Test lifespan context that tracks startup/shutdown."""
|
||||
context = {"started": False, "shutdown": False}
|
||||
try:
|
||||
context["started"] = True
|
||||
yield context
|
||||
finally:
|
||||
context["shutdown"] = True
|
||||
|
||||
server = FastMCP("test", lifespan=test_lifespan)
|
||||
|
||||
# Create memory streams for testing
|
||||
send_stream1, receive_stream1 = anyio.create_memory_object_stream(100)
|
||||
send_stream2, receive_stream2 = anyio.create_memory_object_stream(100)
|
||||
|
||||
# Add a tool that checks lifespan context
|
||||
@server.tool()
|
||||
def check_lifespan(ctx: Context) -> bool:
|
||||
"""Tool that checks lifespan context."""
|
||||
assert isinstance(ctx.request_context.lifespan_context, dict)
|
||||
assert ctx.request_context.lifespan_context["started"]
|
||||
assert not ctx.request_context.lifespan_context["shutdown"]
|
||||
return True
|
||||
|
||||
# Run server in background task
|
||||
async with anyio.create_task_group() as tg:
|
||||
|
||||
async def run_server():
|
||||
await server._mcp_server.run(
|
||||
receive_stream1,
|
||||
send_stream2,
|
||||
server._mcp_server.create_initialization_options(),
|
||||
raise_exceptions=True,
|
||||
)
|
||||
|
||||
tg.start_soon(run_server)
|
||||
|
||||
# Initialize the server
|
||||
params = InitializeRequestParams(
|
||||
protocolVersion="2024-11-05",
|
||||
capabilities=ClientCapabilities(),
|
||||
clientInfo=Implementation(name="test-client", version="0.1.0"),
|
||||
)
|
||||
await send_stream1.send(
|
||||
JSONRPCMessage(
|
||||
root=JSONRPCRequest(
|
||||
jsonrpc="2.0",
|
||||
id=1,
|
||||
method="initialize",
|
||||
params=TypeAdapter(InitializeRequestParams).dump_python(params),
|
||||
)
|
||||
)
|
||||
)
|
||||
response = await receive_stream2.receive()
|
||||
|
||||
# Send initialized notification
|
||||
await send_stream1.send(
|
||||
JSONRPCMessage(
|
||||
root=JSONRPCNotification(
|
||||
jsonrpc="2.0",
|
||||
method="notifications/initialized",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Call the tool to verify lifespan context
|
||||
await send_stream1.send(
|
||||
JSONRPCMessage(
|
||||
root=JSONRPCRequest(
|
||||
jsonrpc="2.0",
|
||||
id=2,
|
||||
method="tools/call",
|
||||
params={"name": "check_lifespan", "arguments": {}},
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Get response and verify
|
||||
response = await receive_stream2.receive()
|
||||
assert response.root.result["content"][0]["text"] == "true"
|
||||
|
||||
# Cancel server task
|
||||
tg.cancel_scope.cancel()
|
||||
Reference in New Issue
Block a user