mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Merge pull request #203 from modelcontextprotocol/davidsp/clean-lifespan
feat: add lifespan context manager support
This commit is contained in:
61
README.md
61
README.md
@@ -128,6 +128,9 @@ The [Model Context Protocol (MCP)](https://modelcontextprotocol.io) lets you bui
|
|||||||
The FastMCP server is your core interface to the MCP protocol. It handles connection management, protocol compliance, and message routing:
|
The FastMCP server is your core interface to the MCP protocol. It handles connection management, protocol compliance, and message routing:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
# Add lifespan support for startup/shutdown with strong typing
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import AsyncIterator
|
||||||
from mcp.server.fastmcp import FastMCP
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
|
||||||
# Create a named server
|
# Create a named server
|
||||||
@@ -135,6 +138,31 @@ mcp = FastMCP("My App")
|
|||||||
|
|
||||||
# Specify dependencies for deployment and development
|
# Specify dependencies for deployment and development
|
||||||
mcp = FastMCP("My App", dependencies=["pandas", "numpy"])
|
mcp = FastMCP("My App", dependencies=["pandas", "numpy"])
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AppContext:
|
||||||
|
db: Database # Replace with your actual DB type
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]:
|
||||||
|
"""Manage application lifecycle with type-safe context"""
|
||||||
|
try:
|
||||||
|
# Initialize on startup
|
||||||
|
await db.connect()
|
||||||
|
yield AppContext(db=db)
|
||||||
|
finally:
|
||||||
|
# Cleanup on shutdown
|
||||||
|
await db.disconnect()
|
||||||
|
|
||||||
|
# Pass lifespan to server
|
||||||
|
mcp = FastMCP("My App", lifespan=app_lifespan)
|
||||||
|
|
||||||
|
# Access type-safe lifespan context in tools
|
||||||
|
@mcp.tool()
|
||||||
|
def query_db(ctx: Context) -> str:
|
||||||
|
"""Tool that uses initialized resources"""
|
||||||
|
db = ctx.request_context.lifespan_context["db"]
|
||||||
|
return db.query()
|
||||||
```
|
```
|
||||||
|
|
||||||
### Resources
|
### Resources
|
||||||
@@ -334,7 +362,38 @@ def query_data(sql: str) -> str:
|
|||||||
|
|
||||||
### Low-Level Server
|
### Low-Level Server
|
||||||
|
|
||||||
For more control, you can use the low-level server implementation directly. This gives you full access to the protocol and allows you to customize every aspect of your server:
|
For more control, you can use the low-level server implementation directly. This gives you full access to the protocol and allows you to customize every aspect of your server, including lifecycle management through the lifespan API:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import AsyncIterator
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def server_lifespan(server: Server) -> AsyncIterator[dict]:
|
||||||
|
"""Manage server startup and shutdown lifecycle."""
|
||||||
|
try:
|
||||||
|
# Initialize resources on startup
|
||||||
|
await db.connect()
|
||||||
|
yield {"db": db}
|
||||||
|
finally:
|
||||||
|
# Clean up on shutdown
|
||||||
|
await db.disconnect()
|
||||||
|
|
||||||
|
# Pass lifespan to server
|
||||||
|
server = Server("example-server", lifespan=server_lifespan)
|
||||||
|
|
||||||
|
# Access lifespan context in handlers
|
||||||
|
@server.call_tool()
|
||||||
|
async def query_db(name: str, arguments: dict) -> list:
|
||||||
|
ctx = server.request_context
|
||||||
|
db = ctx.lifespan_context["db"]
|
||||||
|
return await db.query(arguments["query"])
|
||||||
|
```
|
||||||
|
|
||||||
|
The lifespan API provides:
|
||||||
|
- A way to initialize resources when the server starts and clean them up when it stops
|
||||||
|
- Access to initialized resources through the request context in handlers
|
||||||
|
- Type-safe context passing between lifespan and request handlers
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from mcp.server.lowlevel import Server, NotificationOptions
|
from mcp.server.lowlevel import Server, NotificationOptions
|
||||||
|
|||||||
@@ -3,8 +3,13 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from contextlib import (
|
||||||
|
AbstractAsyncContextManager,
|
||||||
|
asynccontextmanager,
|
||||||
|
)
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Any, Callable, Literal, Sequence
|
from typing import Any, Callable, Generic, Literal, Sequence
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
import pydantic_core
|
import pydantic_core
|
||||||
@@ -19,8 +24,16 @@ from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceMan
|
|||||||
from mcp.server.fastmcp.tools import ToolManager
|
from mcp.server.fastmcp.tools import ToolManager
|
||||||
from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger
|
from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger
|
||||||
from mcp.server.fastmcp.utilities.types import Image
|
from mcp.server.fastmcp.utilities.types import Image
|
||||||
from mcp.server.lowlevel import Server as MCPServer
|
|
||||||
from mcp.server.lowlevel.helper_types import ReadResourceContents
|
from mcp.server.lowlevel.helper_types import ReadResourceContents
|
||||||
|
from mcp.server.lowlevel.server import (
|
||||||
|
LifespanResultT,
|
||||||
|
)
|
||||||
|
from mcp.server.lowlevel.server import (
|
||||||
|
Server as MCPServer,
|
||||||
|
)
|
||||||
|
from mcp.server.lowlevel.server import (
|
||||||
|
lifespan as default_lifespan,
|
||||||
|
)
|
||||||
from mcp.server.sse import SseServerTransport
|
from mcp.server.sse import SseServerTransport
|
||||||
from mcp.server.stdio import stdio_server
|
from mcp.server.stdio import stdio_server
|
||||||
from mcp.shared.context import RequestContext
|
from mcp.shared.context import RequestContext
|
||||||
@@ -50,7 +63,7 @@ from mcp.types import (
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings, Generic[LifespanResultT]):
|
||||||
"""FastMCP server settings.
|
"""FastMCP server settings.
|
||||||
|
|
||||||
All settings can be configured via environment variables with the prefix FASTMCP_.
|
All settings can be configured via environment variables with the prefix FASTMCP_.
|
||||||
@@ -85,13 +98,36 @@ class Settings(BaseSettings):
|
|||||||
description="List of dependencies to install in the server environment",
|
description="List of dependencies to install in the server environment",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
lifespan: (
|
||||||
|
Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]] | None
|
||||||
|
) = Field(None, description="Lifespan context manager")
|
||||||
|
|
||||||
|
|
||||||
|
def lifespan_wrapper(
|
||||||
|
app: "FastMCP",
|
||||||
|
lifespan: Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]],
|
||||||
|
) -> Callable[[MCPServer], AbstractAsyncContextManager[object]]:
|
||||||
|
@asynccontextmanager
|
||||||
|
async def wrap(s: MCPServer) -> AsyncIterator[object]:
|
||||||
|
async with lifespan(app) as context:
|
||||||
|
yield context
|
||||||
|
|
||||||
|
return wrap
|
||||||
|
|
||||||
|
|
||||||
class FastMCP:
|
class FastMCP:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, name: str | None = None, instructions: str | None = None, **settings: Any
|
self, name: str | None = None, instructions: str | None = None, **settings: Any
|
||||||
):
|
):
|
||||||
self.settings = Settings(**settings)
|
self.settings = Settings(**settings)
|
||||||
self._mcp_server = MCPServer(name=name or "FastMCP", instructions=instructions)
|
|
||||||
|
self._mcp_server = MCPServer(
|
||||||
|
name=name or "FastMCP",
|
||||||
|
instructions=instructions,
|
||||||
|
lifespan=lifespan_wrapper(self, self.settings.lifespan)
|
||||||
|
if self.settings.lifespan
|
||||||
|
else default_lifespan,
|
||||||
|
)
|
||||||
self._tool_manager = ToolManager(
|
self._tool_manager = ToolManager(
|
||||||
warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools
|
warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -68,7 +68,8 @@ import contextvars
|
|||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import Any, Sequence
|
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||||
|
from typing import Any, AsyncIterator, Generic, Sequence, TypeVar
|
||||||
|
|
||||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||||
from pydantic import AnyUrl
|
from pydantic import AnyUrl
|
||||||
@@ -84,7 +85,10 @@ from mcp.shared.session import RequestResponder
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
request_ctx: contextvars.ContextVar[RequestContext[ServerSession]] = (
|
LifespanResultT = TypeVar("LifespanResultT")
|
||||||
|
|
||||||
|
# This will be properly typed in each Server instance's context
|
||||||
|
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any]] = (
|
||||||
contextvars.ContextVar("request_ctx")
|
contextvars.ContextVar("request_ctx")
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -101,13 +105,33 @@ class NotificationOptions:
|
|||||||
self.tools_changed = tools_changed
|
self.tools_changed = tools_changed
|
||||||
|
|
||||||
|
|
||||||
class Server:
|
@asynccontextmanager
|
||||||
|
async def lifespan(server: "Server") -> AsyncIterator[object]:
|
||||||
|
"""Default lifespan context manager that does nothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server: The server instance this lifespan is managing
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An empty context object
|
||||||
|
"""
|
||||||
|
yield {}
|
||||||
|
|
||||||
|
|
||||||
|
class Server(Generic[LifespanResultT]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, name: str, version: str | None = None, instructions: str | None = None
|
self,
|
||||||
|
name: str,
|
||||||
|
version: str | None = None,
|
||||||
|
instructions: str | None = None,
|
||||||
|
lifespan: Callable[
|
||||||
|
["Server"], AbstractAsyncContextManager[LifespanResultT]
|
||||||
|
] = lifespan,
|
||||||
):
|
):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.version = version
|
self.version = version
|
||||||
self.instructions = instructions
|
self.instructions = instructions
|
||||||
|
self.lifespan = lifespan
|
||||||
self.request_handlers: dict[
|
self.request_handlers: dict[
|
||||||
type, Callable[..., Awaitable[types.ServerResult]]
|
type, Callable[..., Awaitable[types.ServerResult]]
|
||||||
] = {
|
] = {
|
||||||
@@ -188,7 +212,7 @@ class Server:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def request_context(self) -> RequestContext[ServerSession]:
|
def request_context(self) -> RequestContext[ServerSession, LifespanResultT]:
|
||||||
"""If called outside of a request context, this will raise a LookupError."""
|
"""If called outside of a request context, this will raise a LookupError."""
|
||||||
return request_ctx.get()
|
return request_ctx.get()
|
||||||
|
|
||||||
@@ -446,9 +470,14 @@ class Server:
|
|||||||
raise_exceptions: bool = False,
|
raise_exceptions: bool = False,
|
||||||
):
|
):
|
||||||
with warnings.catch_warnings(record=True) as w:
|
with warnings.catch_warnings(record=True) as w:
|
||||||
async with ServerSession(
|
from contextlib import AsyncExitStack
|
||||||
read_stream, write_stream, initialization_options
|
|
||||||
) as session:
|
async with AsyncExitStack() as stack:
|
||||||
|
lifespan_context = await stack.enter_async_context(self.lifespan(self))
|
||||||
|
session = await stack.enter_async_context(
|
||||||
|
ServerSession(read_stream, write_stream, initialization_options)
|
||||||
|
)
|
||||||
|
|
||||||
async for message in session.incoming_messages:
|
async for message in session.incoming_messages:
|
||||||
logger.debug(f"Received message: {message}")
|
logger.debug(f"Received message: {message}")
|
||||||
|
|
||||||
@@ -460,14 +489,20 @@ class Server:
|
|||||||
):
|
):
|
||||||
with responder:
|
with responder:
|
||||||
await self._handle_request(
|
await self._handle_request(
|
||||||
message, req, session, raise_exceptions
|
message,
|
||||||
|
req,
|
||||||
|
session,
|
||||||
|
lifespan_context,
|
||||||
|
raise_exceptions,
|
||||||
)
|
)
|
||||||
case types.ClientNotification(root=notify):
|
case types.ClientNotification(root=notify):
|
||||||
await self._handle_notification(notify)
|
await self._handle_notification(notify)
|
||||||
|
|
||||||
for warning in w:
|
for warning in w:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Warning: {warning.category.__name__}: {warning.message}"
|
"Warning: %s: %s",
|
||||||
|
warning.category.__name__,
|
||||||
|
warning.message,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _handle_request(
|
async def _handle_request(
|
||||||
@@ -475,6 +510,7 @@ class Server:
|
|||||||
message: RequestResponder,
|
message: RequestResponder,
|
||||||
req: Any,
|
req: Any,
|
||||||
session: ServerSession,
|
session: ServerSession,
|
||||||
|
lifespan_context: LifespanResultT,
|
||||||
raise_exceptions: bool,
|
raise_exceptions: bool,
|
||||||
):
|
):
|
||||||
logger.info(f"Processing request of type {type(req).__name__}")
|
logger.info(f"Processing request of type {type(req).__name__}")
|
||||||
@@ -491,6 +527,7 @@ class Server:
|
|||||||
message.request_id,
|
message.request_id,
|
||||||
message.request_meta,
|
message.request_meta,
|
||||||
session,
|
session,
|
||||||
|
lifespan_context,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
response = await handler(req)
|
response = await handler(req)
|
||||||
|
|||||||
@@ -5,10 +5,12 @@ from mcp.shared.session import BaseSession
|
|||||||
from mcp.types import RequestId, RequestParams
|
from mcp.types import RequestId, RequestParams
|
||||||
|
|
||||||
SessionT = TypeVar("SessionT", bound=BaseSession)
|
SessionT = TypeVar("SessionT", bound=BaseSession)
|
||||||
|
LifespanContextT = TypeVar("LifespanContextT")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RequestContext(Generic[SessionT]):
|
class RequestContext(Generic[SessionT, LifespanContextT]):
|
||||||
request_id: RequestId
|
request_id: RequestId
|
||||||
meta: RequestParams.Meta | None
|
meta: RequestParams.Meta | None
|
||||||
session: SessionT
|
session: SessionT
|
||||||
|
lifespan_context: LifespanContextT
|
||||||
|
|||||||
@@ -20,7 +20,10 @@ async def test_progress_token_zero_first_call():
|
|||||||
mock_meta.progressToken = 0 # This is the key test case - token is 0
|
mock_meta.progressToken = 0 # This is the key test case - token is 0
|
||||||
|
|
||||||
request_context = RequestContext(
|
request_context = RequestContext(
|
||||||
request_id="test-request", session=mock_session, meta=mock_meta
|
request_id="test-request",
|
||||||
|
session=mock_session,
|
||||||
|
meta=mock_meta,
|
||||||
|
lifespan_context=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create context with our mocks
|
# Create context with our mocks
|
||||||
|
|||||||
@@ -236,7 +236,7 @@ async def test_lambda_function():
|
|||||||
|
|
||||||
def test_complex_function_json_schema():
|
def test_complex_function_json_schema():
|
||||||
"""Test JSON schema generation for complex function arguments.
|
"""Test JSON schema generation for complex function arguments.
|
||||||
|
|
||||||
Note: Different versions of pydantic output slightly different
|
Note: Different versions of pydantic output slightly different
|
||||||
JSON Schema formats for model fields with defaults. The format changed in 2.9.0:
|
JSON Schema formats for model fields with defaults. The format changed in 2.9.0:
|
||||||
|
|
||||||
@@ -245,16 +245,16 @@ def test_complex_function_json_schema():
|
|||||||
"allOf": [{"$ref": "#/$defs/Model"}],
|
"allOf": [{"$ref": "#/$defs/Model"}],
|
||||||
"default": {}
|
"default": {}
|
||||||
}
|
}
|
||||||
|
|
||||||
2. Since 2.9.0:
|
2. Since 2.9.0:
|
||||||
{
|
{
|
||||||
"$ref": "#/$defs/Model",
|
"$ref": "#/$defs/Model",
|
||||||
"default": {}
|
"default": {}
|
||||||
}
|
}
|
||||||
|
|
||||||
Both formats are valid and functionally equivalent. This test accepts either format
|
Both formats are valid and functionally equivalent. This test accepts either format
|
||||||
to ensure compatibility across our supported pydantic versions.
|
to ensure compatibility across our supported pydantic versions.
|
||||||
|
|
||||||
This change in format does not affect runtime behavior since:
|
This change in format does not affect runtime behavior since:
|
||||||
1. Both schemas validate the same way
|
1. Both schemas validate the same way
|
||||||
2. The actual model classes and validation logic are unchanged
|
2. The actual model classes and validation logic are unchanged
|
||||||
@@ -262,17 +262,17 @@ def test_complex_function_json_schema():
|
|||||||
"""
|
"""
|
||||||
meta = func_metadata(complex_arguments_fn)
|
meta = func_metadata(complex_arguments_fn)
|
||||||
actual_schema = meta.arg_model.model_json_schema()
|
actual_schema = meta.arg_model.model_json_schema()
|
||||||
|
|
||||||
# Create a copy of the actual schema to normalize
|
# Create a copy of the actual schema to normalize
|
||||||
normalized_schema = actual_schema.copy()
|
normalized_schema = actual_schema.copy()
|
||||||
|
|
||||||
# Normalize the my_model_a_with_default field to handle both pydantic formats
|
# Normalize the my_model_a_with_default field to handle both pydantic formats
|
||||||
if 'allOf' in actual_schema['properties']['my_model_a_with_default']:
|
if "allOf" in actual_schema["properties"]["my_model_a_with_default"]:
|
||||||
normalized_schema['properties']['my_model_a_with_default'] = {
|
normalized_schema["properties"]["my_model_a_with_default"] = {
|
||||||
'$ref': '#/$defs/SomeInputModelA',
|
"$ref": "#/$defs/SomeInputModelA",
|
||||||
'default': {}
|
"default": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
assert normalized_schema == {
|
assert normalized_schema == {
|
||||||
"$defs": {
|
"$defs": {
|
||||||
"InnerModel": {
|
"InnerModel": {
|
||||||
|
|||||||
207
tests/server/test_lifespan.py
Normal file
207
tests/server/test_lifespan.py
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
"""Tests for lifespan functionality in both low-level and FastMCP servers."""
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import AsyncIterator
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
import pytest
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
from mcp.server.fastmcp import Context, FastMCP
|
||||||
|
from mcp.server.lowlevel.server import NotificationOptions, Server
|
||||||
|
from mcp.server.models import InitializationOptions
|
||||||
|
from mcp.types import (
|
||||||
|
ClientCapabilities,
|
||||||
|
Implementation,
|
||||||
|
InitializeRequestParams,
|
||||||
|
JSONRPCMessage,
|
||||||
|
JSONRPCNotification,
|
||||||
|
JSONRPCRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_lowlevel_server_lifespan():
|
||||||
|
"""Test that lifespan works in low-level server."""
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def test_lifespan(server: Server) -> AsyncIterator[dict]:
|
||||||
|
"""Test lifespan context that tracks startup/shutdown."""
|
||||||
|
context = {"started": False, "shutdown": False}
|
||||||
|
try:
|
||||||
|
context["started"] = True
|
||||||
|
yield context
|
||||||
|
finally:
|
||||||
|
context["shutdown"] = True
|
||||||
|
|
||||||
|
server = Server("test", lifespan=test_lifespan)
|
||||||
|
|
||||||
|
# Create memory streams for testing
|
||||||
|
send_stream1, receive_stream1 = anyio.create_memory_object_stream(100)
|
||||||
|
send_stream2, receive_stream2 = anyio.create_memory_object_stream(100)
|
||||||
|
|
||||||
|
# Create a tool that accesses lifespan context
|
||||||
|
@server.call_tool()
|
||||||
|
async def check_lifespan(name: str, arguments: dict) -> list:
|
||||||
|
ctx = server.request_context
|
||||||
|
assert isinstance(ctx.lifespan_context, dict)
|
||||||
|
assert ctx.lifespan_context["started"]
|
||||||
|
assert not ctx.lifespan_context["shutdown"]
|
||||||
|
return [{"type": "text", "text": "true"}]
|
||||||
|
|
||||||
|
# Run server in background task
|
||||||
|
async with anyio.create_task_group() as tg:
|
||||||
|
|
||||||
|
async def run_server():
|
||||||
|
await server.run(
|
||||||
|
receive_stream1,
|
||||||
|
send_stream2,
|
||||||
|
InitializationOptions(
|
||||||
|
server_name="test",
|
||||||
|
server_version="0.1.0",
|
||||||
|
capabilities=server.get_capabilities(
|
||||||
|
notification_options=NotificationOptions(),
|
||||||
|
experimental_capabilities={},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
raise_exceptions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
tg.start_soon(run_server)
|
||||||
|
|
||||||
|
# Initialize the server
|
||||||
|
params = InitializeRequestParams(
|
||||||
|
protocolVersion="2024-11-05",
|
||||||
|
capabilities=ClientCapabilities(),
|
||||||
|
clientInfo=Implementation(name="test-client", version="0.1.0"),
|
||||||
|
)
|
||||||
|
await send_stream1.send(
|
||||||
|
JSONRPCMessage(
|
||||||
|
root=JSONRPCRequest(
|
||||||
|
jsonrpc="2.0",
|
||||||
|
id=1,
|
||||||
|
method="initialize",
|
||||||
|
params=TypeAdapter(InitializeRequestParams).dump_python(params),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
response = await receive_stream2.receive()
|
||||||
|
|
||||||
|
# Send initialized notification
|
||||||
|
await send_stream1.send(
|
||||||
|
JSONRPCMessage(
|
||||||
|
root=JSONRPCNotification(
|
||||||
|
jsonrpc="2.0",
|
||||||
|
method="notifications/initialized",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the tool to verify lifespan context
|
||||||
|
await send_stream1.send(
|
||||||
|
JSONRPCMessage(
|
||||||
|
root=JSONRPCRequest(
|
||||||
|
jsonrpc="2.0",
|
||||||
|
id=2,
|
||||||
|
method="tools/call",
|
||||||
|
params={"name": "check_lifespan", "arguments": {}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get response and verify
|
||||||
|
response = await receive_stream2.receive()
|
||||||
|
assert response.root.result["content"][0]["text"] == "true"
|
||||||
|
|
||||||
|
# Cancel server task
|
||||||
|
tg.cancel_scope.cancel()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_fastmcp_server_lifespan():
|
||||||
|
"""Test that lifespan works in FastMCP server."""
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def test_lifespan(server: FastMCP) -> AsyncIterator[dict]:
|
||||||
|
"""Test lifespan context that tracks startup/shutdown."""
|
||||||
|
context = {"started": False, "shutdown": False}
|
||||||
|
try:
|
||||||
|
context["started"] = True
|
||||||
|
yield context
|
||||||
|
finally:
|
||||||
|
context["shutdown"] = True
|
||||||
|
|
||||||
|
server = FastMCP("test", lifespan=test_lifespan)
|
||||||
|
|
||||||
|
# Create memory streams for testing
|
||||||
|
send_stream1, receive_stream1 = anyio.create_memory_object_stream(100)
|
||||||
|
send_stream2, receive_stream2 = anyio.create_memory_object_stream(100)
|
||||||
|
|
||||||
|
# Add a tool that checks lifespan context
|
||||||
|
@server.tool()
|
||||||
|
def check_lifespan(ctx: Context) -> bool:
|
||||||
|
"""Tool that checks lifespan context."""
|
||||||
|
assert isinstance(ctx.request_context.lifespan_context, dict)
|
||||||
|
assert ctx.request_context.lifespan_context["started"]
|
||||||
|
assert not ctx.request_context.lifespan_context["shutdown"]
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Run server in background task
|
||||||
|
async with anyio.create_task_group() as tg:
|
||||||
|
|
||||||
|
async def run_server():
|
||||||
|
await server._mcp_server.run(
|
||||||
|
receive_stream1,
|
||||||
|
send_stream2,
|
||||||
|
server._mcp_server.create_initialization_options(),
|
||||||
|
raise_exceptions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
tg.start_soon(run_server)
|
||||||
|
|
||||||
|
# Initialize the server
|
||||||
|
params = InitializeRequestParams(
|
||||||
|
protocolVersion="2024-11-05",
|
||||||
|
capabilities=ClientCapabilities(),
|
||||||
|
clientInfo=Implementation(name="test-client", version="0.1.0"),
|
||||||
|
)
|
||||||
|
await send_stream1.send(
|
||||||
|
JSONRPCMessage(
|
||||||
|
root=JSONRPCRequest(
|
||||||
|
jsonrpc="2.0",
|
||||||
|
id=1,
|
||||||
|
method="initialize",
|
||||||
|
params=TypeAdapter(InitializeRequestParams).dump_python(params),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
response = await receive_stream2.receive()
|
||||||
|
|
||||||
|
# Send initialized notification
|
||||||
|
await send_stream1.send(
|
||||||
|
JSONRPCMessage(
|
||||||
|
root=JSONRPCNotification(
|
||||||
|
jsonrpc="2.0",
|
||||||
|
method="notifications/initialized",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the tool to verify lifespan context
|
||||||
|
await send_stream1.send(
|
||||||
|
JSONRPCMessage(
|
||||||
|
root=JSONRPCRequest(
|
||||||
|
jsonrpc="2.0",
|
||||||
|
id=2,
|
||||||
|
method="tools/call",
|
||||||
|
params={"name": "check_lifespan", "arguments": {}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get response and verify
|
||||||
|
response = await receive_stream2.receive()
|
||||||
|
assert response.root.result["content"][0]["text"] == "true"
|
||||||
|
|
||||||
|
# Cancel server task
|
||||||
|
tg.cancel_scope.cancel()
|
||||||
Reference in New Issue
Block a user