From 557e90d2e77e9479a4a2431742ac8c19e4d56b13 Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Mon, 9 Dec 2024 16:16:47 +0000 Subject: [PATCH] Integrate FastMCP This commit integrates FastMCP, a high-level MCP server implementation originally written by Jeremiah Lowin, into the official MCP SDK. It also updates dependencies and adds new dev dependencies. It moves the existing SDK into a .lowlevel . --- pyproject.toml | 35 +- src/mcp/server/__init__.py | 502 +------------ src/mcp/server/fastmcp/__init__.py | 8 + src/mcp/server/fastmcp/exceptions.py | 21 + src/mcp/server/fastmcp/prompts/__init__.py | 4 + src/mcp/server/fastmcp/prompts/base.py | 166 +++++ src/mcp/server/fastmcp/prompts/manager.py | 50 ++ .../server/fastmcp/prompts/prompt_manager.py | 33 + src/mcp/server/fastmcp/resources/__init__.py | 23 + src/mcp/server/fastmcp/resources/base.py | 48 ++ .../fastmcp/resources/resource_manager.py | 95 +++ src/mcp/server/fastmcp/resources/templates.py | 80 +++ src/mcp/server/fastmcp/resources/types.py | 181 +++++ src/mcp/server/fastmcp/server.py | 668 ++++++++++++++++++ src/mcp/server/fastmcp/tools/__init__.py | 4 + src/mcp/server/fastmcp/tools/base.py | 82 +++ src/mcp/server/fastmcp/tools/tool_manager.py | 54 ++ src/mcp/server/fastmcp/utilities/__init__.py | 1 + .../server/fastmcp/utilities/func_metadata.py | 210 ++++++ src/mcp/server/fastmcp/utilities/logging.py | 41 ++ src/mcp/server/fastmcp/utilities/types.py | 54 ++ src/mcp/server/lowlevel/__init__.py | 3 + src/mcp/server/lowlevel/server.py | 500 +++++++++++++ tests/conftest.py | 4 + tests/server/fastmcp/__init__.py | 0 tests/server/fastmcp/prompts/__init__.py | 0 tests/server/fastmcp/prompts/test_base.py | 194 +++++ tests/server/fastmcp/prompts/test_manager.py | 107 +++ tests/server/fastmcp/resources/__init__.py | 0 .../fastmcp/resources/test_file_resources.py | 115 +++ .../resources/test_function_resources.py | 115 +++ .../resources/test_resource_manager.py | 137 ++++ .../resources/test_resource_template.py | 181 +++++ .../fastmcp/resources/test_resources.py | 100 +++ tests/server/fastmcp/servers/__init__.py | 0 .../fastmcp/servers/test_file_server.py | 114 +++ tests/server/fastmcp/test_func_metadata.py | 361 ++++++++++ tests/server/fastmcp/test_server.py | 656 +++++++++++++++++ tests/server/fastmcp/test_tool_manager.py | 306 ++++++++ tests/server/test_session.py | 3 +- uv.lock | 140 +++- 41 files changed, 4875 insertions(+), 521 deletions(-) create mode 100644 src/mcp/server/fastmcp/__init__.py create mode 100644 src/mcp/server/fastmcp/exceptions.py create mode 100644 src/mcp/server/fastmcp/prompts/__init__.py create mode 100644 src/mcp/server/fastmcp/prompts/base.py create mode 100644 src/mcp/server/fastmcp/prompts/manager.py create mode 100644 src/mcp/server/fastmcp/prompts/prompt_manager.py create mode 100644 src/mcp/server/fastmcp/resources/__init__.py create mode 100644 src/mcp/server/fastmcp/resources/base.py create mode 100644 src/mcp/server/fastmcp/resources/resource_manager.py create mode 100644 src/mcp/server/fastmcp/resources/templates.py create mode 100644 src/mcp/server/fastmcp/resources/types.py create mode 100644 src/mcp/server/fastmcp/server.py create mode 100644 src/mcp/server/fastmcp/tools/__init__.py create mode 100644 src/mcp/server/fastmcp/tools/base.py create mode 100644 src/mcp/server/fastmcp/tools/tool_manager.py create mode 100644 src/mcp/server/fastmcp/utilities/__init__.py create mode 100644 src/mcp/server/fastmcp/utilities/func_metadata.py create mode 100644 src/mcp/server/fastmcp/utilities/logging.py create mode 100644 src/mcp/server/fastmcp/utilities/types.py create mode 100644 src/mcp/server/lowlevel/__init__.py create mode 100644 src/mcp/server/lowlevel/server.py create mode 100644 tests/server/fastmcp/__init__.py create mode 100644 tests/server/fastmcp/prompts/__init__.py create mode 100644 tests/server/fastmcp/prompts/test_base.py create mode 100644 tests/server/fastmcp/prompts/test_manager.py create mode 100644 tests/server/fastmcp/resources/__init__.py create mode 100644 tests/server/fastmcp/resources/test_file_resources.py create mode 100644 tests/server/fastmcp/resources/test_function_resources.py create mode 100644 tests/server/fastmcp/resources/test_resource_manager.py create mode 100644 tests/server/fastmcp/resources/test_resource_template.py create mode 100644 tests/server/fastmcp/resources/test_resources.py create mode 100644 tests/server/fastmcp/servers/__init__.py create mode 100644 tests/server/fastmcp/servers/test_file_server.py create mode 100644 tests/server/fastmcp/test_func_metadata.py create mode 100644 tests/server/fastmcp/test_server.py create mode 100644 tests/server/fastmcp/test_tool_manager.py diff --git a/pyproject.toml b/pyproject.toml index efde945..9ba6b9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,3 @@ -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - [project] name = "mcp" version = "1.1.2.dev0" @@ -29,11 +25,31 @@ dependencies = [ "anyio>=4.5", "httpx>=0.27", "httpx-sse>=0.4", - "pydantic>=2.7.2", + "pydantic>=2.7.2,<3.0.0", "starlette>=0.27", "sse-starlette>=1.6.1", + "pydantic-settings>=2.6.1", ] +[project.optional-dependencies] +rich = ["rich>=13.9.4"] + +[tool.uv] +resolution = "lowest-direct" +dev-dependencies = [ + "pyright>=1.1.378", + "pytest>=8.3.3", + "ruff>=0.6.9", + "trio>=0.26.2", + "pytest-flakefinder>=1.1.0", + "pytest-xdist>=3.6.1", + "pytest-asyncio>=0.24.0", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + [project.urls] Homepage = "https://modelcontextprotocol.io" Repository = "https://github.com/modelcontextprotocol/python-sdk" @@ -58,15 +74,6 @@ target-version = "py310" [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401"] -[tool.uv] -resolution = "lowest-direct" -dev-dependencies = [ - "pyright>=1.1.378", - "pytest>=8.3.3", - "ruff>=0.6.9", - "trio>=0.26.2", -] - [tool.uv.workspace] members = ["examples/servers/*"] diff --git a/src/mcp/server/__init__.py b/src/mcp/server/__init__.py index a0dd033..4db3e6d 100644 --- a/src/mcp/server/__init__.py +++ b/src/mcp/server/__init__.py @@ -1,500 +1,4 @@ -""" -MCP Server Module +from .lowlevel import Server, NotificationOptions +from .fastmcp import FastMCP -This module provides a framework for creating an MCP (Model Context Protocol) server. -It allows you to easily define and handle various types of requests and notifications -in an asynchronous manner. - -Usage: -1. Create a Server instance: - server = Server("your_server_name") - -2. Define request handlers using decorators: - @server.list_prompts() - async def handle_list_prompts() -> list[types.Prompt]: - # Implementation - - @server.get_prompt() - async def handle_get_prompt( - name: str, arguments: dict[str, str] | None - ) -> types.GetPromptResult: - # Implementation - - @server.list_tools() - async def handle_list_tools() -> list[types.Tool]: - # Implementation - - @server.call_tool() - async def handle_call_tool( - name: str, arguments: dict | None - ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: - # Implementation - - @server.list_resource_templates() - async def handle_list_resource_templates() -> list[types.ResourceTemplate]: - # Implementation - -3. Define notification handlers if needed: - @server.progress_notification() - async def handle_progress( - progress_token: str | int, progress: float, total: float | None - ) -> None: - # Implementation - -4. Run the server: - async def main(): - async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): - await server.run( - read_stream, - write_stream, - InitializationOptions( - server_name="your_server_name", - server_version="your_version", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), - ) - - asyncio.run(main()) - -The Server class provides methods to register handlers for various MCP requests and -notifications. It automatically manages the request context and handles incoming -messages from the client. -""" - -import contextvars -import logging -import warnings -from collections.abc import Awaitable, Callable -from typing import Any, Sequence - -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import AnyUrl - -import mcp.types as types -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession -from mcp.server.stdio import stdio_server as stdio_server -from mcp.shared.context import RequestContext -from mcp.shared.exceptions import McpError -from mcp.shared.session import RequestResponder - -logger = logging.getLogger(__name__) - -request_ctx: contextvars.ContextVar[RequestContext[ServerSession]] = ( - contextvars.ContextVar("request_ctx") -) - - -class NotificationOptions: - def __init__( - self, - prompts_changed: bool = False, - resources_changed: bool = False, - tools_changed: bool = False, - ): - self.prompts_changed = prompts_changed - self.resources_changed = resources_changed - self.tools_changed = tools_changed - - -class Server: - def __init__(self, name: str): - self.name = name - self.request_handlers: dict[ - type, Callable[..., Awaitable[types.ServerResult]] - ] = { - types.PingRequest: _ping_handler, - } - self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} - self.notification_options = NotificationOptions() - logger.debug(f"Initializing server '{name}'") - - def create_initialization_options( - self, - notification_options: NotificationOptions | None = None, - experimental_capabilities: dict[str, dict[str, Any]] | None = None, - ) -> InitializationOptions: - """Create initialization options from this server instance.""" - - def pkg_version(package: str) -> str: - try: - from importlib.metadata import version - - v = version(package) - if v is not None: - return v - except Exception: - pass - - return "unknown" - - return InitializationOptions( - server_name=self.name, - server_version=pkg_version("mcp"), - capabilities=self.get_capabilities( - notification_options or NotificationOptions(), - experimental_capabilities or {}, - ), - ) - - def get_capabilities( - self, - notification_options: NotificationOptions, - experimental_capabilities: dict[str, dict[str, Any]], - ) -> types.ServerCapabilities: - """Convert existing handlers to a ServerCapabilities object.""" - prompts_capability = None - resources_capability = None - tools_capability = None - logging_capability = None - - # Set prompt capabilities if handler exists - if types.ListPromptsRequest in self.request_handlers: - prompts_capability = types.PromptsCapability( - listChanged=notification_options.prompts_changed - ) - - # Set resource capabilities if handler exists - if types.ListResourcesRequest in self.request_handlers: - resources_capability = types.ResourcesCapability( - subscribe=False, listChanged=notification_options.resources_changed - ) - - # Set tool capabilities if handler exists - if types.ListToolsRequest in self.request_handlers: - tools_capability = types.ToolsCapability( - listChanged=notification_options.tools_changed - ) - - # Set logging capabilities if handler exists - if types.SetLevelRequest in self.request_handlers: - logging_capability = types.LoggingCapability() - - return types.ServerCapabilities( - prompts=prompts_capability, - resources=resources_capability, - tools=tools_capability, - logging=logging_capability, - experimental=experimental_capabilities, - ) - - @property - def request_context(self) -> RequestContext[ServerSession]: - """If called outside of a request context, this will raise a LookupError.""" - return request_ctx.get() - - def list_prompts(self): - def decorator(func: Callable[[], Awaitable[list[types.Prompt]]]): - logger.debug("Registering handler for PromptListRequest") - - async def handler(_: Any): - prompts = await func() - return types.ServerResult(types.ListPromptsResult(prompts=prompts)) - - self.request_handlers[types.ListPromptsRequest] = handler - return func - - return decorator - - def get_prompt(self): - def decorator( - func: Callable[ - [str, dict[str, str] | None], Awaitable[types.GetPromptResult] - ], - ): - logger.debug("Registering handler for GetPromptRequest") - - async def handler(req: types.GetPromptRequest): - prompt_get = await func(req.params.name, req.params.arguments) - return types.ServerResult(prompt_get) - - self.request_handlers[types.GetPromptRequest] = handler - return func - - return decorator - - def list_resources(self): - def decorator(func: Callable[[], Awaitable[list[types.Resource]]]): - logger.debug("Registering handler for ListResourcesRequest") - - async def handler(_: Any): - resources = await func() - return types.ServerResult( - types.ListResourcesResult(resources=resources) - ) - - self.request_handlers[types.ListResourcesRequest] = handler - return func - - return decorator - - def list_resource_templates(self): - def decorator(func: Callable[[], Awaitable[list[types.ResourceTemplate]]]): - logger.debug("Registering handler for ListResourceTemplatesRequest") - - async def handler(_: Any): - templates = await func() - return types.ServerResult( - types.ListResourceTemplatesResult(resourceTemplates=templates) - ) - - self.request_handlers[types.ListResourceTemplatesRequest] = handler - return func - - return decorator - - def read_resource(self): - def decorator(func: Callable[[AnyUrl], Awaitable[str | bytes]]): - logger.debug("Registering handler for ReadResourceRequest") - - async def handler(req: types.ReadResourceRequest): - result = await func(req.params.uri) - match result: - case str(s): - content = types.TextResourceContents( - uri=req.params.uri, - text=s, - mimeType="text/plain", - ) - case bytes(b): - import base64 - - content = types.BlobResourceContents( - uri=req.params.uri, - blob=base64.urlsafe_b64encode(b).decode(), - mimeType="application/octet-stream", - ) - - return types.ServerResult( - types.ReadResourceResult( - contents=[content], - ) - ) - - self.request_handlers[types.ReadResourceRequest] = handler - return func - - return decorator - - def set_logging_level(self): - def decorator(func: Callable[[types.LoggingLevel], Awaitable[None]]): - logger.debug("Registering handler for SetLevelRequest") - - async def handler(req: types.SetLevelRequest): - await func(req.params.level) - return types.ServerResult(types.EmptyResult()) - - self.request_handlers[types.SetLevelRequest] = handler - return func - - return decorator - - def subscribe_resource(self): - def decorator(func: Callable[[AnyUrl], Awaitable[None]]): - logger.debug("Registering handler for SubscribeRequest") - - async def handler(req: types.SubscribeRequest): - await func(req.params.uri) - return types.ServerResult(types.EmptyResult()) - - self.request_handlers[types.SubscribeRequest] = handler - return func - - return decorator - - def unsubscribe_resource(self): - def decorator(func: Callable[[AnyUrl], Awaitable[None]]): - logger.debug("Registering handler for UnsubscribeRequest") - - async def handler(req: types.UnsubscribeRequest): - await func(req.params.uri) - return types.ServerResult(types.EmptyResult()) - - self.request_handlers[types.UnsubscribeRequest] = handler - return func - - return decorator - - def list_tools(self): - def decorator(func: Callable[[], Awaitable[list[types.Tool]]]): - logger.debug("Registering handler for ListToolsRequest") - - async def handler(_: Any): - tools = await func() - return types.ServerResult(types.ListToolsResult(tools=tools)) - - self.request_handlers[types.ListToolsRequest] = handler - return func - - return decorator - - def call_tool(self): - def decorator( - func: Callable[ - ..., - Awaitable[ - Sequence[ - types.TextContent | types.ImageContent | types.EmbeddedResource - ] - ], - ], - ): - logger.debug("Registering handler for CallToolRequest") - - async def handler(req: types.CallToolRequest): - try: - results = await func(req.params.name, (req.params.arguments or {})) - return types.ServerResult( - types.CallToolResult(content=list(results), isError=False) - ) - except Exception as e: - return types.ServerResult( - types.CallToolResult( - content=[types.TextContent(type="text", text=str(e))], - isError=True, - ) - ) - - self.request_handlers[types.CallToolRequest] = handler - return func - - return decorator - - def progress_notification(self): - def decorator( - func: Callable[[str | int, float, float | None], Awaitable[None]], - ): - logger.debug("Registering handler for ProgressNotification") - - async def handler(req: types.ProgressNotification): - await func( - req.params.progressToken, req.params.progress, req.params.total - ) - - self.notification_handlers[types.ProgressNotification] = handler - return func - - return decorator - - def completion(self): - """Provides completions for prompts and resource templates""" - - def decorator( - func: Callable[ - [ - types.PromptReference | types.ResourceReference, - types.CompletionArgument, - ], - Awaitable[types.Completion | None], - ], - ): - logger.debug("Registering handler for CompleteRequest") - - async def handler(req: types.CompleteRequest): - completion = await func(req.params.ref, req.params.argument) - return types.ServerResult( - types.CompleteResult( - completion=completion - if completion is not None - else types.Completion(values=[], total=None, hasMore=None), - ) - ) - - self.request_handlers[types.CompleteRequest] = handler - return func - - return decorator - - async def run( - self, - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[types.JSONRPCMessage], - initialization_options: InitializationOptions, - # When False, exceptions are returned as messages to the client. - # When True, exceptions are raised, which will cause the server to shut down - # but also make tracing exceptions much easier during testing and when using - # in-process servers. - raise_exceptions: bool = False, - ): - with warnings.catch_warnings(record=True) as w: - async with ServerSession( - read_stream, write_stream, initialization_options - ) as session: - async for message in session.incoming_messages: - logger.debug(f"Received message: {message}") - - match message: - case RequestResponder(request=types.ClientRequest(root=req)): - logger.info( - f"Processing request of type {type(req).__name__}" - ) - if type(req) in self.request_handlers: - handler = self.request_handlers[type(req)] - logger.debug( - f"Dispatching request of type {type(req).__name__}" - ) - - token = None - try: - # Set our global state that can be retrieved via - # app.get_request_context() - token = request_ctx.set( - RequestContext( - message.request_id, - message.request_meta, - session, - ) - ) - response = await handler(req) - except McpError as err: - response = err.error - except Exception as err: - if raise_exceptions: - raise err - response = types.ErrorData( - code=0, message=str(err), data=None - ) - finally: - # Reset the global state after we are done - if token is not None: - request_ctx.reset(token) - - await message.respond(response) - else: - await message.respond( - types.ErrorData( - code=types.METHOD_NOT_FOUND, - message="Method not found", - ) - ) - - logger.debug("Response sent") - case types.ClientNotification(root=notify): - if type(notify) in self.notification_handlers: - assert type(notify) in self.notification_handlers - - handler = self.notification_handlers[type(notify)] - logger.debug( - f"Dispatching notification of type " - f"{type(notify).__name__}" - ) - - try: - await handler(notify) - except Exception as err: - logger.error( - f"Uncaught exception in notification handler: " - f"{err}" - ) - - for warning in w: - logger.info( - f"Warning: {warning.category.__name__}: {warning.message}" - ) - - -async def _ping_handler(request: types.PingRequest) -> types.ServerResult: - return types.ServerResult(types.EmptyResult()) +__all__ = ["Server", "FastMCP", "NotificationOptions"] diff --git a/src/mcp/server/fastmcp/__init__.py b/src/mcp/server/fastmcp/__init__.py new file mode 100644 index 0000000..4ff1a05 --- /dev/null +++ b/src/mcp/server/fastmcp/__init__.py @@ -0,0 +1,8 @@ +"""FastMCP - A more ergonomic interface for MCP servers.""" + +from importlib.metadata import version +from .server import FastMCP, Context +from .utilities.types import Image + +__version__ = version("mcp") +__all__ = ["FastMCP", "Context", "Image"] diff --git a/src/mcp/server/fastmcp/exceptions.py b/src/mcp/server/fastmcp/exceptions.py new file mode 100644 index 0000000..fb5bda1 --- /dev/null +++ b/src/mcp/server/fastmcp/exceptions.py @@ -0,0 +1,21 @@ +"""Custom exceptions for FastMCP.""" + + +class FastMCPError(Exception): + """Base error for FastMCP.""" + + +class ValidationError(FastMCPError): + """Error in validating parameters or return values.""" + + +class ResourceError(FastMCPError): + """Error in resource operations.""" + + +class ToolError(FastMCPError): + """Error in tool operations.""" + + +class InvalidSignature(Exception): + """Invalid signature for use with FastMCP.""" diff --git a/src/mcp/server/fastmcp/prompts/__init__.py b/src/mcp/server/fastmcp/prompts/__init__.py new file mode 100644 index 0000000..7637269 --- /dev/null +++ b/src/mcp/server/fastmcp/prompts/__init__.py @@ -0,0 +1,4 @@ +from .base import Prompt +from .manager import PromptManager + +__all__ = ["Prompt", "PromptManager"] diff --git a/src/mcp/server/fastmcp/prompts/base.py b/src/mcp/server/fastmcp/prompts/base.py new file mode 100644 index 0000000..8358f4b --- /dev/null +++ b/src/mcp/server/fastmcp/prompts/base.py @@ -0,0 +1,166 @@ +"""Base classes for FastMCP prompts.""" + +import json +from typing import Any, Literal, Sequence, Awaitable +import inspect +from collections.abc import Callable + +from pydantic import BaseModel, Field, TypeAdapter, validate_call +from mcp.types import TextContent, ImageContent, EmbeddedResource +import pydantic_core + +CONTENT_TYPES = TextContent | ImageContent | EmbeddedResource + + +class Message(BaseModel): + """Base class for all prompt messages.""" + + role: Literal["user", "assistant"] + content: CONTENT_TYPES + + def __init__(self, content: str | CONTENT_TYPES, **kwargs): + if isinstance(content, str): + content = TextContent(type="text", text=content) + super().__init__(content=content, **kwargs) + + +class UserMessage(Message): + """A message from the user.""" + + role: Literal["user", "assistant"] = "user" + + def __init__(self, content: str | CONTENT_TYPES, **kwargs): + super().__init__(content=content, **kwargs) + + +class AssistantMessage(Message): + """A message from the assistant.""" + + role: Literal["user", "assistant"] = "assistant" + + def __init__(self, content: str | CONTENT_TYPES, **kwargs): + super().__init__(content=content, **kwargs) + + +message_validator = TypeAdapter(UserMessage | AssistantMessage) + +SyncPromptResult = ( + str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]] +) +PromptResult = SyncPromptResult | Awaitable[SyncPromptResult] + + +class PromptArgument(BaseModel): + """An argument that can be passed to a prompt.""" + + name: str = Field(description="Name of the argument") + description: str | None = Field( + None, description="Description of what the argument does" + ) + required: bool = Field( + default=False, description="Whether the argument is required" + ) + + +class Prompt(BaseModel): + """A prompt template that can be rendered with parameters.""" + + name: str = Field(description="Name of the prompt") + description: str | None = Field( + None, description="Description of what the prompt does" + ) + arguments: list[PromptArgument] | None = Field( + None, description="Arguments that can be passed to the prompt" + ) + fn: Callable = Field(exclude=True) + + @classmethod + def from_function( + cls, + fn: Callable[..., PromptResult], + name: str | None = None, + description: str | None = None, + ) -> "Prompt": + """Create a Prompt from a function. + + The function can return: + - A string (converted to a message) + - A Message object + - A dict (converted to a message) + - A sequence of any of the above + """ + func_name = name or fn.__name__ + + if func_name == "": + raise ValueError("You must provide a name for lambda functions") + + # Get schema from TypeAdapter - will fail if function isn't properly typed + parameters = TypeAdapter(fn).json_schema() + + # Convert parameters to PromptArguments + arguments = [] + if "properties" in parameters: + for param_name, param in parameters["properties"].items(): + required = param_name in parameters.get("required", []) + arguments.append( + PromptArgument( + name=param_name, + description=param.get("description"), + required=required, + ) + ) + + # ensure the arguments are properly cast + fn = validate_call(fn) + + return cls( + name=func_name, + description=description or fn.__doc__ or "", + arguments=arguments, + fn=fn, + ) + + async def render(self, arguments: dict[str, Any] | None = None) -> list[Message]: + """Render the prompt with arguments.""" + # Validate required arguments + if self.arguments: + required = {arg.name for arg in self.arguments if arg.required} + provided = set(arguments or {}) + missing = required - provided + if missing: + raise ValueError(f"Missing required arguments: {missing}") + + try: + # Call function and check if result is a coroutine + result = self.fn(**(arguments or {})) + if inspect.iscoroutine(result): + result = await result + + # Validate messages + if not isinstance(result, (list, tuple)): + result = [result] + + # Convert result to messages + messages = [] + for msg in result: + try: + if isinstance(msg, Message): + messages.append(msg) + elif isinstance(msg, dict): + msg = message_validator.validate_python(msg) + messages.append(msg) + elif isinstance(msg, str): + messages.append( + UserMessage(content=TextContent(type="text", text=msg)) + ) + else: + msg = json.dumps(pydantic_core.to_jsonable_python(msg)) + messages.append(Message(role="user", content=msg)) + except Exception: + raise ValueError( + f"Could not convert prompt result to message: {msg}" + ) + + return messages + except Exception as e: + raise ValueError(f"Error rendering prompt {self.name}: {e}") diff --git a/src/mcp/server/fastmcp/prompts/manager.py b/src/mcp/server/fastmcp/prompts/manager.py new file mode 100644 index 0000000..7ccbdef --- /dev/null +++ b/src/mcp/server/fastmcp/prompts/manager.py @@ -0,0 +1,50 @@ +"""Prompt management functionality.""" + +from typing import Any + +from mcp.server.fastmcp.prompts.base import Message, Prompt +from mcp.server.fastmcp.utilities.logging import get_logger + +logger = get_logger(__name__) + + +class PromptManager: + """Manages FastMCP prompts.""" + + def __init__(self, warn_on_duplicate_prompts: bool = True): + self._prompts: dict[str, Prompt] = {} + self.warn_on_duplicate_prompts = warn_on_duplicate_prompts + + def get_prompt(self, name: str) -> Prompt | None: + """Get prompt by name.""" + return self._prompts.get(name) + + def list_prompts(self) -> list[Prompt]: + """List all registered prompts.""" + return list(self._prompts.values()) + + def add_prompt( + self, + prompt: Prompt, + ) -> Prompt: + """Add a prompt to the manager.""" + + # Check for duplicates + existing = self._prompts.get(prompt.name) + if existing: + if self.warn_on_duplicate_prompts: + logger.warning(f"Prompt already exists: {prompt.name}") + return existing + + self._prompts[prompt.name] = prompt + return prompt + + async def render_prompt( + self, name: str, arguments: dict[str, Any] | None = None + ) -> list[Message]: + """Render a prompt by name with arguments.""" + prompt = self.get_prompt(name) + if not prompt: + raise ValueError(f"Unknown prompt: {name}") + + return await prompt.render(arguments) diff --git a/src/mcp/server/fastmcp/prompts/prompt_manager.py b/src/mcp/server/fastmcp/prompts/prompt_manager.py new file mode 100644 index 0000000..389e896 --- /dev/null +++ b/src/mcp/server/fastmcp/prompts/prompt_manager.py @@ -0,0 +1,33 @@ +"""Prompt management functionality.""" + +from mcp.server.fastmcp.prompts.base import Prompt +from mcp.server.fastmcp.utilities.logging import get_logger + +logger = get_logger(__name__) + + +class PromptManager: + """Manages FastMCP prompts.""" + + def __init__(self, warn_on_duplicate_prompts: bool = True): + self._prompts: dict[str, Prompt] = {} + self.warn_on_duplicate_prompts = warn_on_duplicate_prompts + + def add_prompt(self, prompt: Prompt) -> Prompt: + """Add a prompt to the manager.""" + logger.debug(f"Adding prompt: {prompt.name}") + existing = self._prompts.get(prompt.name) + if existing: + if self.warn_on_duplicate_prompts: + logger.warning(f"Prompt already exists: {prompt.name}") + return existing + self._prompts[prompt.name] = prompt + return prompt + + def get_prompt(self, name: str) -> Prompt | None: + """Get prompt by name.""" + return self._prompts.get(name) + + def list_prompts(self) -> list[Prompt]: + """List all registered prompts.""" + return list(self._prompts.values()) diff --git a/src/mcp/server/fastmcp/resources/__init__.py b/src/mcp/server/fastmcp/resources/__init__.py new file mode 100644 index 0000000..92deb87 --- /dev/null +++ b/src/mcp/server/fastmcp/resources/__init__.py @@ -0,0 +1,23 @@ +from .base import Resource +from .types import ( + TextResource, + BinaryResource, + FunctionResource, + FileResource, + HttpResource, + DirectoryResource, +) +from .templates import ResourceTemplate +from .resource_manager import ResourceManager + +__all__ = [ + "Resource", + "TextResource", + "BinaryResource", + "FunctionResource", + "FileResource", + "HttpResource", + "DirectoryResource", + "ResourceTemplate", + "ResourceManager", +] diff --git a/src/mcp/server/fastmcp/resources/base.py b/src/mcp/server/fastmcp/resources/base.py new file mode 100644 index 0000000..b2050e7 --- /dev/null +++ b/src/mcp/server/fastmcp/resources/base.py @@ -0,0 +1,48 @@ +"""Base classes and interfaces for FastMCP resources.""" + +import abc +from typing import Annotated + +from pydantic import ( + AnyUrl, + BaseModel, + ConfigDict, + Field, + UrlConstraints, + ValidationInfo, + field_validator, +) + + +class Resource(BaseModel, abc.ABC): + """Base class for all resources.""" + + model_config = ConfigDict(validate_default=True) + + uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] = Field( + default=..., description="URI of the resource" + ) + name: str | None = Field(description="Name of the resource", default=None) + description: str | None = Field( + description="Description of the resource", default=None + ) + mime_type: str = Field( + default="text/plain", + description="MIME type of the resource content", + pattern=r"^[a-zA-Z0-9]+/[a-zA-Z0-9\-+.]+$", + ) + + @field_validator("name", mode="before") + @classmethod + def set_default_name(cls, name: str | None, info: ValidationInfo) -> str: + """Set default name from URI if not provided.""" + if name: + return name + if uri := info.data.get("uri"): + return str(uri) + raise ValueError("Either name or uri must be provided") + + @abc.abstractmethod + async def read(self) -> str | bytes: + """Read the resource content.""" + pass diff --git a/src/mcp/server/fastmcp/resources/resource_manager.py b/src/mcp/server/fastmcp/resources/resource_manager.py new file mode 100644 index 0000000..1f9561e --- /dev/null +++ b/src/mcp/server/fastmcp/resources/resource_manager.py @@ -0,0 +1,95 @@ +"""Resource manager functionality.""" + +from typing import Callable +from collections.abc import Iterable + +from pydantic import AnyUrl + +from mcp.server.fastmcp.resources.base import Resource +from mcp.server.fastmcp.resources.templates import ResourceTemplate +from mcp.server.fastmcp.utilities.logging import get_logger + +logger = get_logger(__name__) + + +class ResourceManager: + """Manages FastMCP resources.""" + + def __init__(self, warn_on_duplicate_resources: bool = True): + self._resources: dict[str, Resource] = {} + self._templates: dict[str, ResourceTemplate] = {} + self.warn_on_duplicate_resources = warn_on_duplicate_resources + + def add_resource(self, resource: Resource) -> Resource: + """Add a resource to the manager. + + Args: + resource: A Resource instance to add + + Returns: + The added resource. If a resource with the same URI already exists, + returns the existing resource. + """ + logger.debug( + "Adding resource", + extra={ + "uri": resource.uri, + "type": type(resource).__name__, + "name": resource.name, + }, + ) + existing = self._resources.get(str(resource.uri)) + if existing: + if self.warn_on_duplicate_resources: + logger.warning(f"Resource already exists: {resource.uri}") + return existing + self._resources[str(resource.uri)] = resource + return resource + + def add_template( + self, + fn: Callable, + uri_template: str, + name: str | None = None, + description: str | None = None, + mime_type: str | None = None, + ) -> ResourceTemplate: + """Add a template from a function.""" + template = ResourceTemplate.from_function( + fn, + uri_template=uri_template, + name=name, + description=description, + mime_type=mime_type, + ) + self._templates[template.uri_template] = template + return template + + async def get_resource(self, uri: AnyUrl | str) -> Resource | None: + """Get resource by URI, checking concrete resources first, then templates.""" + uri_str = str(uri) + logger.debug("Getting resource", extra={"uri": uri_str}) + + # First check concrete resources + if resource := self._resources.get(uri_str): + return resource + + # Then check templates + for template in self._templates.values(): + if params := template.matches(uri_str): + try: + return await template.create_resource(uri_str, params) + except Exception as e: + raise ValueError(f"Error creating resource from template: {e}") + + raise ValueError(f"Unknown resource: {uri}") + + def list_resources(self) -> list[Resource]: + """List all registered resources.""" + logger.debug("Listing resources", extra={"count": len(self._resources)}) + return list(self._resources.values()) + + def list_templates(self) -> list[ResourceTemplate]: + """List all registered templates.""" + logger.debug("Listing templates", extra={"count": len(self._templates)}) + return list(self._templates.values()) diff --git a/src/mcp/server/fastmcp/resources/templates.py b/src/mcp/server/fastmcp/resources/templates.py new file mode 100644 index 0000000..40afaf8 --- /dev/null +++ b/src/mcp/server/fastmcp/resources/templates.py @@ -0,0 +1,80 @@ +"""Resource template functionality.""" + +import inspect +import re +from typing import Any, Callable + +from pydantic import BaseModel, Field, TypeAdapter, validate_call + +from mcp.server.fastmcp.resources.types import FunctionResource, Resource + + +class ResourceTemplate(BaseModel): + """A template for dynamically creating resources.""" + + uri_template: str = Field( + description="URI template with parameters (e.g. weather://{city}/current)" + ) + name: str = Field(description="Name of the resource") + description: str | None = Field(description="Description of what the resource does") + mime_type: str = Field( + default="text/plain", description="MIME type of the resource content" + ) + fn: Callable = Field(exclude=True) + parameters: dict = Field(description="JSON schema for function parameters") + + @classmethod + def from_function( + cls, + fn: Callable, + uri_template: str, + name: str | None = None, + description: str | None = None, + mime_type: str | None = None, + ) -> "ResourceTemplate": + """Create a template from a function.""" + func_name = name or fn.__name__ + if func_name == "": + raise ValueError("You must provide a name for lambda functions") + + # Get schema from TypeAdapter - will fail if function isn't properly typed + parameters = TypeAdapter(fn).json_schema() + + # ensure the arguments are properly cast + fn = validate_call(fn) + + return cls( + uri_template=uri_template, + name=func_name, + description=description or fn.__doc__ or "", + mime_type=mime_type or "text/plain", + fn=fn, + parameters=parameters, + ) + + def matches(self, uri: str) -> dict[str, Any] | None: + """Check if URI matches template and extract parameters.""" + # Convert template to regex pattern + pattern = self.uri_template.replace("{", "(?P<").replace("}", ">[^/]+)") + match = re.match(f"^{pattern}$", uri) + if match: + return match.groupdict() + return None + + async def create_resource(self, uri: str, params: dict[str, Any]) -> Resource: + """Create a resource from the template with the given parameters.""" + try: + # Call function and check if result is a coroutine + result = self.fn(**params) + if inspect.iscoroutine(result): + result = await result + + return FunctionResource( + uri=uri, # type: ignore + name=self.name, + description=self.description, + mime_type=self.mime_type, + fn=lambda: result, # Capture result in closure + ) + except Exception as e: + raise ValueError(f"Error creating resource from template: {e}") diff --git a/src/mcp/server/fastmcp/resources/types.py b/src/mcp/server/fastmcp/resources/types.py new file mode 100644 index 0000000..b1a8088 --- /dev/null +++ b/src/mcp/server/fastmcp/resources/types.py @@ -0,0 +1,181 @@ +"""Concrete resource implementations.""" + +import anyio +import json +from pathlib import Path +from typing import Any, Callable +from collections.abc import Callable + +import httpx +import pydantic.json +import pydantic_core +from pydantic import Field, ValidationInfo + +from mcp.server.fastmcp.resources.base import Resource + + +class TextResource(Resource): + """A resource that reads from a string.""" + + text: str = Field(description="Text content of the resource") + + async def read(self) -> str: + """Read the text content.""" + return self.text + + +class BinaryResource(Resource): + """A resource that reads from bytes.""" + + data: bytes = Field(description="Binary content of the resource") + + async def read(self) -> bytes: + """Read the binary content.""" + return self.data + + +class FunctionResource(Resource): + """A resource that defers data loading by wrapping a function. + + The function is only called when the resource is read, allowing for lazy loading + of potentially expensive data. This is particularly useful when listing resources, + as the function won't be called until the resource is actually accessed. + + The function can return: + - str for text content (default) + - bytes for binary content + - other types will be converted to JSON + """ + + fn: Callable[[], Any] = Field(exclude=True) + + async def read(self) -> str | bytes: + """Read the resource by calling the wrapped function.""" + try: + result = self.fn() + if isinstance(result, Resource): + return await result.read() + if isinstance(result, bytes): + return result + if isinstance(result, str): + return result + try: + return json.dumps(pydantic_core.to_jsonable_python(result)) + except (TypeError, pydantic_core.PydanticSerializationError): + # If JSON serialization fails, try str() + return str(result) + except Exception as e: + raise ValueError(f"Error reading resource {self.uri}: {e}") + + +class FileResource(Resource): + """A resource that reads from a file. + + Set is_binary=True to read file as binary data instead of text. + """ + + path: Path = Field(description="Path to the file") + is_binary: bool = Field( + default=False, + description="Whether to read the file as binary data", + ) + mime_type: str = Field( + default="text/plain", + description="MIME type of the resource content", + ) + + @pydantic.field_validator("path") + @classmethod + def validate_absolute_path(cls, path: Path) -> Path: + """Ensure path is absolute.""" + if not path.is_absolute(): + raise ValueError("Path must be absolute") + return path + + @pydantic.field_validator("is_binary") + @classmethod + def set_binary_from_mime_type(cls, is_binary: bool, info: ValidationInfo) -> bool: + """Set is_binary based on mime_type if not explicitly set.""" + if is_binary: + return True + mime_type = info.data.get("mime_type", "text/plain") + return not mime_type.startswith("text/") + + async def read(self) -> str | bytes: + """Read the file content.""" + try: + if self.is_binary: + return await anyio.to_thread.run_sync(self.path.read_bytes) + return await anyio.to_thread.run_sync(self.path.read_text) + except Exception as e: + raise ValueError(f"Error reading file {self.path}: {e}") + + +class HttpResource(Resource): + """A resource that reads from an HTTP endpoint.""" + + url: str = Field(description="URL to fetch content from") + mime_type: str = Field( + default="application/json", description="MIME type of the resource content" + ) + + async def read(self) -> str | bytes: + """Read the HTTP content.""" + async with httpx.AsyncClient() as client: + response = await client.get(self.url) + response.raise_for_status() + return response.text + + +class DirectoryResource(Resource): + """A resource that lists files in a directory.""" + + path: Path = Field(description="Path to the directory") + recursive: bool = Field( + default=False, description="Whether to list files recursively" + ) + pattern: str | None = Field( + default=None, description="Optional glob pattern to filter files" + ) + mime_type: str = Field( + default="application/json", description="MIME type of the resource content" + ) + + @pydantic.field_validator("path") + @classmethod + def validate_absolute_path(cls, path: Path) -> Path: + """Ensure path is absolute.""" + if not path.is_absolute(): + raise ValueError("Path must be absolute") + return path + + def list_files(self) -> list[Path]: + """List files in the directory.""" + if not self.path.exists(): + raise FileNotFoundError(f"Directory not found: {self.path}") + if not self.path.is_dir(): + raise NotADirectoryError(f"Not a directory: {self.path}") + + try: + if self.pattern: + return ( + list(self.path.glob(self.pattern)) + if not self.recursive + else list(self.path.rglob(self.pattern)) + ) + return ( + list(self.path.glob("*")) + if not self.recursive + else list(self.path.rglob("*")) + ) + except Exception as e: + raise ValueError(f"Error listing directory {self.path}: {e}") + + async def read(self) -> str: # Always returns JSON string + """Read the directory listing.""" + try: + files = await anyio.to_thread.run_sync(self.list_files) + file_list = [str(f.relative_to(self.path)) for f in files if f.is_file()] + return json.dumps({"files": file_list}, indent=2) + except Exception as e: + raise ValueError(f"Error reading directory {self.path}: {e}") diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py new file mode 100644 index 0000000..b00627e --- /dev/null +++ b/src/mcp/server/fastmcp/server.py @@ -0,0 +1,668 @@ +"""FastMCP - A more ergonomic interface for MCP servers.""" + +import anyio +import functools +import inspect +import json +import re +from itertools import chain +from typing import Any, Callable, Literal, Sequence +from collections.abc import Iterable + +import pydantic_core +from pydantic import Field +import uvicorn +from mcp.server.lowlevel import Server as MCPServer +from mcp.server.sse import SseServerTransport +from mcp.server.stdio import stdio_server +from mcp.shared.context import RequestContext +from mcp.types import ( + EmbeddedResource, + GetPromptResult, + ImageContent, + TextContent, +) +from mcp.types import ( + Prompt as MCPPrompt, + PromptArgument as MCPPromptArgument, +) +from mcp.types import ( + Resource as MCPResource, +) +from mcp.types import ( + ResourceTemplate as MCPResourceTemplate, +) +from mcp.types import ( + Tool as MCPTool, +) +from pydantic import BaseModel +from pydantic.networks import AnyUrl +from pydantic_settings import BaseSettings, SettingsConfigDict + +from mcp.server.fastmcp.exceptions import ResourceError +from mcp.server.fastmcp.prompts import Prompt, PromptManager +from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager +from mcp.server.fastmcp.tools import ToolManager +from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger +from mcp.server.fastmcp.utilities.types import Image + +logger = get_logger(__name__) + + +class Settings(BaseSettings): + """FastMCP server settings. + + All settings can be configured via environment variables with the prefix FASTMCP_. + For example, FASTMCP_DEBUG=true will set debug=True. + """ + + model_config = SettingsConfigDict( + env_prefix="FASTMCP_", + env_file=".env", + extra="ignore", + ) + + # Server settings + debug: bool = False + log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO" + + # HTTP settings + host: str = "0.0.0.0" + port: int = 8000 + + # resource settings + warn_on_duplicate_resources: bool = True + + # tool settings + warn_on_duplicate_tools: bool = True + + # prompt settings + warn_on_duplicate_prompts: bool = True + + dependencies: list[str] = Field( + default_factory=list, + description="List of dependencies to install in the server environment", + ) + + +class FastMCP: + def __init__(self, name: str | None = None, **settings: Any): + self.settings = Settings(**settings) + self._mcp_server = MCPServer(name=name or "FastMCP") + self._tool_manager = ToolManager( + warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools + ) + self._resource_manager = ResourceManager( + warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources + ) + self._prompt_manager = PromptManager( + warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts + ) + self.dependencies = self.settings.dependencies + + # Set up MCP protocol handlers + self._setup_handlers() + + # Configure logging + configure_logging(self.settings.log_level) + + @property + def name(self) -> str: + return self._mcp_server.name + + def run(self, transport: Literal["stdio", "sse"] = "stdio") -> None: + """Run the FastMCP server. Note this is a synchronous function. + + Args: + transport: Transport protocol to use ("stdio" or "sse") + """ + TRANSPORTS = Literal["stdio", "sse"] + if transport not in TRANSPORTS.__args__: # type: ignore + raise ValueError(f"Unknown transport: {transport}") + + if transport == "stdio": + anyio.run(self.run_stdio_async) + else: # transport == "sse" + anyio.run(self.run_sse_async) + + def _setup_handlers(self) -> None: + """Set up core MCP protocol handlers.""" + self._mcp_server.list_tools()(self.list_tools) + self._mcp_server.call_tool()(self.call_tool) + self._mcp_server.list_resources()(self.list_resources) + self._mcp_server.read_resource()(self.read_resource) + self._mcp_server.list_prompts()(self.list_prompts) + self._mcp_server.get_prompt()(self.get_prompt) + # TODO: This has not been added to MCP yet, see https://github.com/jlowin/fastmcp/issues/10 + # self._mcp_server.list_resource_templates()(self.list_resource_templates) + + async def list_tools(self) -> list[MCPTool]: + """List all available tools.""" + tools = self._tool_manager.list_tools() + return [ + MCPTool( + name=info.name, + description=info.description, + inputSchema=info.parameters, + ) + for info in tools + ] + + def get_context(self) -> "Context": + """ + Returns a Context object. Note that the context will only be valid + during a request; outside a request, most methods will error. + """ + try: + request_context = self._mcp_server.request_context + except LookupError: + request_context = None + return Context(request_context=request_context, fastmcp=self) + + async def call_tool( + self, name: str, arguments: dict + ) -> Sequence[TextContent | ImageContent | EmbeddedResource]: + """Call a tool by name with arguments.""" + context = self.get_context() + result = await self._tool_manager.call_tool(name, arguments, context=context) + converted_result = _convert_to_content(result) + return converted_result + + async def list_resources(self) -> list[MCPResource]: + """List all available resources.""" + + resources = self._resource_manager.list_resources() + return [ + MCPResource( + uri=resource.uri, + name=resource.name or "", + description=resource.description, + mimeType=resource.mime_type, + ) + for resource in resources + ] + + async def list_resource_templates(self) -> list[MCPResourceTemplate]: + templates = self._resource_manager.list_templates() + return [ + MCPResourceTemplate( + uriTemplate=template.uri_template, + name=template.name, + description=template.description, + ) + for template in templates + ] + + async def read_resource(self, uri: AnyUrl | str) -> str | bytes: + """Read a resource by URI.""" + resource = await self._resource_manager.get_resource(uri) + if not resource: + raise ResourceError(f"Unknown resource: {uri}") + + try: + return await resource.read() + except Exception as e: + logger.error(f"Error reading resource {uri}: {e}") + raise ResourceError(str(e)) + + def add_tool( + self, + fn: Callable, + name: str | None = None, + description: str | None = None, + ) -> None: + """Add a tool to the server. + + The tool function can optionally request a Context object by adding a parameter + with the Context type annotation. See the @tool decorator for examples. + + Args: + fn: The function to register as a tool + name: Optional name for the tool (defaults to function name) + description: Optional description of what the tool does + """ + self._tool_manager.add_tool(fn, name=name, description=description) + + def tool(self, name: str | None = None, description: str | None = None) -> Callable: + """Decorator to register a tool. + + Tools can optionally request a Context object by adding a parameter with the Context type annotation. + The context provides access to MCP capabilities like logging, progress reporting, and resource access. + + Args: + name: Optional name for the tool (defaults to function name) + description: Optional description of what the tool does + + Example: + @server.tool() + def my_tool(x: int) -> str: + return str(x) + + @server.tool() + def tool_with_context(x: int, ctx: Context) -> str: + ctx.info(f"Processing {x}") + return str(x) + + @server.tool() + async def async_tool(x: int, context: Context) -> str: + await context.report_progress(50, 100) + return str(x) + """ + # Check if user passed function directly instead of calling decorator + if callable(name): + raise TypeError( + "The @tool decorator was used incorrectly. " + "Did you forget to call it? Use @tool() instead of @tool" + ) + + def decorator(fn: Callable) -> Callable: + self.add_tool(fn, name=name, description=description) + return fn + + return decorator + + def add_resource(self, resource: Resource) -> None: + """Add a resource to the server. + + Args: + resource: A Resource instance to add + """ + self._resource_manager.add_resource(resource) + + def resource( + self, + uri: str, + *, + name: str | None = None, + description: str | None = None, + mime_type: str | None = None, + ) -> Callable: + """Decorator to register a function as a resource. + + The function will be called when the resource is read to generate its content. + The function can return: + - str for text content + - bytes for binary content + - other types will be converted to JSON + + If the URI contains parameters (e.g. "resource://{param}") or the function + has parameters, it will be registered as a template resource. + + Args: + uri: URI for the resource (e.g. "resource://my-resource" or "resource://{param}") + name: Optional name for the resource + description: Optional description of the resource + mime_type: Optional MIME type for the resource + + Example: + @server.resource("resource://my-resource") + def get_data() -> str: + return "Hello, world!" + + @server.resource("resource://{city}/weather") + def get_weather(city: str) -> str: + return f"Weather for {city}" + """ + # Check if user passed function directly instead of calling decorator + if callable(uri): + raise TypeError( + "The @resource decorator was used incorrectly. " + "Did you forget to call it? Use @resource('uri') instead of @resource" + ) + + def decorator(fn: Callable) -> Callable: + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + return fn(*args, **kwargs) + + # Check if this should be a template + has_uri_params = "{" in uri and "}" in uri + has_func_params = bool(inspect.signature(fn).parameters) + + if has_uri_params or has_func_params: + # Validate that URI params match function params + uri_params = set(re.findall(r"{(\w+)}", uri)) + func_params = set(inspect.signature(fn).parameters.keys()) + + if uri_params != func_params: + raise ValueError( + f"Mismatch between URI parameters {uri_params} " + f"and function parameters {func_params}" + ) + + # Register as template + self._resource_manager.add_template( + wrapper, + uri_template=uri, + name=name, + description=description, + mime_type=mime_type or "text/plain", + ) + else: + # Register as regular resource + resource = FunctionResource( + uri=AnyUrl(uri), + name=name, + description=description, + mime_type=mime_type or "text/plain", + fn=wrapper, + ) + self.add_resource(resource) + return wrapper + + return decorator + + def add_prompt(self, prompt: Prompt) -> None: + """Add a prompt to the server. + + Args: + prompt: A Prompt instance to add + """ + self._prompt_manager.add_prompt(prompt) + + def prompt( + self, name: str | None = None, description: str | None = None + ) -> Callable: + """Decorator to register a prompt. + + Args: + name: Optional name for the prompt (defaults to function name) + description: Optional description of what the prompt does + + Example: + @server.prompt() + def analyze_table(table_name: str) -> list[Message]: + schema = read_table_schema(table_name) + return [ + { + "role": "user", + "content": f"Analyze this schema:\n{schema}" + } + ] + + @server.prompt() + async def analyze_file(path: str) -> list[Message]: + content = await read_file(path) + return [ + { + "role": "user", + "content": { + "type": "resource", + "resource": { + "uri": f"file://{path}", + "text": content + } + } + } + ] + """ + # Check if user passed function directly instead of calling decorator + if callable(name): + raise TypeError( + "The @prompt decorator was used incorrectly. " + "Did you forget to call it? Use @prompt() instead of @prompt" + ) + + def decorator(func: Callable) -> Callable: + prompt = Prompt.from_function(func, name=name, description=description) + self.add_prompt(prompt) + return func + + return decorator + + async def run_stdio_async(self) -> None: + """Run the server using stdio transport.""" + async with stdio_server() as (read_stream, write_stream): + await self._mcp_server.run( + read_stream, + write_stream, + self._mcp_server.create_initialization_options(), + ) + + async def run_sse_async(self) -> None: + """Run the server using SSE transport.""" + from starlette.applications import Starlette + from starlette.routing import Route + + sse = SseServerTransport("/messages") + + async def handle_sse(request): + async with sse.connect_sse( + request.scope, request.receive, request._send + ) as streams: + await self._mcp_server.run( + streams[0], + streams[1], + self._mcp_server.create_initialization_options(), + ) + + async def handle_messages(request): + await sse.handle_post_message(request.scope, request.receive, request._send) + + starlette_app = Starlette( + debug=self.settings.debug, + routes=[ + Route("/sse", endpoint=handle_sse), + Route("/messages", endpoint=handle_messages, methods=["POST"]), + ], + ) + + config = uvicorn.Config( + starlette_app, + host=self.settings.host, + port=self.settings.port, + log_level=self.settings.log_level.lower(), + ) + server = uvicorn.Server(config) + await server.serve() + + async def list_prompts(self) -> list[MCPPrompt]: + """List all available prompts.""" + prompts = self._prompt_manager.list_prompts() + return [ + MCPPrompt( + name=prompt.name, + description=prompt.description, + arguments=[ + MCPPromptArgument( + name=arg.name, + description=arg.description, + required=arg.required, + ) + for arg in (prompt.arguments or []) + ], + ) + for prompt in prompts + ] + + async def get_prompt( + self, name: str, arguments: dict[str, Any] | None = None + ) -> GetPromptResult: + """Get a prompt by name with arguments.""" + try: + messages = await self._prompt_manager.render_prompt(name, arguments) + + return GetPromptResult(messages=pydantic_core.to_jsonable_python(messages)) + except Exception as e: + logger.error(f"Error getting prompt {name}: {e}") + raise ValueError(str(e)) + + +def _convert_to_content( + result: Any, +) -> Sequence[TextContent | ImageContent | EmbeddedResource]: + """Convert a result to a sequence of content objects.""" + if result is None: + return [] + + if isinstance(result, (TextContent, ImageContent, EmbeddedResource)): + return [result] + + if isinstance(result, Image): + return [result.to_image_content()] + + if isinstance(result, (list, tuple)): + return list(chain.from_iterable(_convert_to_content(item) for item in result)) + + if not isinstance(result, str): + try: + result = json.dumps(pydantic_core.to_jsonable_python(result)) + except Exception: + result = str(result) + + return [TextContent(type="text", text=result)] + + +class Context(BaseModel): + """Context object providing access to MCP capabilities. + + This provides a cleaner interface to MCP's RequestContext functionality. + It gets injected into tool and resource functions that request it via type hints. + + To use context in a tool function, add a parameter with the Context type annotation: + + ```python + @server.tool() + def my_tool(x: int, ctx: Context) -> str: + # Log messages to the client + ctx.info(f"Processing {x}") + ctx.debug("Debug info") + ctx.warning("Warning message") + ctx.error("Error message") + + # Report progress + ctx.report_progress(50, 100) + + # Access resources + data = ctx.read_resource("resource://data") + + # Get request info + request_id = ctx.request_id + client_id = ctx.client_id + + return str(x) + ``` + + The context parameter name can be anything as long as it's annotated with Context. + The context is optional - tools that don't need it can omit the parameter. + """ + + _request_context: RequestContext | None + _fastmcp: FastMCP | None + + def __init__( + self, + *, + request_context: RequestContext | None = None, + fastmcp: FastMCP | None = None, + **kwargs: Any, + ): + super().__init__(**kwargs) + self._request_context = request_context + self._fastmcp = fastmcp + + @property + def fastmcp(self) -> FastMCP: + """Access to the FastMCP server.""" + if self._fastmcp is None: + raise ValueError("Context is not available outside of a request") + return self._fastmcp + + @property + def request_context(self) -> RequestContext: + """Access to the underlying request context.""" + if self._request_context is None: + raise ValueError("Context is not available outside of a request") + return self._request_context + + async def report_progress( + self, progress: float, total: float | None = None + ) -> None: + """Report progress for the current operation. + + Args: + progress: Current progress value e.g. 24 + total: Optional total value e.g. 100 + """ + + progress_token = ( + self.request_context.meta.progressToken + if self.request_context.meta + else None + ) + + if not progress_token: + return + + await self.request_context.session.send_progress_notification( + progress_token=progress_token, progress=progress, total=total + ) + + async def read_resource(self, uri: str | AnyUrl) -> str | bytes: + """Read a resource by URI. + + Args: + uri: Resource URI to read + + Returns: + The resource content as either text or bytes + """ + assert ( + self._fastmcp is not None + ), "Context is not available outside of a request" + return await self._fastmcp.read_resource(uri) + + def log( + self, + level: Literal["debug", "info", "warning", "error"], + message: str, + *, + logger_name: str | None = None, + ) -> None: + """Send a log message to the client. + + Args: + level: Log level (debug, info, warning, error) + message: Log message + logger_name: Optional logger name + **extra: Additional structured data to include + """ + self.request_context.session.send_log_message( + level=level, data=message, logger=logger_name + ) + + @property + def client_id(self) -> str | None: + """Get the client ID if available.""" + return ( + getattr(self.request_context.meta, "client_id", None) + if self.request_context.meta + else None + ) + + @property + def request_id(self) -> str: + """Get the unique ID for this request.""" + return str(self.request_context.request_id) + + @property + def session(self): + """Access to the underlying session for advanced usage.""" + return self.request_context.session + + # Convenience methods for common log levels + def debug(self, message: str, **extra: Any) -> None: + """Send a debug log message.""" + self.log("debug", message, **extra) + + def info(self, message: str, **extra: Any) -> None: + """Send an info log message.""" + self.log("info", message, **extra) + + def warning(self, message: str, **extra: Any) -> None: + """Send a warning log message.""" + self.log("warning", message, **extra) + + def error(self, message: str, **extra: Any) -> None: + """Send an error log message.""" + self.log("error", message, **extra) diff --git a/src/mcp/server/fastmcp/tools/__init__.py b/src/mcp/server/fastmcp/tools/__init__.py new file mode 100644 index 0000000..ae9c656 --- /dev/null +++ b/src/mcp/server/fastmcp/tools/__init__.py @@ -0,0 +1,4 @@ +from .base import Tool +from .tool_manager import ToolManager + +__all__ = ["Tool", "ToolManager"] diff --git a/src/mcp/server/fastmcp/tools/base.py b/src/mcp/server/fastmcp/tools/base.py new file mode 100644 index 0000000..8f2ea48 --- /dev/null +++ b/src/mcp/server/fastmcp/tools/base.py @@ -0,0 +1,82 @@ +import mcp.server.fastmcp +from mcp.server.fastmcp.exceptions import ToolError +from mcp.server.fastmcp.utilities.func_metadata import func_metadata, FuncMetadata +from pydantic import BaseModel, Field + + +import inspect +from typing import TYPE_CHECKING, Any, Callable + +if TYPE_CHECKING: + from mcp.server.fastmcp.server import Context + + +class Tool(BaseModel): + """Internal tool registration info.""" + + fn: Callable = Field(exclude=True) + name: str = Field(description="Name of the tool") + description: str = Field(description="Description of what the tool does") + parameters: dict = Field(description="JSON schema for tool parameters") + fn_metadata: FuncMetadata = Field( + description="Metadata about the function including a pydantic model for tool arguments" + ) + is_async: bool = Field(description="Whether the tool is async") + context_kwarg: str | None = Field( + None, description="Name of the kwarg that should receive context" + ) + + @classmethod + def from_function( + cls, + fn: Callable, + name: str | None = None, + description: str | None = None, + context_kwarg: str | None = None, + ) -> "Tool": + """Create a Tool from a function.""" + func_name = name or fn.__name__ + + if func_name == "": + raise ValueError("You must provide a name for lambda functions") + + func_doc = description or fn.__doc__ or "" + is_async = inspect.iscoroutinefunction(fn) + + # Find context parameter if it exists + if context_kwarg is None: + sig = inspect.signature(fn) + for param_name, param in sig.parameters.items(): + if param.annotation is mcp.server.fastmcp.Context: + context_kwarg = param_name + break + + func_arg_metadata = func_metadata( + fn, + skip_names=[context_kwarg] if context_kwarg is not None else [], + ) + parameters = func_arg_metadata.arg_model.model_json_schema() + + return cls( + fn=fn, + name=func_name, + description=func_doc, + parameters=parameters, + fn_metadata=func_arg_metadata, + is_async=is_async, + context_kwarg=context_kwarg, + ) + + async def run(self, arguments: dict, context: "Context | None" = None) -> Any: + """Run the tool with arguments.""" + try: + return await self.fn_metadata.call_fn_with_arg_validation( + self.fn, + self.is_async, + arguments, + {self.context_kwarg: context} + if self.context_kwarg is not None + else None, + ) + except Exception as e: + raise ToolError(f"Error executing tool {self.name}: {e}") from e diff --git a/src/mcp/server/fastmcp/tools/tool_manager.py b/src/mcp/server/fastmcp/tools/tool_manager.py new file mode 100644 index 0000000..52b45d1 --- /dev/null +++ b/src/mcp/server/fastmcp/tools/tool_manager.py @@ -0,0 +1,54 @@ +from mcp.server.fastmcp.exceptions import ToolError +from mcp.server.fastmcp.tools.base import Tool + +from typing import Any, Callable, TYPE_CHECKING +from collections.abc import Callable + +from mcp.server.fastmcp.utilities.logging import get_logger + +if TYPE_CHECKING: + from mcp.server.fastmcp.server import Context + +logger = get_logger(__name__) + + +class ToolManager: + """Manages FastMCP tools.""" + + def __init__(self, warn_on_duplicate_tools: bool = True): + self._tools: dict[str, Tool] = {} + self.warn_on_duplicate_tools = warn_on_duplicate_tools + + def get_tool(self, name: str) -> Tool | None: + """Get tool by name.""" + return self._tools.get(name) + + def list_tools(self) -> list[Tool]: + """List all registered tools.""" + return list(self._tools.values()) + + def add_tool( + self, + fn: Callable, + name: str | None = None, + description: str | None = None, + ) -> Tool: + """Add a tool to the server.""" + tool = Tool.from_function(fn, name=name, description=description) + existing = self._tools.get(tool.name) + if existing: + if self.warn_on_duplicate_tools: + logger.warning(f"Tool already exists: {tool.name}") + return existing + self._tools[tool.name] = tool + return tool + + async def call_tool( + self, name: str, arguments: dict, context: "Context | None" = None + ) -> Any: + """Call a tool by name with arguments.""" + tool = self.get_tool(name) + if not tool: + raise ToolError(f"Unknown tool: {name}") + + return await tool.run(arguments, context=context) diff --git a/src/mcp/server/fastmcp/utilities/__init__.py b/src/mcp/server/fastmcp/utilities/__init__.py new file mode 100644 index 0000000..be448f9 --- /dev/null +++ b/src/mcp/server/fastmcp/utilities/__init__.py @@ -0,0 +1 @@ +"""FastMCP utility modules.""" diff --git a/src/mcp/server/fastmcp/utilities/func_metadata.py b/src/mcp/server/fastmcp/utilities/func_metadata.py new file mode 100644 index 0000000..b1f1385 --- /dev/null +++ b/src/mcp/server/fastmcp/utilities/func_metadata.py @@ -0,0 +1,210 @@ +import inspect +from collections.abc import Callable, Sequence, Awaitable +from typing import ( + Annotated, + Any, + ForwardRef, +) +from pydantic import Field +from mcp.server.fastmcp.exceptions import InvalidSignature +from pydantic._internal._typing_extra import eval_type_backport +import json +from pydantic import BaseModel +from pydantic.fields import FieldInfo +from pydantic import ConfigDict, create_model +from pydantic import WithJsonSchema +from pydantic_core import PydanticUndefined +from mcp.server.fastmcp.utilities.logging import get_logger + + +logger = get_logger(__name__) + + +class ArgModelBase(BaseModel): + """A model representing the arguments to a function.""" + + def model_dump_one_level(self) -> dict[str, Any]: + """Return a dict of the model's fields, one level deep. + + That is, sub-models etc are not dumped - they are kept as pydantic models. + """ + kwargs: dict[str, Any] = {} + for field_name in self.model_fields.keys(): + kwargs[field_name] = getattr(self, field_name) + return kwargs + + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) + + +class FuncMetadata(BaseModel): + arg_model: Annotated[type[ArgModelBase], WithJsonSchema(None)] + # We can add things in the future like + # - Maybe some args are excluded from attempting to parse from JSON + # - Maybe some args are special (like context) for dependency injection + + async def call_fn_with_arg_validation( + self, + fn: Callable[..., Any] | Awaitable[Any], + fn_is_async: bool, + arguments_to_validate: dict[str, Any], + arguments_to_pass_directly: dict[str, Any] | None, + ) -> Any: + """Call the given function with arguments validated and injected. + + Arguments are first attempted to be parsed from JSON, then validated against + the argument model, before being passed to the function. + """ + arguments_pre_parsed = self.pre_parse_json(arguments_to_validate) + arguments_parsed_model = self.arg_model.model_validate(arguments_pre_parsed) + arguments_parsed_dict = arguments_parsed_model.model_dump_one_level() + + arguments_parsed_dict |= arguments_to_pass_directly or {} + + if fn_is_async: + if isinstance(fn, Awaitable): + return await fn + return await fn(**arguments_parsed_dict) + if isinstance(fn, Callable): + return fn(**arguments_parsed_dict) + raise TypeError("fn must be either Callable or Awaitable") + + def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]: + """Pre-parse data from JSON. + + Return a dict with same keys as input but with values parsed from JSON + if appropriate. + + This is to handle cases like `["a", "b", "c"]` being passed in as JSON inside + a string rather than an actual list. Claude desktop is prone to this - in fact + it seems incapable of NOT doing this. For sub-models, it tends to pass + dicts (JSON objects) as JSON strings, which can be pre-parsed here. + """ + new_data = data.copy() # Shallow copy + for field_name, field_info in self.arg_model.model_fields.items(): + if field_name not in data.keys(): + continue + if isinstance(data[field_name], str): + try: + pre_parsed = json.loads(data[field_name]) + except json.JSONDecodeError: + continue # Not JSON - skip + if isinstance(pre_parsed, str): + # This is likely that the raw value is e.g. `"hello"` which we + # Should really be parsed as '"hello"' in Python - but if we parse + # it as JSON it'll turn into just 'hello'. So we skip it. + continue + new_data[field_name] = pre_parsed + assert new_data.keys() == data.keys() + return new_data + + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) + + +def func_metadata(func: Callable, skip_names: Sequence[str] = ()) -> FuncMetadata: + """Given a function, return metadata including a pydantic model representing its signature. + + The use case for this is + ``` + meta = func_to_pyd(func) + validated_args = meta.arg_model.model_validate(some_raw_data_dict) + return func(**validated_args.model_dump_one_level()) + ``` + + **critically** it also provides pre-parse helper to attempt to parse things from JSON. + + Args: + func: The function to convert to a pydantic model + skip_names: A list of parameter names to skip. These will not be included in + the model. + Returns: + A pydantic model representing the function's signature. + """ + sig = _get_typed_signature(func) + params = sig.parameters + dynamic_pydantic_model_params: dict[str, Any] = {} + globalns = getattr(func, "__globals__", {}) + for param in params.values(): + if param.name.startswith("_"): + raise InvalidSignature( + f"Parameter {param.name} of {func.__name__} may not start with an underscore" + ) + if param.name in skip_names: + continue + annotation = param.annotation + + # `x: None` / `x: None = None` + if annotation is None: + annotation = Annotated[ + None, + Field( + default=param.default + if param.default is not inspect.Parameter.empty + else PydanticUndefined + ), + ] + + # Untyped field + if annotation is inspect.Parameter.empty: + annotation = Annotated[ + Any, + Field(), + # 🤷 + WithJsonSchema({"title": param.name, "type": "string"}), + ] + + field_info = FieldInfo.from_annotated_attribute( + _get_typed_annotation(annotation, globalns), + param.default + if param.default is not inspect.Parameter.empty + else PydanticUndefined, + ) + dynamic_pydantic_model_params[param.name] = (field_info.annotation, field_info) + continue + + arguments_model = create_model( + f"{func.__name__}Arguments", + **dynamic_pydantic_model_params, + __base__=ArgModelBase, + ) + resp = FuncMetadata(arg_model=arguments_model) + return resp + + +def _get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any: + def try_eval_type(value, globalns, localns): + try: + return eval_type_backport(value, globalns, localns), True + except NameError: + return value, False + + if isinstance(annotation, str): + annotation = ForwardRef(annotation) + annotation, status = try_eval_type(annotation, globalns, globalns) + + # This check and raise could perhaps be skipped, and we (FastMCP) just call + # model_rebuild right before using it 🤷 + if status is False: + raise InvalidSignature(f"Unable to evaluate type annotation {annotation}") + + return annotation + + +def _get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: + """Get function signature while evaluating forward references""" + signature = inspect.signature(call) + globalns = getattr(call, "__globals__", {}) + typed_params = [ + inspect.Parameter( + name=param.name, + kind=param.kind, + default=param.default, + annotation=_get_typed_annotation(param.annotation, globalns), + ) + for param in signature.parameters.values() + ] + typed_signature = inspect.Signature(typed_params) + return typed_signature diff --git a/src/mcp/server/fastmcp/utilities/logging.py b/src/mcp/server/fastmcp/utilities/logging.py new file mode 100644 index 0000000..60738f8 --- /dev/null +++ b/src/mcp/server/fastmcp/utilities/logging.py @@ -0,0 +1,41 @@ +"""Logging utilities for FastMCP.""" + +import logging +from typing import Literal + +def get_logger(name: str) -> logging.Logger: + """Get a logger nested under MCPnamespace. + + Args: + name: the name of the logger, which will be prefixed with 'FastMCP.' + + Returns: + a configured logger instance + """ + return logging.getLogger(name) + + +def configure_logging( + level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO", +) -> None: + """Configure logging for MCP. + + Args: + level: the log level to use + """ + handlers = [] + try: + from rich.console import Console + from rich.logging import RichHandler + handlers.append(RichHandler(console=Console(stderr=True), rich_tracebacks=True)) + except ImportError: + pass + + if not handlers: + handlers.append(logging.StreamHandler()) + + logging.basicConfig( + level=level, + format="%(message)s", + handlers=handlers, + ) diff --git a/src/mcp/server/fastmcp/utilities/types.py b/src/mcp/server/fastmcp/utilities/types.py new file mode 100644 index 0000000..ccaa3d6 --- /dev/null +++ b/src/mcp/server/fastmcp/utilities/types.py @@ -0,0 +1,54 @@ +"""Common types used across FastMCP.""" + +import base64 +from pathlib import Path + +from mcp.types import ImageContent + + +class Image: + """Helper class for returning images from tools.""" + + def __init__( + self, + path: str | Path | None = None, + data: bytes | None = None, + format: str | None = None, + ): + if path is None and data is None: + raise ValueError("Either path or data must be provided") + if path is not None and data is not None: + raise ValueError("Only one of path or data can be provided") + + self.path = Path(path) if path else None + self.data = data + self._format = format + self._mime_type = self._get_mime_type() + + def _get_mime_type(self) -> str: + """Get MIME type from format or guess from file extension.""" + if self._format: + return f"image/{self._format.lower()}" + + if self.path: + suffix = self.path.suffix.lower() + return { + ".png": "image/png", + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".gif": "image/gif", + ".webp": "image/webp", + }.get(suffix, "application/octet-stream") + return "image/png" # default for raw binary data + + def to_image_content(self) -> ImageContent: + """Convert to MCP ImageContent.""" + if self.path: + with open(self.path, "rb") as f: + data = base64.b64encode(f.read()).decode() + elif self.data is not None: + data = base64.b64encode(self.data).decode() + else: + raise ValueError("No image data available") + + return ImageContent(type="image", data=data, mimeType=self._mime_type) diff --git a/src/mcp/server/lowlevel/__init__.py b/src/mcp/server/lowlevel/__init__.py new file mode 100644 index 0000000..a6dff43 --- /dev/null +++ b/src/mcp/server/lowlevel/__init__.py @@ -0,0 +1,3 @@ +from .server import Server, NotificationOptions + +__all__ = ["Server", "NotificationOptions"] diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py new file mode 100644 index 0000000..a0dd033 --- /dev/null +++ b/src/mcp/server/lowlevel/server.py @@ -0,0 +1,500 @@ +""" +MCP Server Module + +This module provides a framework for creating an MCP (Model Context Protocol) server. +It allows you to easily define and handle various types of requests and notifications +in an asynchronous manner. + +Usage: +1. Create a Server instance: + server = Server("your_server_name") + +2. Define request handlers using decorators: + @server.list_prompts() + async def handle_list_prompts() -> list[types.Prompt]: + # Implementation + + @server.get_prompt() + async def handle_get_prompt( + name: str, arguments: dict[str, str] | None + ) -> types.GetPromptResult: + # Implementation + + @server.list_tools() + async def handle_list_tools() -> list[types.Tool]: + # Implementation + + @server.call_tool() + async def handle_call_tool( + name: str, arguments: dict | None + ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: + # Implementation + + @server.list_resource_templates() + async def handle_list_resource_templates() -> list[types.ResourceTemplate]: + # Implementation + +3. Define notification handlers if needed: + @server.progress_notification() + async def handle_progress( + progress_token: str | int, progress: float, total: float | None + ) -> None: + # Implementation + +4. Run the server: + async def main(): + async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): + await server.run( + read_stream, + write_stream, + InitializationOptions( + server_name="your_server_name", + server_version="your_version", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) + + asyncio.run(main()) + +The Server class provides methods to register handlers for various MCP requests and +notifications. It automatically manages the request context and handles incoming +messages from the client. +""" + +import contextvars +import logging +import warnings +from collections.abc import Awaitable, Callable +from typing import Any, Sequence + +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import AnyUrl + +import mcp.types as types +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession +from mcp.server.stdio import stdio_server as stdio_server +from mcp.shared.context import RequestContext +from mcp.shared.exceptions import McpError +from mcp.shared.session import RequestResponder + +logger = logging.getLogger(__name__) + +request_ctx: contextvars.ContextVar[RequestContext[ServerSession]] = ( + contextvars.ContextVar("request_ctx") +) + + +class NotificationOptions: + def __init__( + self, + prompts_changed: bool = False, + resources_changed: bool = False, + tools_changed: bool = False, + ): + self.prompts_changed = prompts_changed + self.resources_changed = resources_changed + self.tools_changed = tools_changed + + +class Server: + def __init__(self, name: str): + self.name = name + self.request_handlers: dict[ + type, Callable[..., Awaitable[types.ServerResult]] + ] = { + types.PingRequest: _ping_handler, + } + self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} + self.notification_options = NotificationOptions() + logger.debug(f"Initializing server '{name}'") + + def create_initialization_options( + self, + notification_options: NotificationOptions | None = None, + experimental_capabilities: dict[str, dict[str, Any]] | None = None, + ) -> InitializationOptions: + """Create initialization options from this server instance.""" + + def pkg_version(package: str) -> str: + try: + from importlib.metadata import version + + v = version(package) + if v is not None: + return v + except Exception: + pass + + return "unknown" + + return InitializationOptions( + server_name=self.name, + server_version=pkg_version("mcp"), + capabilities=self.get_capabilities( + notification_options or NotificationOptions(), + experimental_capabilities or {}, + ), + ) + + def get_capabilities( + self, + notification_options: NotificationOptions, + experimental_capabilities: dict[str, dict[str, Any]], + ) -> types.ServerCapabilities: + """Convert existing handlers to a ServerCapabilities object.""" + prompts_capability = None + resources_capability = None + tools_capability = None + logging_capability = None + + # Set prompt capabilities if handler exists + if types.ListPromptsRequest in self.request_handlers: + prompts_capability = types.PromptsCapability( + listChanged=notification_options.prompts_changed + ) + + # Set resource capabilities if handler exists + if types.ListResourcesRequest in self.request_handlers: + resources_capability = types.ResourcesCapability( + subscribe=False, listChanged=notification_options.resources_changed + ) + + # Set tool capabilities if handler exists + if types.ListToolsRequest in self.request_handlers: + tools_capability = types.ToolsCapability( + listChanged=notification_options.tools_changed + ) + + # Set logging capabilities if handler exists + if types.SetLevelRequest in self.request_handlers: + logging_capability = types.LoggingCapability() + + return types.ServerCapabilities( + prompts=prompts_capability, + resources=resources_capability, + tools=tools_capability, + logging=logging_capability, + experimental=experimental_capabilities, + ) + + @property + def request_context(self) -> RequestContext[ServerSession]: + """If called outside of a request context, this will raise a LookupError.""" + return request_ctx.get() + + def list_prompts(self): + def decorator(func: Callable[[], Awaitable[list[types.Prompt]]]): + logger.debug("Registering handler for PromptListRequest") + + async def handler(_: Any): + prompts = await func() + return types.ServerResult(types.ListPromptsResult(prompts=prompts)) + + self.request_handlers[types.ListPromptsRequest] = handler + return func + + return decorator + + def get_prompt(self): + def decorator( + func: Callable[ + [str, dict[str, str] | None], Awaitable[types.GetPromptResult] + ], + ): + logger.debug("Registering handler for GetPromptRequest") + + async def handler(req: types.GetPromptRequest): + prompt_get = await func(req.params.name, req.params.arguments) + return types.ServerResult(prompt_get) + + self.request_handlers[types.GetPromptRequest] = handler + return func + + return decorator + + def list_resources(self): + def decorator(func: Callable[[], Awaitable[list[types.Resource]]]): + logger.debug("Registering handler for ListResourcesRequest") + + async def handler(_: Any): + resources = await func() + return types.ServerResult( + types.ListResourcesResult(resources=resources) + ) + + self.request_handlers[types.ListResourcesRequest] = handler + return func + + return decorator + + def list_resource_templates(self): + def decorator(func: Callable[[], Awaitable[list[types.ResourceTemplate]]]): + logger.debug("Registering handler for ListResourceTemplatesRequest") + + async def handler(_: Any): + templates = await func() + return types.ServerResult( + types.ListResourceTemplatesResult(resourceTemplates=templates) + ) + + self.request_handlers[types.ListResourceTemplatesRequest] = handler + return func + + return decorator + + def read_resource(self): + def decorator(func: Callable[[AnyUrl], Awaitable[str | bytes]]): + logger.debug("Registering handler for ReadResourceRequest") + + async def handler(req: types.ReadResourceRequest): + result = await func(req.params.uri) + match result: + case str(s): + content = types.TextResourceContents( + uri=req.params.uri, + text=s, + mimeType="text/plain", + ) + case bytes(b): + import base64 + + content = types.BlobResourceContents( + uri=req.params.uri, + blob=base64.urlsafe_b64encode(b).decode(), + mimeType="application/octet-stream", + ) + + return types.ServerResult( + types.ReadResourceResult( + contents=[content], + ) + ) + + self.request_handlers[types.ReadResourceRequest] = handler + return func + + return decorator + + def set_logging_level(self): + def decorator(func: Callable[[types.LoggingLevel], Awaitable[None]]): + logger.debug("Registering handler for SetLevelRequest") + + async def handler(req: types.SetLevelRequest): + await func(req.params.level) + return types.ServerResult(types.EmptyResult()) + + self.request_handlers[types.SetLevelRequest] = handler + return func + + return decorator + + def subscribe_resource(self): + def decorator(func: Callable[[AnyUrl], Awaitable[None]]): + logger.debug("Registering handler for SubscribeRequest") + + async def handler(req: types.SubscribeRequest): + await func(req.params.uri) + return types.ServerResult(types.EmptyResult()) + + self.request_handlers[types.SubscribeRequest] = handler + return func + + return decorator + + def unsubscribe_resource(self): + def decorator(func: Callable[[AnyUrl], Awaitable[None]]): + logger.debug("Registering handler for UnsubscribeRequest") + + async def handler(req: types.UnsubscribeRequest): + await func(req.params.uri) + return types.ServerResult(types.EmptyResult()) + + self.request_handlers[types.UnsubscribeRequest] = handler + return func + + return decorator + + def list_tools(self): + def decorator(func: Callable[[], Awaitable[list[types.Tool]]]): + logger.debug("Registering handler for ListToolsRequest") + + async def handler(_: Any): + tools = await func() + return types.ServerResult(types.ListToolsResult(tools=tools)) + + self.request_handlers[types.ListToolsRequest] = handler + return func + + return decorator + + def call_tool(self): + def decorator( + func: Callable[ + ..., + Awaitable[ + Sequence[ + types.TextContent | types.ImageContent | types.EmbeddedResource + ] + ], + ], + ): + logger.debug("Registering handler for CallToolRequest") + + async def handler(req: types.CallToolRequest): + try: + results = await func(req.params.name, (req.params.arguments or {})) + return types.ServerResult( + types.CallToolResult(content=list(results), isError=False) + ) + except Exception as e: + return types.ServerResult( + types.CallToolResult( + content=[types.TextContent(type="text", text=str(e))], + isError=True, + ) + ) + + self.request_handlers[types.CallToolRequest] = handler + return func + + return decorator + + def progress_notification(self): + def decorator( + func: Callable[[str | int, float, float | None], Awaitable[None]], + ): + logger.debug("Registering handler for ProgressNotification") + + async def handler(req: types.ProgressNotification): + await func( + req.params.progressToken, req.params.progress, req.params.total + ) + + self.notification_handlers[types.ProgressNotification] = handler + return func + + return decorator + + def completion(self): + """Provides completions for prompts and resource templates""" + + def decorator( + func: Callable[ + [ + types.PromptReference | types.ResourceReference, + types.CompletionArgument, + ], + Awaitable[types.Completion | None], + ], + ): + logger.debug("Registering handler for CompleteRequest") + + async def handler(req: types.CompleteRequest): + completion = await func(req.params.ref, req.params.argument) + return types.ServerResult( + types.CompleteResult( + completion=completion + if completion is not None + else types.Completion(values=[], total=None, hasMore=None), + ) + ) + + self.request_handlers[types.CompleteRequest] = handler + return func + + return decorator + + async def run( + self, + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], + write_stream: MemoryObjectSendStream[types.JSONRPCMessage], + initialization_options: InitializationOptions, + # When False, exceptions are returned as messages to the client. + # When True, exceptions are raised, which will cause the server to shut down + # but also make tracing exceptions much easier during testing and when using + # in-process servers. + raise_exceptions: bool = False, + ): + with warnings.catch_warnings(record=True) as w: + async with ServerSession( + read_stream, write_stream, initialization_options + ) as session: + async for message in session.incoming_messages: + logger.debug(f"Received message: {message}") + + match message: + case RequestResponder(request=types.ClientRequest(root=req)): + logger.info( + f"Processing request of type {type(req).__name__}" + ) + if type(req) in self.request_handlers: + handler = self.request_handlers[type(req)] + logger.debug( + f"Dispatching request of type {type(req).__name__}" + ) + + token = None + try: + # Set our global state that can be retrieved via + # app.get_request_context() + token = request_ctx.set( + RequestContext( + message.request_id, + message.request_meta, + session, + ) + ) + response = await handler(req) + except McpError as err: + response = err.error + except Exception as err: + if raise_exceptions: + raise err + response = types.ErrorData( + code=0, message=str(err), data=None + ) + finally: + # Reset the global state after we are done + if token is not None: + request_ctx.reset(token) + + await message.respond(response) + else: + await message.respond( + types.ErrorData( + code=types.METHOD_NOT_FOUND, + message="Method not found", + ) + ) + + logger.debug("Response sent") + case types.ClientNotification(root=notify): + if type(notify) in self.notification_handlers: + assert type(notify) in self.notification_handlers + + handler = self.notification_handlers[type(notify)] + logger.debug( + f"Dispatching notification of type " + f"{type(notify).__name__}" + ) + + try: + await handler(notify) + except Exception as err: + logger.error( + f"Uncaught exception in notification handler: " + f"{err}" + ) + + for warning in w: + logger.info( + f"Warning: {warning.category.__name__}: {warning.message}" + ) + + +async def _ping_handler(request: types.PingRequest) -> types.ServerResult: + return types.ServerResult(types.EmptyResult()) diff --git a/tests/conftest.py b/tests/conftest.py index 8d792aa..381e5db 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,3 +27,7 @@ def mcp_server() -> Server: ] return server + +@pytest.fixture +def anyio_backend(): + return 'asyncio' diff --git a/tests/server/fastmcp/__init__.py b/tests/server/fastmcp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/server/fastmcp/prompts/__init__.py b/tests/server/fastmcp/prompts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/server/fastmcp/prompts/test_base.py b/tests/server/fastmcp/prompts/test_base.py new file mode 100644 index 0000000..63dc230 --- /dev/null +++ b/tests/server/fastmcp/prompts/test_base.py @@ -0,0 +1,194 @@ +from pydantic import FileUrl +import pytest +from mcp.server.fastmcp.prompts.base import ( + Prompt, + UserMessage, + TextContent, + AssistantMessage, + Message, +) +from mcp.types import EmbeddedResource, TextResourceContents + + +class TestRenderPrompt: + async def test_basic_fn(self): + def fn() -> str: + return "Hello, world!" + + prompt = Prompt.from_function(fn) + assert await prompt.render() == [ + UserMessage(content=TextContent(type="text", text="Hello, world!")) + ] + + async def test_async_fn(self): + async def fn() -> str: + return "Hello, world!" + + prompt = Prompt.from_function(fn) + assert await prompt.render() == [ + UserMessage(content=TextContent(type="text", text="Hello, world!")) + ] + + async def test_fn_with_args(self): + async def fn(name: str, age: int = 30) -> str: + return f"Hello, {name}! You're {age} years old." + + prompt = Prompt.from_function(fn) + assert await prompt.render(arguments=dict(name="World")) == [ + UserMessage( + content=TextContent( + type="text", text="Hello, World! You're 30 years old." + ) + ) + ] + + async def test_fn_with_invalid_kwargs(self): + async def fn(name: str, age: int = 30) -> str: + return f"Hello, {name}! You're {age} years old." + + prompt = Prompt.from_function(fn) + with pytest.raises(ValueError): + await prompt.render(arguments=dict(age=40)) + + async def test_fn_returns_message(self): + async def fn() -> UserMessage: + return UserMessage(content="Hello, world!") + + prompt = Prompt.from_function(fn) + assert await prompt.render() == [ + UserMessage(content=TextContent(type="text", text="Hello, world!")) + ] + + async def test_fn_returns_assistant_message(self): + async def fn() -> AssistantMessage: + return AssistantMessage( + content=TextContent(type="text", text="Hello, world!") + ) + + prompt = Prompt.from_function(fn) + assert await prompt.render() == [ + AssistantMessage(content=TextContent(type="text", text="Hello, world!")) + ] + + async def test_fn_returns_multiple_messages(self): + expected = [ + UserMessage("Hello, world!"), + AssistantMessage("How can I help you today?"), + UserMessage("I'm looking for a restaurant in the center of town."), + ] + + async def fn() -> list[Message]: + return expected + + prompt = Prompt.from_function(fn) + assert await prompt.render() == expected + + async def test_fn_returns_list_of_strings(self): + expected = [ + "Hello, world!", + "I'm looking for a restaurant in the center of town.", + ] + + async def fn() -> list[str]: + return expected + + prompt = Prompt.from_function(fn) + assert await prompt.render() == [UserMessage(t) for t in expected] + + async def test_fn_returns_resource_content(self): + """Test returning a message with resource content.""" + + async def fn() -> UserMessage: + return UserMessage( + content=EmbeddedResource( + type="resource", + resource=TextResourceContents( + uri=FileUrl("file://file.txt"), + text="File contents", + mimeType="text/plain", + ), + ) + ) + + prompt = Prompt.from_function(fn) + assert await prompt.render() == [ + UserMessage( + content=EmbeddedResource( + type="resource", + resource=TextResourceContents( + uri=FileUrl("file://file.txt"), + text="File contents", + mimeType="text/plain", + ), + ) + ) + ] + + async def test_fn_returns_mixed_content(self): + """Test returning messages with mixed content types.""" + + async def fn() -> list[Message]: + return [ + UserMessage(content="Please analyze this file:"), + UserMessage( + content=EmbeddedResource( + type="resource", + resource=TextResourceContents( + uri=FileUrl("file://file.txt"), + text="File contents", + mimeType="text/plain", + ), + ) + ), + AssistantMessage(content="I'll help analyze that file."), + ] + + prompt = Prompt.from_function(fn) + assert await prompt.render() == [ + UserMessage( + content=TextContent(type="text", text="Please analyze this file:") + ), + UserMessage( + content=EmbeddedResource( + type="resource", + resource=TextResourceContents( + uri=FileUrl("file://file.txt"), + text="File contents", + mimeType="text/plain", + ), + ) + ), + AssistantMessage( + content=TextContent(type="text", text="I'll help analyze that file.") + ), + ] + + async def test_fn_returns_dict_with_resource(self): + """Test returning a dict with resource content.""" + + async def fn() -> dict: + return { + "role": "user", + "content": { + "type": "resource", + "resource": { + "uri": FileUrl("file://file.txt"), + "text": "File contents", + "mimeType": "text/plain", + }, + }, + } + + prompt = Prompt.from_function(fn) + assert await prompt.render() == [ + UserMessage( + content=EmbeddedResource( + type="resource", + resource=TextResourceContents( + uri=FileUrl("file://file.txt"), + text="File contents", + mimeType="text/plain", + ), + ) + ) + ] diff --git a/tests/server/fastmcp/prompts/test_manager.py b/tests/server/fastmcp/prompts/test_manager.py new file mode 100644 index 0000000..7b97b30 --- /dev/null +++ b/tests/server/fastmcp/prompts/test_manager.py @@ -0,0 +1,107 @@ +import pytest +from mcp.server.fastmcp.prompts.base import UserMessage, TextContent, Prompt +from mcp.server.fastmcp.prompts.manager import PromptManager + + +class TestPromptManager: + def test_add_prompt(self): + """Test adding a prompt to the manager.""" + + def fn() -> str: + return "Hello, world!" + + manager = PromptManager() + prompt = Prompt.from_function(fn) + added = manager.add_prompt(prompt) + assert added == prompt + assert manager.get_prompt("fn") == prompt + + def test_add_duplicate_prompt(self, caplog): + """Test adding the same prompt twice.""" + + def fn() -> str: + return "Hello, world!" + + manager = PromptManager() + prompt = Prompt.from_function(fn) + first = manager.add_prompt(prompt) + second = manager.add_prompt(prompt) + assert first == second + assert "Prompt already exists" in caplog.text + + def test_disable_warn_on_duplicate_prompts(self, caplog): + """Test disabling warning on duplicate prompts.""" + + def fn() -> str: + return "Hello, world!" + + manager = PromptManager(warn_on_duplicate_prompts=False) + prompt = Prompt.from_function(fn) + first = manager.add_prompt(prompt) + second = manager.add_prompt(prompt) + assert first == second + assert "Prompt already exists" not in caplog.text + + def test_list_prompts(self): + """Test listing all prompts.""" + + def fn1() -> str: + return "Hello, world!" + + def fn2() -> str: + return "Goodbye, world!" + + manager = PromptManager() + prompt1 = Prompt.from_function(fn1) + prompt2 = Prompt.from_function(fn2) + manager.add_prompt(prompt1) + manager.add_prompt(prompt2) + prompts = manager.list_prompts() + assert len(prompts) == 2 + assert prompts == [prompt1, prompt2] + + async def test_render_prompt(self): + """Test rendering a prompt.""" + + def fn() -> str: + return "Hello, world!" + + manager = PromptManager() + prompt = Prompt.from_function(fn) + manager.add_prompt(prompt) + messages = await manager.render_prompt("fn") + assert messages == [ + UserMessage(content=TextContent(type="text", text="Hello, world!")) + ] + + async def test_render_prompt_with_args(self): + """Test rendering a prompt with arguments.""" + + def fn(name: str) -> str: + return f"Hello, {name}!" + + manager = PromptManager() + prompt = Prompt.from_function(fn) + manager.add_prompt(prompt) + messages = await manager.render_prompt("fn", arguments={"name": "World"}) + assert messages == [ + UserMessage(content=TextContent(type="text", text="Hello, World!")) + ] + + async def test_render_unknown_prompt(self): + """Test rendering a non-existent prompt.""" + manager = PromptManager() + with pytest.raises(ValueError, match="Unknown prompt: unknown"): + await manager.render_prompt("unknown") + + async def test_render_prompt_with_missing_args(self): + """Test rendering a prompt with missing required arguments.""" + + def fn(name: str) -> str: + return f"Hello, {name}!" + + manager = PromptManager() + prompt = Prompt.from_function(fn) + manager.add_prompt(prompt) + with pytest.raises(ValueError, match="Missing required arguments"): + await manager.render_prompt("fn") diff --git a/tests/server/fastmcp/resources/__init__.py b/tests/server/fastmcp/resources/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/server/fastmcp/resources/test_file_resources.py b/tests/server/fastmcp/resources/test_file_resources.py new file mode 100644 index 0000000..f9ec159 --- /dev/null +++ b/tests/server/fastmcp/resources/test_file_resources.py @@ -0,0 +1,115 @@ +import os + +import pytest +from pathlib import Path +from tempfile import NamedTemporaryFile +from pydantic import FileUrl + +from mcp.server.fastmcp.resources import FileResource + + +@pytest.fixture +def temp_file(): + """Create a temporary file for testing. + + File is automatically cleaned up after the test if it still exists. + """ + content = "test content" + with NamedTemporaryFile(mode="w", delete=False) as f: + f.write(content) + path = Path(f.name).resolve() + yield path + try: + path.unlink() + except FileNotFoundError: + pass # File was already deleted by the test + + +class TestFileResource: + """Test FileResource functionality.""" + + def test_file_resource_creation(self, temp_file: Path): + """Test creating a FileResource.""" + resource = FileResource( + uri=FileUrl(temp_file.as_uri()), + name="test", + description="test file", + path=temp_file, + ) + assert str(resource.uri) == temp_file.as_uri() + assert resource.name == "test" + assert resource.description == "test file" + assert resource.mime_type == "text/plain" # default + assert resource.path == temp_file + assert resource.is_binary is False # default + + def test_file_resource_str_path_conversion(self, temp_file: Path): + """Test FileResource handles string paths.""" + resource = FileResource( + uri=FileUrl(f"file://{temp_file}"), + name="test", + path=Path(str(temp_file)), + ) + assert isinstance(resource.path, Path) + assert resource.path.is_absolute() + + async def test_read_text_file(self, temp_file: Path): + """Test reading a text file.""" + resource = FileResource( + uri=FileUrl(f"file://{temp_file}"), + name="test", + path=temp_file, + ) + content = await resource.read() + assert content == "test content" + assert resource.mime_type == "text/plain" + + async def test_read_binary_file(self, temp_file: Path): + """Test reading a file as binary.""" + resource = FileResource( + uri=FileUrl(f"file://{temp_file}"), + name="test", + path=temp_file, + is_binary=True, + ) + content = await resource.read() + assert isinstance(content, bytes) + assert content == b"test content" + + def test_relative_path_error(self): + """Test error on relative path.""" + with pytest.raises(ValueError, match="Path must be absolute"): + FileResource( + uri=FileUrl("file:///test.txt"), + name="test", + path=Path("test.txt"), + ) + + async def test_missing_file_error(self, temp_file: Path): + """Test error when file doesn't exist.""" + # Create path to non-existent file + missing = temp_file.parent / "missing.txt" + resource = FileResource( + uri=FileUrl("file:///missing.txt"), + name="test", + path=missing, + ) + with pytest.raises(ValueError, match="Error reading file"): + await resource.read() + + @pytest.mark.skipif( + os.name == "nt", reason="File permissions behave differently on Windows" + ) + async def test_permission_error(self, temp_file: Path): + """Test reading a file without permissions.""" + temp_file.chmod(0o000) # Remove all permissions + try: + resource = FileResource( + uri=FileUrl(temp_file.as_uri()), + name="test", + path=temp_file, + ) + with pytest.raises(ValueError, match="Error reading file"): + await resource.read() + finally: + temp_file.chmod(0o644) # Restore permissions diff --git a/tests/server/fastmcp/resources/test_function_resources.py b/tests/server/fastmcp/resources/test_function_resources.py new file mode 100644 index 0000000..e132e5f --- /dev/null +++ b/tests/server/fastmcp/resources/test_function_resources.py @@ -0,0 +1,115 @@ +from pydantic import BaseModel, AnyUrl +import pytest +from mcp.server.fastmcp.resources import FunctionResource + + +class TestFunctionResource: + """Test FunctionResource functionality.""" + + def test_function_resource_creation(self): + """Test creating a FunctionResource.""" + + def my_func() -> str: + return "test content" + + resource = FunctionResource( + uri=AnyUrl("fn://test"), + name="test", + description="test function", + fn=my_func, + ) + assert str(resource.uri) == "fn://test" + assert resource.name == "test" + assert resource.description == "test function" + assert resource.mime_type == "text/plain" # default + assert resource.fn == my_func + + async def test_read_text(self): + """Test reading text from a FunctionResource.""" + + def get_data() -> str: + return "Hello, world!" + + resource = FunctionResource( + uri=AnyUrl("function://test"), + name="test", + fn=get_data, + ) + content = await resource.read() + assert content == "Hello, world!" + assert resource.mime_type == "text/plain" + + async def test_read_binary(self): + """Test reading binary data from a FunctionResource.""" + + def get_data() -> bytes: + return b"Hello, world!" + + resource = FunctionResource( + uri=AnyUrl("function://test"), + name="test", + fn=get_data, + ) + content = await resource.read() + assert content == b"Hello, world!" + + async def test_json_conversion(self): + """Test automatic JSON conversion of non-string results.""" + + def get_data() -> dict: + return {"key": "value"} + + resource = FunctionResource( + uri=AnyUrl("function://test"), + name="test", + fn=get_data, + ) + content = await resource.read() + assert isinstance(content, str) + assert '"key": "value"' in content + + async def test_error_handling(self): + """Test error handling in FunctionResource.""" + + def failing_func() -> str: + raise ValueError("Test error") + + resource = FunctionResource( + uri=AnyUrl("function://test"), + name="test", + fn=failing_func, + ) + with pytest.raises(ValueError, match="Error reading resource function://test"): + await resource.read() + + async def test_basemodel_conversion(self): + """Test handling of BaseModel types.""" + + class MyModel(BaseModel): + name: str + + resource = FunctionResource( + uri=AnyUrl("function://test"), + name="test", + fn=lambda: MyModel(name="test"), + ) + content = await resource.read() + assert content == '{"name": "test"}' + + async def test_custom_type_conversion(self): + """Test handling of custom types.""" + + class CustomData: + def __str__(self) -> str: + return "custom data" + + def get_data() -> CustomData: + return CustomData() + + resource = FunctionResource( + uri=AnyUrl("function://test"), + name="test", + fn=get_data, + ) + content = await resource.read() + assert isinstance(content, str) diff --git a/tests/server/fastmcp/resources/test_resource_manager.py b/tests/server/fastmcp/resources/test_resource_manager.py new file mode 100644 index 0000000..d8d04e5 --- /dev/null +++ b/tests/server/fastmcp/resources/test_resource_manager.py @@ -0,0 +1,137 @@ +import pytest +from pathlib import Path +from tempfile import NamedTemporaryFile +from pydantic import AnyUrl, FileUrl + +from mcp.server.fastmcp.resources import ( + FileResource, + FunctionResource, + ResourceManager, + ResourceTemplate, +) + + +@pytest.fixture +def temp_file(): + """Create a temporary file for testing. + + File is automatically cleaned up after the test if it still exists. + """ + content = "test content" + with NamedTemporaryFile(mode="w", delete=False) as f: + f.write(content) + path = Path(f.name).resolve() + yield path + try: + path.unlink() + except FileNotFoundError: + pass # File was already deleted by the test + + +class TestResourceManager: + """Test ResourceManager functionality.""" + + def test_add_resource(self, temp_file: Path): + """Test adding a resource.""" + manager = ResourceManager() + resource = FileResource( + uri=FileUrl(f"file://{temp_file}"), + name="test", + path=temp_file, + ) + added = manager.add_resource(resource) + assert added == resource + assert manager.list_resources() == [resource] + + def test_add_duplicate_resource(self, temp_file: Path): + """Test adding the same resource twice.""" + manager = ResourceManager() + resource = FileResource( + uri=FileUrl(f"file://{temp_file}"), + name="test", + path=temp_file, + ) + first = manager.add_resource(resource) + second = manager.add_resource(resource) + assert first == second + assert manager.list_resources() == [resource] + + def test_warn_on_duplicate_resources(self, temp_file: Path, caplog): + """Test warning on duplicate resources.""" + manager = ResourceManager() + resource = FileResource( + uri=FileUrl(f"file://{temp_file}"), + name="test", + path=temp_file, + ) + manager.add_resource(resource) + manager.add_resource(resource) + assert "Resource already exists" in caplog.text + + def test_disable_warn_on_duplicate_resources(self, temp_file: Path, caplog): + """Test disabling warning on duplicate resources.""" + manager = ResourceManager(warn_on_duplicate_resources=False) + resource = FileResource( + uri=FileUrl(f"file://{temp_file}"), + name="test", + path=temp_file, + ) + manager.add_resource(resource) + manager.add_resource(resource) + assert "Resource already exists" not in caplog.text + + async def test_get_resource(self, temp_file: Path): + """Test getting a resource by URI.""" + manager = ResourceManager() + resource = FileResource( + uri=FileUrl(f"file://{temp_file}"), + name="test", + path=temp_file, + ) + manager.add_resource(resource) + retrieved = await manager.get_resource(resource.uri) + assert retrieved == resource + + async def test_get_resource_from_template(self): + """Test getting a resource through a template.""" + manager = ResourceManager() + + def greet(name: str) -> str: + return f"Hello, {name}!" + + template = ResourceTemplate.from_function( + fn=greet, + uri_template="greet://{name}", + name="greeter", + ) + manager._templates[template.uri_template] = template + + resource = await manager.get_resource(AnyUrl("greet://world")) + assert isinstance(resource, FunctionResource) + content = await resource.read() + assert content == "Hello, world!" + + async def test_get_unknown_resource(self): + """Test getting a non-existent resource.""" + manager = ResourceManager() + with pytest.raises(ValueError, match="Unknown resource"): + await manager.get_resource(AnyUrl("unknown://test")) + + def test_list_resources(self, temp_file: Path): + """Test listing all resources.""" + manager = ResourceManager() + resource1 = FileResource( + uri=FileUrl(f"file://{temp_file}"), + name="test1", + path=temp_file, + ) + resource2 = FileResource( + uri=FileUrl(f"file://{temp_file}2"), + name="test2", + path=temp_file, + ) + manager.add_resource(resource1) + manager.add_resource(resource2) + resources = manager.list_resources() + assert len(resources) == 2 + assert resources == [resource1, resource2] diff --git a/tests/server/fastmcp/resources/test_resource_template.py b/tests/server/fastmcp/resources/test_resource_template.py new file mode 100644 index 0000000..95d0585 --- /dev/null +++ b/tests/server/fastmcp/resources/test_resource_template.py @@ -0,0 +1,181 @@ +import json +import pytest +from pydantic import BaseModel + +from mcp.server.fastmcp.resources import FunctionResource, ResourceTemplate + + +class TestResourceTemplate: + """Test ResourceTemplate functionality.""" + + def test_template_creation(self): + """Test creating a template from a function.""" + + def my_func(key: str, value: int) -> dict: + return {"key": key, "value": value} + + template = ResourceTemplate.from_function( + fn=my_func, + uri_template="test://{key}/{value}", + name="test", + ) + assert template.uri_template == "test://{key}/{value}" + assert template.name == "test" + assert template.mime_type == "text/plain" # default + test_input = {"key": "test", "value": 42} + assert template.fn(**test_input) == my_func(**test_input) + + def test_template_matches(self): + """Test matching URIs against a template.""" + + def my_func(key: str, value: int) -> dict: + return {"key": key, "value": value} + + template = ResourceTemplate.from_function( + fn=my_func, + uri_template="test://{key}/{value}", + name="test", + ) + + # Valid match + params = template.matches("test://foo/123") + assert params == {"key": "foo", "value": "123"} + + # No match + assert template.matches("test://foo") is None + assert template.matches("other://foo/123") is None + + async def test_create_resource(self): + """Test creating a resource from a template.""" + + def my_func(key: str, value: int) -> dict: + return {"key": key, "value": value} + + template = ResourceTemplate.from_function( + fn=my_func, + uri_template="test://{key}/{value}", + name="test", + ) + + resource = await template.create_resource( + "test://foo/123", + {"key": "foo", "value": 123}, + ) + + assert isinstance(resource, FunctionResource) + content = await resource.read() + assert isinstance(content, str) + data = json.loads(content) + assert data == {"key": "foo", "value": 123} + + async def test_template_error(self): + """Test error handling in template resource creation.""" + + def failing_func(x: str) -> str: + raise ValueError("Test error") + + template = ResourceTemplate.from_function( + fn=failing_func, + uri_template="fail://{x}", + name="fail", + ) + + with pytest.raises(ValueError, match="Error creating resource from template"): + await template.create_resource("fail://test", {"x": "test"}) + + async def test_async_text_resource(self): + """Test creating a text resource from async function.""" + + async def greet(name: str) -> str: + return f"Hello, {name}!" + + template = ResourceTemplate.from_function( + fn=greet, + uri_template="greet://{name}", + name="greeter", + ) + + resource = await template.create_resource( + "greet://world", + {"name": "world"}, + ) + + assert isinstance(resource, FunctionResource) + content = await resource.read() + assert content == "Hello, world!" + + async def test_async_binary_resource(self): + """Test creating a binary resource from async function.""" + + async def get_bytes(value: str) -> bytes: + return value.encode() + + template = ResourceTemplate.from_function( + fn=get_bytes, + uri_template="bytes://{value}", + name="bytes", + ) + + resource = await template.create_resource( + "bytes://test", + {"value": "test"}, + ) + + assert isinstance(resource, FunctionResource) + content = await resource.read() + assert content == b"test" + + async def test_basemodel_conversion(self): + """Test handling of BaseModel types.""" + + class MyModel(BaseModel): + key: str + value: int + + def get_data(key: str, value: int) -> MyModel: + return MyModel(key=key, value=value) + + template = ResourceTemplate.from_function( + fn=get_data, + uri_template="test://{key}/{value}", + name="test", + ) + + resource = await template.create_resource( + "test://foo/123", + {"key": "foo", "value": 123}, + ) + + assert isinstance(resource, FunctionResource) + content = await resource.read() + assert isinstance(content, str) + data = json.loads(content) + assert data == {"key": "foo", "value": 123} + + async def test_custom_type_conversion(self): + """Test handling of custom types.""" + + class CustomData: + def __init__(self, value: str): + self.value = value + + def __str__(self) -> str: + return self.value + + def get_data(value: str) -> CustomData: + return CustomData(value) + + template = ResourceTemplate.from_function( + fn=get_data, + uri_template="test://{value}", + name="test", + ) + + resource = await template.create_resource( + "test://hello", + {"value": "hello"}, + ) + + assert isinstance(resource, FunctionResource) + content = await resource.read() + assert content == "hello" diff --git a/tests/server/fastmcp/resources/test_resources.py b/tests/server/fastmcp/resources/test_resources.py new file mode 100644 index 0000000..dddcd56 --- /dev/null +++ b/tests/server/fastmcp/resources/test_resources.py @@ -0,0 +1,100 @@ +import pytest +from pydantic import AnyUrl + +from mcp.server.fastmcp.resources import FunctionResource, Resource + + +class TestResourceValidation: + """Test base Resource validation.""" + + def test_resource_uri_validation(self): + """Test URI validation.""" + + def dummy_func() -> str: + return "data" + + # Valid URI + resource = FunctionResource( + uri=AnyUrl("http://example.com/data"), + name="test", + fn=dummy_func, + ) + assert str(resource.uri) == "http://example.com/data" + + # Missing protocol + with pytest.raises(ValueError, match="Input should be a valid URL"): + FunctionResource( + uri=AnyUrl("invalid"), + name="test", + fn=dummy_func, + ) + + # Missing host + with pytest.raises(ValueError, match="Input should be a valid URL"): + FunctionResource( + uri=AnyUrl("http://"), + name="test", + fn=dummy_func, + ) + + def test_resource_name_from_uri(self): + """Test name is extracted from URI if not provided.""" + + def dummy_func() -> str: + return "data" + + resource = FunctionResource( + uri=AnyUrl("resource://my-resource"), + fn=dummy_func, + ) + assert resource.name == "resource://my-resource" + + def test_resource_name_validation(self): + """Test name validation.""" + + def dummy_func() -> str: + return "data" + + # Must provide either name or URI + with pytest.raises(ValueError, match="Either name or uri must be provided"): + FunctionResource( + fn=dummy_func, + ) + + # Explicit name takes precedence over URI + resource = FunctionResource( + uri=AnyUrl("resource://uri-name"), + name="explicit-name", + fn=dummy_func, + ) + assert resource.name == "explicit-name" + + def test_resource_mime_type(self): + """Test mime type handling.""" + + def dummy_func() -> str: + return "data" + + # Default mime type + resource = FunctionResource( + uri=AnyUrl("resource://test"), + fn=dummy_func, + ) + assert resource.mime_type == "text/plain" + + # Custom mime type + resource = FunctionResource( + uri=AnyUrl("resource://test"), + fn=dummy_func, + mime_type="application/json", + ) + assert resource.mime_type == "application/json" + + async def test_resource_read_abstract(self): + """Test that Resource.read() is abstract.""" + + class ConcreteResource(Resource): + pass + + with pytest.raises(TypeError, match="abstract method"): + ConcreteResource(uri=AnyUrl("test://test"), name="test") # type: ignore diff --git a/tests/server/fastmcp/servers/__init__.py b/tests/server/fastmcp/servers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/server/fastmcp/servers/test_file_server.py b/tests/server/fastmcp/servers/test_file_server.py new file mode 100644 index 0000000..d9b34d6 --- /dev/null +++ b/tests/server/fastmcp/servers/test_file_server.py @@ -0,0 +1,114 @@ +import json +from mcp.server.fastmcp import FastMCP +import pytest +from pathlib import Path + + +@pytest.fixture() +def test_dir(tmp_path_factory) -> Path: + """Create a temporary directory with test files.""" + tmp = tmp_path_factory.mktemp("test_files") + + # Create test files + (tmp / "example.py").write_text("print('hello world')") + (tmp / "readme.md").write_text("# Test Directory\nThis is a test.") + (tmp / "config.json").write_text('{"test": true}') + + return tmp + + +@pytest.fixture +def mcp() -> FastMCP: + mcp = FastMCP() + + return mcp + + +@pytest.fixture(autouse=True) +def resources(mcp: FastMCP, test_dir: Path) -> FastMCP: + @mcp.resource("dir://test_dir") + def list_test_dir() -> list[str]: + """List the files in the test directory""" + return [str(f) for f in test_dir.iterdir()] + + @mcp.resource("file://test_dir/example.py") + def read_example_py() -> str: + """Read the example.py file""" + try: + return (test_dir / "example.py").read_text() + except FileNotFoundError: + return "File not found" + + @mcp.resource("file://test_dir/readme.md") + def read_readme_md() -> str: + """Read the readme.md file""" + try: + return (test_dir / "readme.md").read_text() + except FileNotFoundError: + return "File not found" + + @mcp.resource("file://test_dir/config.json") + def read_config_json() -> str: + """Read the config.json file""" + try: + return (test_dir / "config.json").read_text() + except FileNotFoundError: + return "File not found" + + return mcp + + +@pytest.fixture(autouse=True) +def tools(mcp: FastMCP, test_dir: Path) -> FastMCP: + @mcp.tool() + def delete_file(path: str) -> bool: + # ensure path is in test_dir + if Path(path).resolve().parent != test_dir: + raise ValueError(f"Path must be in test_dir: {path}") + Path(path).unlink() + return True + + return mcp + + +async def test_list_resources(mcp: FastMCP): + resources = await mcp.list_resources() + assert len(resources) == 4 + + assert [str(r.uri) for r in resources] == [ + "dir://test_dir", + "file://test_dir/example.py", + "file://test_dir/readme.md", + "file://test_dir/config.json", + ] + + +async def test_read_resource_dir(mcp: FastMCP): + files = await mcp.read_resource("dir://test_dir") + files = json.loads(files) + + assert sorted([Path(f).name for f in files]) == [ + "config.json", + "example.py", + "readme.md", + ] + + +async def test_read_resource_file(mcp: FastMCP): + result = await mcp.read_resource("file://test_dir/example.py") + assert result == "print('hello world')" + + +async def test_delete_file(mcp: FastMCP, test_dir: Path): + await mcp.call_tool( + "delete_file", arguments=dict(path=str(test_dir / "example.py")) + ) + assert not (test_dir / "example.py").exists() + + +async def test_delete_file_and_check_resources(mcp: FastMCP, test_dir: Path): + await mcp.call_tool( + "delete_file", arguments=dict(path=str(test_dir / "example.py")) + ) + result = await mcp.read_resource("file://test_dir/example.py") + assert result == "File not found" diff --git a/tests/server/fastmcp/test_func_metadata.py b/tests/server/fastmcp/test_func_metadata.py new file mode 100644 index 0000000..b329c90 --- /dev/null +++ b/tests/server/fastmcp/test_func_metadata.py @@ -0,0 +1,361 @@ +from typing import Annotated + +import annotated_types +import pytest +from pydantic import BaseModel, Field + +from mcp.server.fastmcp.utilities.func_metadata import func_metadata + + +class SomeInputModelA(BaseModel): + pass + + +class SomeInputModelB(BaseModel): + class InnerModel(BaseModel): + x: int + + how_many_shrimp: Annotated[int, Field(description="How many shrimp in the tank???")] + ok: InnerModel + y: None + + +def complex_arguments_fn( + an_int: int, + must_be_none: None, + must_be_none_dumb_annotation: Annotated[None, "blah"], + list_of_ints: list[int], + # list[str] | str is an interesting case because if it comes in as JSON like + # "[\"a\", \"b\"]" then it will be naively parsed as a string. + list_str_or_str: list[str] | str, + an_int_annotated_with_field: Annotated[ + int, Field(description="An int with a field") + ], + an_int_annotated_with_field_and_others: Annotated[ + int, + str, # Should be ignored, really + Field(description="An int with a field"), + annotated_types.Gt(1), + ], + an_int_annotated_with_junk: Annotated[ + int, + "123", + 456, + ], + field_with_default_via_field_annotation_before_nondefault_arg: Annotated[ + int, Field(1) + ], + unannotated, + my_model_a: SomeInputModelA, + my_model_a_forward_ref: "SomeInputModelA", + my_model_b: SomeInputModelB, + an_int_annotated_with_field_default: Annotated[ + int, + Field(1, description="An int with a field"), + ], + unannotated_with_default=5, + my_model_a_with_default: SomeInputModelA = SomeInputModelA(), # noqa: B008 + an_int_with_default: int = 1, + must_be_none_with_default: None = None, + an_int_with_equals_field: int = Field(1, ge=0), + int_annotated_with_default: Annotated[int, Field(description="hey")] = 5, +) -> str: + _ = ( + an_int, + must_be_none, + must_be_none_dumb_annotation, + list_of_ints, + list_str_or_str, + an_int_annotated_with_field, + an_int_annotated_with_field_and_others, + an_int_annotated_with_junk, + field_with_default_via_field_annotation_before_nondefault_arg, + unannotated, + an_int_annotated_with_field_default, + unannotated_with_default, + my_model_a, + my_model_a_forward_ref, + my_model_b, + my_model_a_with_default, + an_int_with_default, + must_be_none_with_default, + an_int_with_equals_field, + int_annotated_with_default, + ) + return "ok!" + + +async def test_complex_function_runtime_arg_validation_non_json(): + """Test that basic non-JSON arguments are validated correctly""" + meta = func_metadata(complex_arguments_fn) + + # Test with minimum required arguments + result = await meta.call_fn_with_arg_validation( + complex_arguments_fn, + fn_is_async=False, + arguments_to_validate={ + "an_int": 1, + "must_be_none": None, + "must_be_none_dumb_annotation": None, + "list_of_ints": [1, 2, 3], + "list_str_or_str": "hello", + "an_int_annotated_with_field": 42, + "an_int_annotated_with_field_and_others": 5, + "an_int_annotated_with_junk": 100, + "unannotated": "test", + "my_model_a": {}, + "my_model_a_forward_ref": {}, + "my_model_b": {"how_many_shrimp": 5, "ok": {"x": 1}, "y": None}, + }, + arguments_to_pass_directly=None, + ) + assert result == "ok!" + + # Test with invalid types + with pytest.raises(ValueError): + await meta.call_fn_with_arg_validation( + complex_arguments_fn, + fn_is_async=False, + arguments_to_validate={"an_int": "not an int"}, + arguments_to_pass_directly=None, + ) + + +async def test_complex_function_runtime_arg_validation_with_json(): + """Test that JSON string arguments are parsed and validated correctly""" + meta = func_metadata(complex_arguments_fn) + + result = await meta.call_fn_with_arg_validation( + complex_arguments_fn, + fn_is_async=False, + arguments_to_validate={ + "an_int": 1, + "must_be_none": None, + "must_be_none_dumb_annotation": None, + "list_of_ints": "[1, 2, 3]", # JSON string + "list_str_or_str": '["a", "b", "c"]', # JSON string + "an_int_annotated_with_field": 42, + "an_int_annotated_with_field_and_others": "5", # JSON string + "an_int_annotated_with_junk": 100, + "unannotated": "test", + "my_model_a": "{}", # JSON string + "my_model_a_forward_ref": "{}", # JSON string + "my_model_b": '{"how_many_shrimp": 5, "ok": {"x": 1}, "y": null}', # JSON string + }, + arguments_to_pass_directly=None, + ) + assert result == "ok!" + + +def test_str_vs_list_str(): + """Test handling of string vs list[str] type annotations. + + This is tricky as '"hello"' can be parsed as a JSON string or a Python string. + We want to make sure it's kept as a python string. + """ + + def func_with_str_types(str_or_list: str | list[str]): + return str_or_list + + meta = func_metadata(func_with_str_types) + + # Test string input for union type + result = meta.pre_parse_json({"str_or_list": "hello"}) + assert result["str_or_list"] == "hello" + + # Test string input that contains valid JSON for union type + # We want to see here that the JSON-vali string is NOT parsed as JSON, but rather + # kept as a raw string + result = meta.pre_parse_json({"str_or_list": '"hello"'}) + assert result["str_or_list"] == '"hello"' + + # Test list input for union type + result = meta.pre_parse_json({"str_or_list": '["hello", "world"]'}) + assert result["str_or_list"] == ["hello", "world"] + + +def test_skip_names(): + """Test that skipped parameters are not included in the model""" + + def func_with_many_params( + keep_this: int, skip_this: str, also_keep: float, also_skip: bool + ): + return keep_this, skip_this, also_keep, also_skip + + # Skip some parameters + meta = func_metadata(func_with_many_params, skip_names=["skip_this", "also_skip"]) + + # Check model fields + assert "keep_this" in meta.arg_model.model_fields + assert "also_keep" in meta.arg_model.model_fields + assert "skip_this" not in meta.arg_model.model_fields + assert "also_skip" not in meta.arg_model.model_fields + + # Validate that we can call with only non-skipped parameters + model: BaseModel = meta.arg_model.model_validate({"keep_this": 1, "also_keep": 2.5}) # type: ignore + assert model.keep_this == 1 # type: ignore + assert model.also_keep == 2.5 # type: ignore + + +async def test_lambda_function(): + """Test lambda function schema and validation""" + fn = lambda x, y=5: x # noqa: E731 + meta = func_metadata(lambda x, y=5: x) + + # Test schema + assert meta.arg_model.model_json_schema() == { + "properties": { + "x": {"title": "x", "type": "string"}, + "y": {"default": 5, "title": "y", "type": "string"}, + }, + "required": ["x"], + "title": "Arguments", + "type": "object", + } + + async def check_call(args): + return await meta.call_fn_with_arg_validation( + fn, + fn_is_async=False, + arguments_to_validate=args, + arguments_to_pass_directly=None, + ) + + # Basic calls + assert await check_call({"x": "hello"}) == "hello" + assert await check_call({"x": "hello", "y": "world"}) == "hello" + assert await check_call({"x": '"hello"'}) == '"hello"' + + # Missing required arg + with pytest.raises(ValueError): + await check_call({"y": "world"}) + + +def test_complex_function_json_schema(): + meta = func_metadata(complex_arguments_fn) + assert meta.arg_model.model_json_schema() == { + "$defs": { + "InnerModel": { + "properties": {"x": {"title": "X", "type": "integer"}}, + "required": ["x"], + "title": "InnerModel", + "type": "object", + }, + "SomeInputModelA": { + "properties": {}, + "title": "SomeInputModelA", + "type": "object", + }, + "SomeInputModelB": { + "properties": { + "how_many_shrimp": { + "description": "How many shrimp in the tank???", + "title": "How Many Shrimp", + "type": "integer", + }, + "ok": {"$ref": "#/$defs/InnerModel"}, + "y": {"title": "Y", "type": "null"}, + }, + "required": ["how_many_shrimp", "ok", "y"], + "title": "SomeInputModelB", + "type": "object", + }, + }, + "properties": { + "an_int": {"title": "An Int", "type": "integer"}, + "must_be_none": {"title": "Must Be None", "type": "null"}, + "must_be_none_dumb_annotation": { + "title": "Must Be None Dumb Annotation", + "type": "null", + }, + "list_of_ints": { + "items": {"type": "integer"}, + "title": "List Of Ints", + "type": "array", + }, + "list_str_or_str": { + "anyOf": [ + {"items": {"type": "string"}, "type": "array"}, + {"type": "string"}, + ], + "title": "List Str Or Str", + }, + "an_int_annotated_with_field": { + "description": "An int with a field", + "title": "An Int Annotated With Field", + "type": "integer", + }, + "an_int_annotated_with_field_and_others": { + "description": "An int with a field", + "exclusiveMinimum": 1, + "title": "An Int Annotated With Field And Others", + "type": "integer", + }, + "an_int_annotated_with_junk": { + "title": "An Int Annotated With Junk", + "type": "integer", + }, + "field_with_default_via_field_annotation_before_nondefault_arg": { + "default": 1, + "title": "Field With Default Via Field Annotation Before Nondefault Arg", + "type": "integer", + }, + "unannotated": {"title": "unannotated", "type": "string"}, + "my_model_a": {"$ref": "#/$defs/SomeInputModelA"}, + "my_model_a_forward_ref": {"$ref": "#/$defs/SomeInputModelA"}, + "my_model_b": {"$ref": "#/$defs/SomeInputModelB"}, + "an_int_annotated_with_field_default": { + "default": 1, + "description": "An int with a field", + "title": "An Int Annotated With Field Default", + "type": "integer", + }, + "unannotated_with_default": { + "default": 5, + "title": "unannotated_with_default", + "type": "string", + }, + "my_model_a_with_default": { + "$ref": "#/$defs/SomeInputModelA", + "default": {}, + }, + "an_int_with_default": { + "default": 1, + "title": "An Int With Default", + "type": "integer", + }, + "must_be_none_with_default": { + "default": None, + "title": "Must Be None With Default", + "type": "null", + }, + "an_int_with_equals_field": { + "default": 1, + "minimum": 0, + "title": "An Int With Equals Field", + "type": "integer", + }, + "int_annotated_with_default": { + "default": 5, + "description": "hey", + "title": "Int Annotated With Default", + "type": "integer", + }, + }, + "required": [ + "an_int", + "must_be_none", + "must_be_none_dumb_annotation", + "list_of_ints", + "list_str_or_str", + "an_int_annotated_with_field", + "an_int_annotated_with_field_and_others", + "an_int_annotated_with_junk", + "unannotated", + "my_model_a", + "my_model_a_forward_ref", + "my_model_b", + ], + "title": "complex_arguments_fnArguments", + "type": "object", + } diff --git a/tests/server/fastmcp/test_server.py b/tests/server/fastmcp/test_server.py new file mode 100644 index 0000000..27e13f2 --- /dev/null +++ b/tests/server/fastmcp/test_server.py @@ -0,0 +1,656 @@ +import base64 +from pathlib import Path +from typing import TYPE_CHECKING, Union + +import pytest +from mcp.shared.exceptions import McpError +from mcp.shared.memory import ( + create_connected_server_and_client_session as client_session, +) +from mcp.types import ( + ImageContent, + TextContent, + TextResourceContents, + BlobResourceContents, +) +from pydantic import AnyUrl + +from mcp.server.fastmcp import Context, FastMCP +from mcp.server.fastmcp.prompts.base import EmbeddedResource, Message, UserMessage +from mcp.server.fastmcp.resources import FileResource, FunctionResource +from mcp.server.fastmcp.utilities.types import Image + +if TYPE_CHECKING: + from mcp.server.fastmcp import Context + + +class TestServer: + async def test_create_server(self): + mcp = FastMCP() + assert mcp.name == "FastMCP" + + async def test_add_tool_decorator(self): + mcp = FastMCP() + + @mcp.tool() + def add(x: int, y: int) -> int: + return x + y + + assert len(mcp._tool_manager.list_tools()) == 1 + + async def test_add_tool_decorator_incorrect_usage(self): + mcp = FastMCP() + + with pytest.raises(TypeError, match="The @tool decorator was used incorrectly"): + + @mcp.tool # Missing parentheses #type: ignore + def add(x: int, y: int) -> int: + return x + y + + async def test_add_resource_decorator(self): + mcp = FastMCP() + + @mcp.resource("r://{x}") + def get_data(x: str) -> str: + return f"Data: {x}" + + assert len(mcp._resource_manager._templates) == 1 + + async def test_add_resource_decorator_incorrect_usage(self): + mcp = FastMCP() + + with pytest.raises( + TypeError, match="The @resource decorator was used incorrectly" + ): + + @mcp.resource # Missing parentheses #type: ignore + def get_data(x: str) -> str: + return f"Data: {x}" + + +def tool_fn(x: int, y: int) -> int: + return x + y + + +def error_tool_fn() -> None: + raise ValueError("Test error") + + +def image_tool_fn(path: str) -> Image: + return Image(path) + + +def mixed_content_tool_fn() -> list[Union[TextContent, ImageContent]]: + return [ + TextContent(type="text", text="Hello"), + ImageContent(type="image", data="abc", mimeType="image/png"), + ] + + +class TestServerTools: + async def test_add_tool(self): + mcp = FastMCP() + mcp.add_tool(tool_fn) + mcp.add_tool(tool_fn) + assert len(mcp._tool_manager.list_tools()) == 1 + + async def test_list_tools(self): + mcp = FastMCP() + mcp.add_tool(tool_fn) + async with client_session(mcp._mcp_server) as client: + tools = await client.list_tools() + assert len(tools.tools) == 1 + + async def test_call_tool(self): + mcp = FastMCP() + mcp.add_tool(tool_fn) + async with client_session(mcp._mcp_server) as client: + result = await client.call_tool("my_tool", {"arg1": "value"}) + assert not hasattr(result, "error") + assert len(result.content) > 0 + + async def test_tool_exception_handling(self): + mcp = FastMCP() + mcp.add_tool(error_tool_fn) + async with client_session(mcp._mcp_server) as client: + result = await client.call_tool("error_tool_fn", {}) + assert len(result.content) == 1 + content = result.content[0] + assert isinstance(content, TextContent) + assert "Test error" in content.text + assert result.isError is True + + async def test_tool_error_handling(self): + mcp = FastMCP() + mcp.add_tool(error_tool_fn) + async with client_session(mcp._mcp_server) as client: + result = await client.call_tool("error_tool_fn", {}) + assert len(result.content) == 1 + content = result.content[0] + assert isinstance(content, TextContent) + assert "Test error" in content.text + assert result.isError is True + + async def test_tool_error_details(self): + """Test that exception details are properly formatted in the response""" + mcp = FastMCP() + mcp.add_tool(error_tool_fn) + async with client_session(mcp._mcp_server) as client: + result = await client.call_tool("error_tool_fn", {}) + content = result.content[0] + assert isinstance(content, TextContent) + assert isinstance(content.text, str) + assert "Test error" in content.text + assert result.isError is True + + async def test_tool_return_value_conversion(self): + mcp = FastMCP() + mcp.add_tool(tool_fn) + async with client_session(mcp._mcp_server) as client: + result = await client.call_tool("tool_fn", {"x": 1, "y": 2}) + assert len(result.content) == 1 + content = result.content[0] + assert isinstance(content, TextContent) + assert content.text == "3" + + async def test_tool_image_helper(self, tmp_path: Path): + # Create a test image + image_path = tmp_path / "test.png" + image_path.write_bytes(b"fake png data") + + mcp = FastMCP() + mcp.add_tool(image_tool_fn) + async with client_session(mcp._mcp_server) as client: + result = await client.call_tool("image_tool_fn", {"path": str(image_path)}) + assert len(result.content) == 1 + content = result.content[0] + assert isinstance(content, ImageContent) + assert content.type == "image" + assert content.mimeType == "image/png" + # Verify base64 encoding + decoded = base64.b64decode(content.data) + assert decoded == b"fake png data" + + async def test_tool_mixed_content(self): + mcp = FastMCP() + mcp.add_tool(mixed_content_tool_fn) + async with client_session(mcp._mcp_server) as client: + result = await client.call_tool("mixed_content_tool_fn", {}) + assert len(result.content) == 2 + content1 = result.content[0] + content2 = result.content[1] + assert isinstance(content1, TextContent) + assert content1.text == "Hello" + assert isinstance(content2, ImageContent) + assert content2.mimeType == "image/png" + assert content2.data == "abc" + + async def test_tool_mixed_list_with_image(self, tmp_path: Path): + """Test that lists containing Image objects and other types are handled correctly""" + # Create a test image + image_path = tmp_path / "test.png" + image_path.write_bytes(b"test image data") + + def mixed_list_fn() -> list: + return [ + "text message", + Image(image_path), + {"key": "value"}, + TextContent(type="text", text="direct content"), + ] + + mcp = FastMCP() + mcp.add_tool(mixed_list_fn) + async with client_session(mcp._mcp_server) as client: + result = await client.call_tool("mixed_list_fn", {}) + assert len(result.content) == 4 + # Check text conversion + content1 = result.content[0] + assert isinstance(content1, TextContent) + assert content1.text == "text message" + # Check image conversion + content2 = result.content[1] + assert isinstance(content2, ImageContent) + assert content2.mimeType == "image/png" + assert base64.b64decode(content2.data) == b"test image data" + # Check dict conversion + content3 = result.content[2] + assert isinstance(content3, TextContent) + assert '"key": "value"' in content3.text + # Check direct TextContent + content4 = result.content[3] + assert isinstance(content4, TextContent) + assert content4.text == "direct content" + + +class TestServerResources: + async def test_text_resource(self): + mcp = FastMCP() + + def get_text(): + return "Hello, world!" + + resource = FunctionResource( + uri=AnyUrl("resource://test"), name="test", fn=get_text + ) + mcp.add_resource(resource) + + async with client_session(mcp._mcp_server) as client: + result = await client.read_resource(AnyUrl("resource://test")) + assert isinstance(result.contents[0], TextResourceContents) + assert result.contents[0].text == "Hello, world!" + + async def test_binary_resource(self): + mcp = FastMCP() + + def get_binary(): + return b"Binary data" + + resource = FunctionResource( + uri=AnyUrl("resource://binary"), + name="binary", + fn=get_binary, + mime_type="application/octet-stream", + ) + mcp.add_resource(resource) + + async with client_session(mcp._mcp_server) as client: + result = await client.read_resource(AnyUrl("resource://binary")) + assert isinstance(result.contents[0], BlobResourceContents) + assert result.contents[0].blob == base64.b64encode(b"Binary data").decode() + + async def test_file_resource_text(self, tmp_path: Path): + mcp = FastMCP() + + # Create a text file + text_file = tmp_path / "test.txt" + text_file.write_text("Hello from file!") + + resource = FileResource( + uri=AnyUrl("file://test.txt"), name="test.txt", path=text_file + ) + mcp.add_resource(resource) + + async with client_session(mcp._mcp_server) as client: + result = await client.read_resource(AnyUrl("file://test.txt")) + assert isinstance(result.contents[0], TextResourceContents) + assert result.contents[0].text == "Hello from file!" + + async def test_file_resource_binary(self, tmp_path: Path): + mcp = FastMCP() + + # Create a binary file + binary_file = tmp_path / "test.bin" + binary_file.write_bytes(b"Binary file data") + + resource = FileResource( + uri=AnyUrl("file://test.bin"), + name="test.bin", + path=binary_file, + mime_type="application/octet-stream", + ) + mcp.add_resource(resource) + + async with client_session(mcp._mcp_server) as client: + result = await client.read_resource(AnyUrl("file://test.bin")) + assert isinstance(result.contents[0], BlobResourceContents) + assert ( + result.contents[0].blob + == base64.b64encode(b"Binary file data").decode() + ) + + +class TestServerResourceTemplates: + async def test_resource_with_params(self): + """Test that a resource with function parameters raises an error if the URI + parameters don't match""" + mcp = FastMCP() + + with pytest.raises(ValueError, match="Mismatch between URI parameters"): + + @mcp.resource("resource://data") + def get_data_fn(param: str) -> str: + return f"Data: {param}" + + async def test_resource_with_uri_params(self): + """Test that a resource with URI parameters is automatically a template""" + mcp = FastMCP() + + with pytest.raises(ValueError, match="Mismatch between URI parameters"): + + @mcp.resource("resource://{param}") + def get_data() -> str: + return "Data" + + async def test_resource_with_untyped_params(self): + """Test that a resource with untyped parameters raises an error""" + mcp = FastMCP() + + @mcp.resource("resource://{param}") + def get_data(param) -> str: + return "Data" + + async def test_resource_matching_params(self): + """Test that a resource with matching URI and function parameters works""" + mcp = FastMCP() + + @mcp.resource("resource://{name}/data") + def get_data(name: str) -> str: + return f"Data for {name}" + + async with client_session(mcp._mcp_server) as client: + result = await client.read_resource(AnyUrl("resource://test/data")) + assert isinstance(result.contents[0], TextResourceContents) + assert result.contents[0].text == "Data for test" + + async def test_resource_mismatched_params(self): + """Test that mismatched parameters raise an error""" + mcp = FastMCP() + + with pytest.raises(ValueError, match="Mismatch between URI parameters"): + + @mcp.resource("resource://{name}/data") + def get_data(user: str) -> str: + return f"Data for {user}" + + async def test_resource_multiple_params(self): + """Test that multiple parameters work correctly""" + mcp = FastMCP() + + @mcp.resource("resource://{org}/{repo}/data") + def get_data(org: str, repo: str) -> str: + return f"Data for {org}/{repo}" + + async with client_session(mcp._mcp_server) as client: + result = await client.read_resource( + AnyUrl("resource://cursor/fastmcp/data") + ) + assert isinstance(result.contents[0], TextResourceContents) + assert result.contents[0].text == "Data for cursor/fastmcp" + + async def test_resource_multiple_mismatched_params(self): + """Test that mismatched parameters raise an error""" + mcp = FastMCP() + + with pytest.raises(ValueError, match="Mismatch between URI parameters"): + + @mcp.resource("resource://{org}/{repo}/data") + def get_data_mismatched(org: str, repo_2: str) -> str: + return f"Data for {org}" + + """Test that a resource with no parameters works as a regular resource""" + mcp = FastMCP() + + @mcp.resource("resource://static") + def get_static_data() -> str: + return "Static data" + + async with client_session(mcp._mcp_server) as client: + result = await client.read_resource(AnyUrl("resource://static")) + assert isinstance(result.contents[0], TextResourceContents) + assert result.contents[0].text == "Static data" + + async def test_template_to_resource_conversion(self): + """Test that templates are properly converted to resources when accessed""" + mcp = FastMCP() + + @mcp.resource("resource://{name}/data") + def get_data(name: str) -> str: + return f"Data for {name}" + + # Should be registered as a template + assert len(mcp._resource_manager._templates) == 1 + assert len(await mcp.list_resources()) == 0 + + # When accessed, should create a concrete resource + resource = await mcp._resource_manager.get_resource("resource://test/data") + assert isinstance(resource, FunctionResource) + result = await resource.read() + assert result == "Data for test" + + +class TestContextInjection: + """Test context injection in tools.""" + + async def test_context_detection(self): + """Test that context parameters are properly detected.""" + mcp = FastMCP() + + def tool_with_context(x: int, ctx: Context) -> str: + return f"Request {ctx.request_id}: {x}" + + tool = mcp._tool_manager.add_tool(tool_with_context) + assert tool.context_kwarg == "ctx" + + async def test_context_injection(self): + """Test that context is properly injected into tool calls.""" + mcp = FastMCP() + + def tool_with_context(x: int, ctx: Context) -> str: + assert ctx.request_id is not None + return f"Request {ctx.request_id}: {x}" + + mcp.add_tool(tool_with_context) + async with client_session(mcp._mcp_server) as client: + result = await client.call_tool("tool_with_context", {"x": 42}) + assert len(result.content) == 1 + content = result.content[0] + assert isinstance(content, TextContent) + assert "Request" in content.text + assert "42" in content.text + + async def test_async_context(self): + """Test that context works in async functions.""" + mcp = FastMCP() + + async def async_tool(x: int, ctx: Context) -> str: + assert ctx.request_id is not None + return f"Async request {ctx.request_id}: {x}" + + mcp.add_tool(async_tool) + async with client_session(mcp._mcp_server) as client: + result = await client.call_tool("async_tool", {"x": 42}) + assert len(result.content) == 1 + content = result.content[0] + assert isinstance(content, TextContent) + assert "Async request" in content.text + assert "42" in content.text + + async def test_context_logging(self): + """Test that context logging methods work.""" + mcp = FastMCP() + + def logging_tool(msg: str, ctx: Context) -> str: + ctx.debug("Debug message") + ctx.info("Info message") + ctx.warning("Warning message") + ctx.error("Error message") + return f"Logged messages for {msg}" + + mcp.add_tool(logging_tool) + async with client_session(mcp._mcp_server) as client: + result = await client.call_tool("logging_tool", {"msg": "test"}) + assert len(result.content) == 1 + content = result.content[0] + assert isinstance(content, TextContent) + assert "Logged messages for test" in content.text + + async def test_optional_context(self): + """Test that context is optional.""" + mcp = FastMCP() + + def no_context(x: int) -> int: + return x * 2 + + mcp.add_tool(no_context) + async with client_session(mcp._mcp_server) as client: + result = await client.call_tool("no_context", {"x": 21}) + assert len(result.content) == 1 + content = result.content[0] + assert isinstance(content, TextContent) + assert content.text == "42" + + async def test_context_resource_access(self): + """Test that context can access resources.""" + mcp = FastMCP() + + @mcp.resource("test://data") + def test_resource() -> str: + return "resource data" + + @mcp.tool() + async def tool_with_resource(ctx: Context) -> str: + data = await ctx.read_resource("test://data") + return f"Read resource: {data}" + + async with client_session(mcp._mcp_server) as client: + result = await client.call_tool("tool_with_resource", {}) + assert len(result.content) == 1 + content = result.content[0] + assert isinstance(content, TextContent) + assert "Read resource: resource data" in content.text + + +class TestServerPrompts: + """Test prompt functionality in FastMCP server.""" + + async def test_prompt_decorator(self): + """Test that the prompt decorator registers prompts correctly.""" + mcp = FastMCP() + + @mcp.prompt() + def fn() -> str: + return "Hello, world!" + + prompts = mcp._prompt_manager.list_prompts() + assert len(prompts) == 1 + assert prompts[0].name == "fn" + # Don't compare functions directly since validate_call wraps them + content = await prompts[0].render() + assert isinstance(content[0].content, TextContent) + assert content[0].content.text == "Hello, world!" + + async def test_prompt_decorator_with_name(self): + """Test prompt decorator with custom name.""" + mcp = FastMCP() + + @mcp.prompt(name="custom_name") + def fn() -> str: + return "Hello, world!" + + prompts = mcp._prompt_manager.list_prompts() + assert len(prompts) == 1 + assert prompts[0].name == "custom_name" + content = await prompts[0].render() + assert isinstance(content[0].content, TextContent) + assert content[0].content.text == "Hello, world!" + + async def test_prompt_decorator_with_description(self): + """Test prompt decorator with custom description.""" + mcp = FastMCP() + + @mcp.prompt(description="A custom description") + def fn() -> str: + return "Hello, world!" + + prompts = mcp._prompt_manager.list_prompts() + assert len(prompts) == 1 + assert prompts[0].description == "A custom description" + content = await prompts[0].render() + assert isinstance(content[0].content, TextContent) + assert content[0].content.text == "Hello, world!" + + def test_prompt_decorator_error(self): + """Test error when decorator is used incorrectly.""" + mcp = FastMCP() + with pytest.raises(TypeError, match="decorator was used incorrectly"): + + @mcp.prompt # type: ignore + def fn() -> str: + return "Hello, world!" + + async def test_list_prompts(self): + """Test listing prompts through MCP protocol.""" + mcp = FastMCP() + + @mcp.prompt() + def fn(name: str, optional: str = "default") -> str: + return f"Hello, {name}!" + + async with client_session(mcp._mcp_server) as client: + result = await client.list_prompts() + assert result.prompts is not None + assert len(result.prompts) == 1 + prompt = result.prompts[0] + assert prompt.name == "fn" + assert prompt.arguments is not None + assert len(prompt.arguments) == 2 + assert prompt.arguments[0].name == "name" + assert prompt.arguments[0].required is True + assert prompt.arguments[1].name == "optional" + assert prompt.arguments[1].required is False + + async def test_get_prompt(self): + """Test getting a prompt through MCP protocol.""" + mcp = FastMCP() + + @mcp.prompt() + def fn(name: str) -> str: + return f"Hello, {name}!" + + async with client_session(mcp._mcp_server) as client: + result = await client.get_prompt("fn", {"name": "World"}) + assert len(result.messages) == 1 + message = result.messages[0] + assert message.role == "user" + content = message.content + assert isinstance(content, TextContent) + assert content.text == "Hello, World!" + + async def test_get_prompt_with_resource(self): + """Test getting a prompt that returns resource content.""" + mcp = FastMCP() + + @mcp.prompt() + def fn() -> Message: + return UserMessage( + content=EmbeddedResource( + type="resource", + resource=TextResourceContents( + uri=AnyUrl("file://file.txt"), + text="File contents", + mimeType="text/plain", + ), + ) + ) + + async with client_session(mcp._mcp_server) as client: + result = await client.get_prompt("fn") + assert len(result.messages) == 1 + message = result.messages[0] + assert message.role == "user" + content = message.content + assert isinstance(content, EmbeddedResource) + resource = content.resource + assert isinstance(resource, TextResourceContents) + assert resource.text == "File contents" + assert resource.mimeType == "text/plain" + + async def test_get_unknown_prompt(self): + """Test error when getting unknown prompt.""" + mcp = FastMCP() + async with client_session(mcp._mcp_server) as client: + with pytest.raises(McpError, match="Unknown prompt"): + await client.get_prompt("unknown") + + async def test_get_prompt_missing_args(self): + """Test error when required arguments are missing.""" + mcp = FastMCP() + + @mcp.prompt() + def prompt_fn(name: str) -> str: + return f"Hello, {name}!" + + async with client_session(mcp._mcp_server) as client: + with pytest.raises(McpError, match="Missing required arguments"): + await client.get_prompt("prompt_fn") diff --git a/tests/server/fastmcp/test_tool_manager.py b/tests/server/fastmcp/test_tool_manager.py new file mode 100644 index 0000000..884059a --- /dev/null +++ b/tests/server/fastmcp/test_tool_manager.py @@ -0,0 +1,306 @@ +import logging +from typing import Optional + +import pytest +from pydantic import BaseModel +import json +from mcp.server.fastmcp.exceptions import ToolError +from mcp.server.fastmcp.tools import ToolManager + + +class TestAddTools: + def test_basic_function(self): + """Test registering and running a basic function.""" + + def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + manager = ToolManager() + manager.add_tool(add) + + tool = manager.get_tool("add") + assert tool is not None + assert tool.name == "add" + assert tool.description == "Add two numbers." + assert tool.is_async is False + assert tool.parameters["properties"]["a"]["type"] == "integer" + assert tool.parameters["properties"]["b"]["type"] == "integer" + + async def test_async_function(self): + """Test registering and running an async function.""" + + async def fetch_data(url: str) -> str: + """Fetch data from URL.""" + return f"Data from {url}" + + manager = ToolManager() + manager.add_tool(fetch_data) + + tool = manager.get_tool("fetch_data") + assert tool is not None + assert tool.name == "fetch_data" + assert tool.description == "Fetch data from URL." + assert tool.is_async is True + assert tool.parameters["properties"]["url"]["type"] == "string" + + def test_pydantic_model_function(self): + """Test registering a function that takes a Pydantic model.""" + + class UserInput(BaseModel): + name: str + age: int + + def create_user(user: UserInput, flag: bool) -> dict: + """Create a new user.""" + return {"id": 1, **user.model_dump()} + + manager = ToolManager() + manager.add_tool(create_user) + + tool = manager.get_tool("create_user") + assert tool is not None + assert tool.name == "create_user" + assert tool.description == "Create a new user." + assert tool.is_async is False + assert "name" in tool.parameters["$defs"]["UserInput"]["properties"] + assert "age" in tool.parameters["$defs"]["UserInput"]["properties"] + assert "flag" in tool.parameters["properties"] + + def test_add_invalid_tool(self): + manager = ToolManager() + with pytest.raises(AttributeError): + manager.add_tool(1) # type: ignore + + def test_add_lambda(self): + manager = ToolManager() + tool = manager.add_tool(lambda x: x, name="my_tool") + assert tool.name == "my_tool" + + def test_add_lambda_with_no_name(self): + manager = ToolManager() + with pytest.raises( + ValueError, match="You must provide a name for lambda functions" + ): + manager.add_tool(lambda x: x) + + def test_warn_on_duplicate_tools(self, caplog): + """Test warning on duplicate tools.""" + + def f(x: int) -> int: + return x + + manager = ToolManager() + manager.add_tool(f) + with caplog.at_level(logging.WARNING): + manager.add_tool(f) + assert "Tool already exists: f" in caplog.text + + def test_disable_warn_on_duplicate_tools(self, caplog): + """Test disabling warning on duplicate tools.""" + + def f(x: int) -> int: + return x + + manager = ToolManager() + manager.add_tool(f) + manager.warn_on_duplicate_tools = False + with caplog.at_level(logging.WARNING): + manager.add_tool(f) + assert "Tool already exists: f" not in caplog.text + + +class TestCallTools: + async def test_call_tool(self): + def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + manager = ToolManager() + manager.add_tool(add) + result = await manager.call_tool("add", {"a": 1, "b": 2}) + assert result == 3 + + async def test_call_async_tool(self): + async def double(n: int) -> int: + """Double a number.""" + return n * 2 + + manager = ToolManager() + manager.add_tool(double) + result = await manager.call_tool("double", {"n": 5}) + assert result == 10 + + async def test_call_tool_with_default_args(self): + def add(a: int, b: int = 1) -> int: + """Add two numbers.""" + return a + b + + manager = ToolManager() + manager.add_tool(add) + result = await manager.call_tool("add", {"a": 1}) + assert result == 2 + + async def test_call_tool_with_missing_args(self): + def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + manager = ToolManager() + manager.add_tool(add) + with pytest.raises(ToolError): + await manager.call_tool("add", {"a": 1}) + + async def test_call_unknown_tool(self): + manager = ToolManager() + with pytest.raises(ToolError): + await manager.call_tool("unknown", {"a": 1}) + + async def test_call_tool_with_list_int_input(self): + def sum_vals(vals: list[int]) -> int: + return sum(vals) + + manager = ToolManager() + manager.add_tool(sum_vals) + # Try both with plain list and with JSON list + result = await manager.call_tool("sum_vals", {"vals": "[1, 2, 3]"}) + assert result == 6 + result = await manager.call_tool("sum_vals", {"vals": [1, 2, 3]}) + assert result == 6 + + async def test_call_tool_with_list_str_or_str_input(self): + def concat_strs(vals: list[str] | str) -> str: + return vals if isinstance(vals, str) else "".join(vals) + + manager = ToolManager() + manager.add_tool(concat_strs) + # Try both with plain python object and with JSON list + result = await manager.call_tool("concat_strs", {"vals": ["a", "b", "c"]}) + assert result == "abc" + result = await manager.call_tool("concat_strs", {"vals": '["a", "b", "c"]'}) + assert result == "abc" + result = await manager.call_tool("concat_strs", {"vals": "a"}) + assert result == "a" + result = await manager.call_tool("concat_strs", {"vals": '"a"'}) + assert result == '"a"' + + async def test_call_tool_with_complex_model(self): + from mcp.server.fastmcp import Context + + class MyShrimpTank(BaseModel): + class Shrimp(BaseModel): + name: str + + shrimp: list[Shrimp] + x: None + + def name_shrimp(tank: MyShrimpTank, ctx: Context) -> list[str]: + return [x.name for x in tank.shrimp] + + manager = ToolManager() + manager.add_tool(name_shrimp) + result = await manager.call_tool( + "name_shrimp", + {"tank": {"x": None, "shrimp": [{"name": "rex"}, {"name": "gertrude"}]}}, + ) + assert result == ["rex", "gertrude"] + result = await manager.call_tool( + "name_shrimp", + {"tank": '{"x": null, "shrimp": [{"name": "rex"}, {"name": "gertrude"}]}'}, + ) + assert result == ["rex", "gertrude"] + + +class TestToolSchema: + async def test_context_arg_excluded_from_schema(self): + from mcp.server.fastmcp import Context + + def something(a: int, ctx: Context) -> int: + return a + + manager = ToolManager() + tool = manager.add_tool(something) + assert "ctx" not in json.dumps(tool.parameters) + assert "Context" not in json.dumps(tool.parameters) + assert "ctx" not in tool.fn_metadata.arg_model.model_fields + + +class TestContextHandling: + """Test context handling in the tool manager.""" + + def test_context_parameter_detection(self): + """Test that context parameters are properly detected in Tool.from_function().""" + from mcp.server.fastmcp import Context + + def tool_with_context(x: int, ctx: Context) -> str: + return str(x) + + manager = ToolManager() + tool = manager.add_tool(tool_with_context) + assert tool.context_kwarg == "ctx" + + def tool_without_context(x: int) -> str: + return str(x) + + tool = manager.add_tool(tool_without_context) + assert tool.context_kwarg is None + + async def test_context_injection(self): + """Test that context is properly injected during tool execution.""" + from mcp.server.fastmcp import Context, FastMCP + + def tool_with_context(x: int, ctx: Context) -> str: + assert isinstance(ctx, Context) + return str(x) + + manager = ToolManager() + manager.add_tool(tool_with_context) + + mcp = FastMCP() + ctx = mcp.get_context() + result = await manager.call_tool("tool_with_context", {"x": 42}, context=ctx) + assert result == "42" + + async def test_context_injection_async(self): + """Test that context is properly injected in async tools.""" + from mcp.server.fastmcp import Context, FastMCP + + async def async_tool(x: int, ctx: Context) -> str: + assert isinstance(ctx, Context) + return str(x) + + manager = ToolManager() + manager.add_tool(async_tool) + + mcp = FastMCP() + ctx = mcp.get_context() + result = await manager.call_tool("async_tool", {"x": 42}, context=ctx) + assert result == "42" + + async def test_context_optional(self): + """Test that context is optional when calling tools.""" + from mcp.server.fastmcp import Context + + def tool_with_context(x: int, ctx: Optional[Context] = None) -> str: + return str(x) + + manager = ToolManager() + manager.add_tool(tool_with_context) + # Should not raise an error when context is not provided + result = await manager.call_tool("tool_with_context", {"x": 42}) + assert result == "42" + + async def test_context_error_handling(self): + """Test error handling when context injection fails.""" + from mcp.server.fastmcp import Context, FastMCP + + def tool_with_context(x: int, ctx: Context) -> str: + raise ValueError("Test error") + + manager = ToolManager() + manager.add_tool(tool_with_context) + + mcp = FastMCP() + ctx = mcp.get_context() + with pytest.raises(ToolError, match="Error executing tool tool_with_context"): + await manager.call_tool("tool_with_context", {"x": 42}, context=ctx) diff --git a/tests/server/test_session.py b/tests/server/test_session.py index a78ca90..ead18f7 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -2,7 +2,8 @@ import anyio import pytest from mcp.client.session import ClientSession -from mcp.server import NotificationOptions, Server +from mcp.server import Server +from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.types import ( diff --git a/uv.lock b/uv.lock index db4dbc7..18009ba 100644 --- a/uv.lock +++ b/uv.lock @@ -38,20 +38,20 @@ wheels = [ [[package]] name = "attrs" -version = "24.2.0" +version = "24.3.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fc/0f/aafca9af9315aee06a89ffde799a10a582fe8de76c563ee80bbcdc08b3fb/attrs-24.2.0.tar.gz", hash = "sha256:5cfb1b9148b5b086569baec03f20d7b6bf3bcacc9a42bebf87ffaaca362f6346", size = 792678 } +sdist = { url = "https://files.pythonhosted.org/packages/48/c8/6260f8ccc11f0917360fc0da435c5c9c7504e3db174d5a12a1494887b045/attrs-24.3.0.tar.gz", hash = "sha256:8f5c07333d543103541ba7be0e2ce16eeee8130cb0b3f9238ab904ce1e85baff", size = 805984 } wheels = [ - { url = "https://files.pythonhosted.org/packages/6a/21/5b6702a7f963e95456c0de2d495f67bf5fd62840ac655dc451586d23d39a/attrs-24.2.0-py3-none-any.whl", hash = "sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2", size = 63001 }, + { url = "https://files.pythonhosted.org/packages/89/aa/ab0f7891a01eeb2d2e338ae8fecbe57fcebea1a24dbb64d45801bfab481d/attrs-24.3.0-py3-none-any.whl", hash = "sha256:ac96cd038792094f438ad1f6ff80837353805ac950cd2aa0e0625ef19850c308", size = 63397 }, ] [[package]] name = "certifi" -version = "2024.8.30" +version = "2024.12.14" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b0/ee/9b19140fe824b367c04c5e1b369942dd754c4c5462d5674002f75c4dedc1/certifi-2024.8.30.tar.gz", hash = "sha256:bec941d2aa8195e248a60b31ff9f0558284cf01a52591ceda73ea9afffd69fd9", size = 168507 } +sdist = { url = "https://files.pythonhosted.org/packages/0f/bd/1d41ee578ce09523c81a15426705dd20969f5abf006d1afe8aeff0dd776a/certifi-2024.12.14.tar.gz", hash = "sha256:b650d30f370c2b724812bee08008be0c4163b163ddaec3f2546c1caf65f191db", size = 166010 } wheels = [ - { url = "https://files.pythonhosted.org/packages/12/90/3c9ff0512038035f59d279fddeb79f5f1eccd8859f06d6163c58798b9487/certifi-2024.8.30-py3-none-any.whl", hash = "sha256:922820b53db7a7257ffbda3f597266d435245903d80737e34f8a45ff3e3230d8", size = 167321 }, + { url = "https://files.pythonhosted.org/packages/a5/32/8f6669fc4798494966bf446c8c4a162e0b5d893dff088afddf76414f70e1/certifi-2024.12.14-py3-none-any.whl", hash = "sha256:1275f7a45be9464efc1173084eaa30f866fe2e47d389406136d332ed4967ec56", size = 164927 }, ] [[package]] @@ -103,6 +103,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/02/cc/b7e31358aac6ed1ef2bb790a9746ac2c69bcb3c8588b41616914eb106eaf/exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b", size = 16453 }, ] +[[package]] +name = "execnet" +version = "2.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bb/ff/b4c0dc78fbe20c3e59c0c7334de0c27eb4001a2b2017999af398bf730817/execnet-2.1.1.tar.gz", hash = "sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3", size = 166524 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/09/2aea36ff60d16dd8879bdb2f5b3ee0ba8d08cbbdcdfe870e695ce3784385/execnet-2.1.1-py3-none-any.whl", hash = "sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc", size = 40612 }, +] + [[package]] name = "h11" version = "0.14.0" @@ -168,6 +177,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 }, ] +[[package]] +name = "markdown-it-py" +version = "3.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mdurl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/38/71/3b932df36c1a044d397a1f92d1cf91ee0a503d91e470cbd670aa66b07ed0/markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb", size = 74596 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1", size = 87528 }, +] + [[package]] name = "mcp" version = "1.1.2.dev0" @@ -177,14 +198,23 @@ dependencies = [ { name = "httpx" }, { name = "httpx-sse" }, { name = "pydantic" }, + { name = "pydantic-settings" }, { name = "sse-starlette" }, { name = "starlette" }, ] +[package.optional-dependencies] +rich = [ + { name = "rich" }, +] + [package.dev-dependencies] dev = [ { name = "pyright" }, { name = "pytest" }, + { name = "pytest-asyncio" }, + { name = "pytest-flakefinder" }, + { name = "pytest-xdist" }, { name = "ruff" }, { name = "trio" }, ] @@ -194,7 +224,9 @@ requires-dist = [ { name = "anyio", specifier = ">=4.5" }, { name = "httpx", specifier = ">=0.27" }, { name = "httpx-sse", specifier = ">=0.4" }, - { name = "pydantic", specifier = ">=2.7.2" }, + { name = "pydantic", specifier = ">=2.7.2,<3.0.0" }, + { name = "pydantic-settings", specifier = ">=2.6.1" }, + { name = "rich", marker = "extra == 'rich'", specifier = ">=13.9.4" }, { name = "sse-starlette", specifier = ">=1.6.1" }, { name = "starlette", specifier = ">=0.27" }, ] @@ -203,6 +235,9 @@ requires-dist = [ dev = [ { name = "pyright", specifier = ">=1.1.378" }, { name = "pytest", specifier = ">=8.3.3" }, + { name = "pytest-asyncio", specifier = ">=0.24.0" }, + { name = "pytest-flakefinder", specifier = ">=1.1.0" }, + { name = "pytest-xdist", specifier = ">=3.6.1" }, { name = "ruff", specifier = ">=0.6.9" }, { name = "trio", specifier = ">=0.26.2" }, ] @@ -306,6 +341,15 @@ dev = [ { name = "ruff", specifier = ">=0.6.9" }, ] +[[package]] +name = "mdurl" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979 }, +] + [[package]] name = "nodeenv" version = "1.9.1" @@ -425,6 +469,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/af/d1/1c18f8e215930665e65597dd677937595355057f631bf4b9110aa6f88f79/pydantic_core-2.18.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:c0037a92cf0c580ed14e10953cdd26528e8796307bb8bb312dc65f71547df04d", size = 1898163 }, ] +[[package]] +name = "pydantic-settings" +version = "2.6.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "python-dotenv" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b5/d4/9dfbe238f45ad8b168f5c96ee49a3df0598ce18a0795a983b419949ce65b/pydantic_settings-2.6.1.tar.gz", hash = "sha256:e0f92546d8a9923cb8941689abf85d6601a8c19a23e97a34b2964a2e3f813ca0", size = 75646 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5e/f9/ff95fd7d760af42f647ea87f9b8a383d891cdb5e5dbd4613edaeb094252a/pydantic_settings-2.6.1-py3-none-any.whl", hash = "sha256:7fb0637c786a558d3103436278a7c4f1cfd29ba8973238a50c5bb9a55387da87", size = 28595 }, +] + +[[package]] +name = "pygments" +version = "2.18.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/62/8336eff65bcbc8e4cb5d05b55faf041285951b6e80f33e2bff2024788f31/pygments-2.18.0.tar.gz", hash = "sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199", size = 4891905 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f7/3f/01c8b82017c199075f8f788d0d906b9ffbbc5a47dc9918a945e13d5a2bda/pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a", size = 1205513 }, +] + [[package]] name = "pyright" version = "1.1.378" @@ -454,6 +520,66 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6b/77/7440a06a8ead44c7757a64362dd22df5760f9b12dc5f11b6188cd2fc27a0/pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2", size = 342341 }, ] +[[package]] +name = "pytest-asyncio" +version = "0.24.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/52/6d/c6cf50ce320cf8611df7a1254d86233b3df7cc07f9b5f5cbcb82e08aa534/pytest_asyncio-0.24.0.tar.gz", hash = "sha256:d081d828e576d85f875399194281e92bf8a68d60d72d1a2faf2feddb6c46b276", size = 49855 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/96/31/6607dab48616902f76885dfcf62c08d929796fc3b2d2318faf9fd54dbed9/pytest_asyncio-0.24.0-py3-none-any.whl", hash = "sha256:a811296ed596b69bf0b6f3dc40f83bcaf341b155a269052d82efa2b25ac7037b", size = 18024 }, +] + +[[package]] +name = "pytest-flakefinder" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ec/53/69c56a93ea057895b5761c5318455804873a6cd9d796d7c55d41c2358125/pytest-flakefinder-1.1.0.tar.gz", hash = "sha256:e2412a1920bdb8e7908783b20b3d57e9dad590cc39a93e8596ffdd493b403e0e", size = 6795 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/33/8b/06787150d0fd0cbd3a8054262b56f91631c7778c1bc91bf4637e47f909ad/pytest_flakefinder-1.1.0-py2.py3-none-any.whl", hash = "sha256:741e0e8eea427052f5b8c89c2b3c3019a50c39a59ce4df6a305a2c2d9ba2bd13", size = 4644 }, +] + +[[package]] +name = "pytest-xdist" +version = "3.6.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "execnet" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/41/c4/3c310a19bc1f1e9ef50075582652673ef2bfc8cd62afef9585683821902f/pytest_xdist-3.6.1.tar.gz", hash = "sha256:ead156a4db231eec769737f57668ef58a2084a34b2e55c4a8fa20d861107300d", size = 84060 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6d/82/1d96bf03ee4c0fdc3c0cbe61470070e659ca78dc0086fb88b66c185e2449/pytest_xdist-3.6.1-py3-none-any.whl", hash = "sha256:9ed4adfb68a016610848639bb7e02c9352d5d9f03d04809919e2dafc3be4cca7", size = 46108 }, +] + +[[package]] +name = "python-dotenv" +version = "1.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bc/57/e84d88dfe0aec03b7a2d4327012c1627ab5f03652216c63d49846d7a6c58/python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca", size = 39115 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6a/3e/b68c118422ec867fa7ab88444e1274aa40681c606d59ac27de5a5588f082/python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a", size = 19863 }, +] + +[[package]] +name = "rich" +version = "13.9.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, + { name = "pygments" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ab/3a/0316b28d0761c6734d6bc14e770d85506c986c85ffb239e688eeaab2c2bc/rich-13.9.4.tar.gz", hash = "sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098", size = 223149 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/19/71/39c7c0d87f8d4e6c020a393182060eaefeeae6c01dab6a84ec346f2567df/rich-13.9.4-py3-none-any.whl", hash = "sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90", size = 242424 }, +] + [[package]] name = "ruff" version = "0.6.9"