style: Fix imports and line length formatting

This commit is contained in:
David Soria Parra
2024-12-19 22:33:40 +00:00
parent 7bbf71e29a
commit a79f51f55f
37 changed files with 242 additions and 135 deletions

View File

@@ -4,8 +4,10 @@ FastMCP Complex inputs Example
Demonstrates validation via pydantic with complex models. Demonstrates validation via pydantic with complex models.
""" """
from pydantic import BaseModel, Field
from typing import Annotated from typing import Annotated
from pydantic import BaseModel, Field
from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp import FastMCP
mcp = FastMCP("Shrimp Tank") mcp = FastMCP("Shrimp Tank")

View File

@@ -6,7 +6,8 @@
""" """
Recursive memory system inspired by the human brain's clustering of memories. Recursive memory system inspired by the human brain's clustering of memories.
Uses OpenAI's 'text-embedding-3-small' model and pgvector for efficient similarity search. Uses OpenAI's 'text-embedding-3-small' model and pgvector for efficient
similarity search.
""" """
import asyncio import asyncio
@@ -111,7 +112,8 @@ class MemoryNode(BaseModel):
if self.id is None: if self.id is None:
result = await conn.fetchrow( result = await conn.fetchrow(
""" """
INSERT INTO memories (content, summary, importance, access_count, timestamp, embedding) INSERT INTO memories (content, summary, importance, access_count,
timestamp, embedding)
VALUES ($1, $2, $3, $4, $5, $6) VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id RETURNING id
""", """,
@@ -336,7 +338,8 @@ async def initialize_database():
timestamp DOUBLE PRECISION NOT NULL, timestamp DOUBLE PRECISION NOT NULL,
embedding vector(1536) NOT NULL embedding vector(1536) NOT NULL
); );
CREATE INDEX IF NOT EXISTS idx_memories_embedding ON memories USING hnsw (embedding vector_l2_ops); CREATE INDEX IF NOT EXISTS idx_memories_embedding ON memories
USING hnsw (embedding vector_l2_ops);
""") """)
finally: finally:
await pool.close() await pool.close()

View File

@@ -1,6 +1,5 @@
from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp import FastMCP
# Create an MCP server # Create an MCP server
mcp = FastMCP("Demo") mcp = FastMCP("Demo")

View File

@@ -5,10 +5,10 @@ Give Claude a tool to capture and view screenshots.
""" """
import io import io
from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp import FastMCP
from mcp.server.fastmcp.utilities.types import Image from mcp.server.fastmcp.utilities.types import Image
# Create server # Create server
mcp = FastMCP("Screenshot Demo", dependencies=["pyautogui", "Pillow"]) mcp = FastMCP("Screenshot Demo", dependencies=["pyautogui", "Pillow"])

View File

@@ -4,7 +4,6 @@ FastMCP Echo Server
from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp import FastMCP
# Create server # Create server
mcp = FastMCP("Echo Server") mcp = FastMCP("Echo Server")

View File

@@ -19,6 +19,7 @@ Visit https://surgemsg.com/ and click "Get Started" to obtain these values.
""" """
from typing import Annotated from typing import Annotated
import httpx import httpx
from pydantic import BeforeValidator from pydantic import BeforeValidator
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict

View File

@@ -1,8 +1,8 @@
import anyio import anyio
import click import click
import mcp.types as types import mcp.types as types
from pydantic import AnyUrl
from mcp.server.lowlevel import Server from mcp.server.lowlevel import Server
from pydantic import AnyUrl
SAMPLE_RESOURCES = { SAMPLE_RESOURCES = {
"greeting": "Hello! This is a sample text resource.", "greeting": "Hello! This is a sample text resource.",

View File

@@ -47,7 +47,6 @@ dev-dependencies = [
"trio>=0.26.2", "trio>=0.26.2",
"pytest-flakefinder>=1.1.0", "pytest-flakefinder>=1.1.0",
"pytest-xdist>=3.6.1", "pytest-xdist>=3.6.1",
"pytest-asyncio>=0.24.0",
] ]
[build-system] [build-system]

View File

@@ -2,6 +2,5 @@
from .cli import app from .cli import app
if __name__ == "__main__": if __name__ == "__main__":
app() app()

View File

@@ -48,8 +48,8 @@ def update_claude_config(
config_dir = get_claude_config_path() config_dir = get_claude_config_path()
if not config_dir: if not config_dir:
raise RuntimeError( raise RuntimeError(
"Claude Desktop config directory not found. Please ensure Claude Desktop " "Claude Desktop config directory not found. Please ensure Claude Desktop"
"is installed and has been run at least once to initialize its configuration." " is installed and has been run at least once to initialize its config."
) )
config_file = config_dir / "claude_desktop_config.json" config_file = config_dir / "claude_desktop_config.json"

View File

@@ -295,7 +295,8 @@ def run(
"""Run a MCP server. """Run a MCP server.
The server can be specified in two ways: The server can be specified in two ways:
1. Module approach: server.py - runs the module directly, expecting a server.run() call 1. Module approach: server.py - runs the module directly, expecting a server.run()
call
2. Import approach: server.py:app - imports and runs the specified server object 2. Import approach: server.py:app - imports and runs the specified server object
Note: This command runs the server directly. You are responsible for ensuring Note: This command runs the server directly. You are responsible for ensuring
@@ -346,7 +347,8 @@ def install(
typer.Option( typer.Option(
"--name", "--name",
"-n", "-n",
help="Custom name for the server (defaults to server's name attribute or file name)", help="Custom name for the server (defaults to server's name attribute or"
" file name)",
), ),
] = None, ] = None,
with_editable: Annotated[ with_editable: Annotated[
@@ -410,7 +412,8 @@ def install(
logger.error("Claude app not found") logger.error("Claude app not found")
sys.exit(1) sys.exit(1)
# Try to import server to get its name, but fall back to file name if dependencies missing # Try to import server to get its name, but fall back to file name if dependencies
# missing
name = server_name name = server_name
server = None server = None
if not name: if not name:
@@ -419,7 +422,8 @@ def install(
name = server.name name = server.name
except (ImportError, ModuleNotFoundError) as e: except (ImportError, ModuleNotFoundError) as e:
logger.debug( logger.debug(
"Could not import server (likely missing dependencies), using file name", "Could not import server (likely missing dependencies), using file"
" name",
extra={"error": str(e)}, extra={"error": str(e)},
) )
name = file.stem name = file.stem

View File

@@ -1,4 +1,4 @@
from .lowlevel import Server, NotificationOptions
from .fastmcp import FastMCP from .fastmcp import FastMCP
from .lowlevel import NotificationOptions, Server
__all__ = ["Server", "FastMCP", "NotificationOptions"] __all__ = ["Server", "FastMCP", "NotificationOptions"]

View File

@@ -1,7 +1,8 @@
"""FastMCP - A more ergonomic interface for MCP servers.""" """FastMCP - A more ergonomic interface for MCP servers."""
from importlib.metadata import version from importlib.metadata import version
from .server import FastMCP, Context
from .server import Context, FastMCP
from .utilities.types import Image from .utilities.types import Image
__version__ = version("mcp") __version__ = version("mcp")

View File

@@ -1,13 +1,14 @@
"""Base classes for FastMCP prompts.""" """Base classes for FastMCP prompts."""
import json
from typing import Any, Literal, Sequence, Awaitable
import inspect import inspect
import json
from collections.abc import Callable from collections.abc import Callable
from typing import Any, Awaitable, Literal, Sequence
from pydantic import BaseModel, Field, TypeAdapter, validate_call
from mcp.types import TextContent, ImageContent, EmbeddedResource
import pydantic_core import pydantic_core
from pydantic import BaseModel, Field, TypeAdapter, validate_call
from mcp.types import EmbeddedResource, ImageContent, TextContent
CONTENT_TYPES = TextContent | ImageContent | EmbeddedResource CONTENT_TYPES = TextContent | ImageContent | EmbeddedResource

View File

@@ -1,14 +1,14 @@
from .base import Resource from .base import Resource
from .types import (
TextResource,
BinaryResource,
FunctionResource,
FileResource,
HttpResource,
DirectoryResource,
)
from .templates import ResourceTemplate
from .resource_manager import ResourceManager from .resource_manager import ResourceManager
from .templates import ResourceTemplate
from .types import (
BinaryResource,
DirectoryResource,
FileResource,
FunctionResource,
HttpResource,
TextResource,
)
__all__ = [ __all__ = [
"Resource", "Resource",

View File

@@ -1,7 +1,6 @@
"""Resource manager functionality.""" """Resource manager functionality."""
from typing import Callable from typing import Callable
from collections.abc import Iterable
from pydantic import AnyUrl from pydantic import AnyUrl

View File

@@ -1,11 +1,11 @@
"""Concrete resource implementations.""" """Concrete resource implementations."""
import anyio
import json import json
from pathlib import Path
from typing import Any, Callable
from collections.abc import Callable from collections.abc import Callable
from pathlib import Path
from typing import Any
import anyio
import httpx import httpx
import pydantic.json import pydantic.json
import pydantic_core import pydantic_core

View File

@@ -1,17 +1,25 @@
"""FastMCP - A more ergonomic interface for MCP servers.""" """FastMCP - A more ergonomic interface for MCP servers."""
import anyio
import functools import functools
import inspect import inspect
import json import json
import re import re
from itertools import chain from itertools import chain
from typing import Any, Callable, Literal, Sequence from typing import Any, Callable, Literal, Sequence
from collections.abc import Iterable
import anyio
import pydantic_core import pydantic_core
from pydantic import Field
import uvicorn import uvicorn
from pydantic import BaseModel, Field
from pydantic.networks import AnyUrl
from pydantic_settings import BaseSettings, SettingsConfigDict
from mcp.server.fastmcp.exceptions import ResourceError
from mcp.server.fastmcp.prompts import Prompt, PromptManager
from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager
from mcp.server.fastmcp.tools import ToolManager
from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger
from mcp.server.fastmcp.utilities.types import Image
from mcp.server.lowlevel import Server as MCPServer from mcp.server.lowlevel import Server as MCPServer
from mcp.server.sse import SseServerTransport from mcp.server.sse import SseServerTransport
from mcp.server.stdio import stdio_server from mcp.server.stdio import stdio_server
@@ -24,6 +32,8 @@ from mcp.types import (
) )
from mcp.types import ( from mcp.types import (
Prompt as MCPPrompt, Prompt as MCPPrompt,
)
from mcp.types import (
PromptArgument as MCPPromptArgument, PromptArgument as MCPPromptArgument,
) )
from mcp.types import ( from mcp.types import (
@@ -35,16 +45,6 @@ from mcp.types import (
from mcp.types import ( from mcp.types import (
Tool as MCPTool, Tool as MCPTool,
) )
from pydantic import BaseModel
from pydantic.networks import AnyUrl
from pydantic_settings import BaseSettings, SettingsConfigDict
from mcp.server.fastmcp.exceptions import ResourceError
from mcp.server.fastmcp.prompts import Prompt, PromptManager
from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager
from mcp.server.fastmcp.tools import ToolManager
from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger
from mcp.server.fastmcp.utilities.types import Image
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -226,8 +226,9 @@ class FastMCP:
def tool(self, name: str | None = None, description: str | None = None) -> Callable: def tool(self, name: str | None = None, description: str | None = None) -> Callable:
"""Decorator to register a tool. """Decorator to register a tool.
Tools can optionally request a Context object by adding a parameter with the Context type annotation. Tools can optionally request a Context object by adding a parameter with the
The context provides access to MCP capabilities like logging, progress reporting, and resource access. Context type annotation. The context provides access to MCP capabilities like
logging, progress reporting, and resource access.
Args: Args:
name: Optional name for the tool (defaults to function name) name: Optional name for the tool (defaults to function name)

View File

@@ -1,12 +1,12 @@
import mcp.server.fastmcp
from mcp.server.fastmcp.exceptions import ToolError
from mcp.server.fastmcp.utilities.func_metadata import func_metadata, FuncMetadata
from pydantic import BaseModel, Field
import inspect import inspect
from typing import TYPE_CHECKING, Any, Callable from typing import TYPE_CHECKING, Any, Callable
from pydantic import BaseModel, Field
import mcp.server.fastmcp
from mcp.server.fastmcp.exceptions import ToolError
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
if TYPE_CHECKING: if TYPE_CHECKING:
from mcp.server.fastmcp.server import Context from mcp.server.fastmcp.server import Context
@@ -19,7 +19,8 @@ class Tool(BaseModel):
description: str = Field(description="Description of what the tool does") description: str = Field(description="Description of what the tool does")
parameters: dict = Field(description="JSON schema for tool parameters") parameters: dict = Field(description="JSON schema for tool parameters")
fn_metadata: FuncMetadata = Field( fn_metadata: FuncMetadata = Field(
description="Metadata about the function including a pydantic model for tool arguments" description="Metadata about the function including a pydantic model for tool"
" arguments"
) )
is_async: bool = Field(description="Whether the tool is async") is_async: bool = Field(description="Whether the tool is async")
context_kwarg: str | None = Field( context_kwarg: str | None = Field(

View File

@@ -1,9 +1,8 @@
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.exceptions import ToolError
from mcp.server.fastmcp.tools.base import Tool from mcp.server.fastmcp.tools.base import Tool
from typing import Any, Callable, TYPE_CHECKING
from collections.abc import Callable
from mcp.server.fastmcp.utilities.logging import get_logger from mcp.server.fastmcp.utilities.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@@ -1,21 +1,19 @@
import inspect import inspect
from collections.abc import Callable, Sequence, Awaitable import json
from collections.abc import Awaitable, Callable, Sequence
from typing import ( from typing import (
Annotated, Annotated,
Any, Any,
ForwardRef, ForwardRef,
) )
from pydantic import Field
from mcp.server.fastmcp.exceptions import InvalidSignature
from pydantic._internal._typing_extra import eval_type_backport
import json
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from pydantic import ConfigDict, create_model
from pydantic import WithJsonSchema
from pydantic_core import PydanticUndefined
from mcp.server.fastmcp.utilities.logging import get_logger
from pydantic import BaseModel, ConfigDict, Field, WithJsonSchema, create_model
from pydantic._internal._typing_extra import eval_type_backport
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
from mcp.server.fastmcp.exceptions import InvalidSignature
from mcp.server.fastmcp.utilities.logging import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -105,7 +103,8 @@ class FuncMetadata(BaseModel):
def func_metadata(func: Callable, skip_names: Sequence[str] = ()) -> FuncMetadata: def func_metadata(func: Callable, skip_names: Sequence[str] = ()) -> FuncMetadata:
"""Given a function, return metadata including a pydantic model representing its signature. """Given a function, return metadata including a pydantic model representing its
signature.
The use case for this is The use case for this is
``` ```
@@ -114,7 +113,8 @@ def func_metadata(func: Callable, skip_names: Sequence[str] = ()) -> FuncMetadat
return func(**validated_args.model_dump_one_level()) return func(**validated_args.model_dump_one_level())
``` ```
**critically** it also provides pre-parse helper to attempt to parse things from JSON. **critically** it also provides pre-parse helper to attempt to parse things from
JSON.
Args: Args:
func: The function to convert to a pydantic model func: The function to convert to a pydantic model
@@ -130,7 +130,7 @@ def func_metadata(func: Callable, skip_names: Sequence[str] = ()) -> FuncMetadat
for param in params.values(): for param in params.values():
if param.name.startswith("_"): if param.name.startswith("_"):
raise InvalidSignature( raise InvalidSignature(
f"Parameter {param.name} of {func.__name__} may not start with an underscore" f"Parameter {param.name} of {func.__name__} cannot start with '_'"
) )
if param.name in skip_names: if param.name in skip_names:
continue continue

View File

@@ -3,6 +3,7 @@
import logging import logging
from typing import Literal from typing import Literal
def get_logger(name: str) -> logging.Logger: def get_logger(name: str) -> logging.Logger:
"""Get a logger nested under MCPnamespace. """Get a logger nested under MCPnamespace.
@@ -27,6 +28,7 @@ def configure_logging(
try: try:
from rich.console import Console from rich.console import Console
from rich.logging import RichHandler from rich.logging import RichHandler
handlers.append(RichHandler(console=Console(stderr=True), rich_tracebacks=True)) handlers.append(RichHandler(console=Console(stderr=True), rich_tracebacks=True))
except ImportError: except ImportError:
pass pass

View File

@@ -1,3 +1,3 @@
from .server import Server, NotificationOptions from .server import NotificationOptions, Server
__all__ = ["Server", "NotificationOptions"] __all__ = ["Server", "NotificationOptions"]

View File

@@ -1,16 +1,18 @@
from pydantic import FileUrl
import pytest import pytest
from pydantic import FileUrl
from mcp.server.fastmcp.prompts.base import ( from mcp.server.fastmcp.prompts.base import (
Prompt,
UserMessage,
TextContent,
AssistantMessage, AssistantMessage,
Message, Message,
Prompt,
TextContent,
UserMessage,
) )
from mcp.types import EmbeddedResource, TextResourceContents from mcp.types import EmbeddedResource, TextResourceContents
class TestRenderPrompt: class TestRenderPrompt:
@pytest.mark.anyio
async def test_basic_fn(self): async def test_basic_fn(self):
def fn() -> str: def fn() -> str:
return "Hello, world!" return "Hello, world!"
@@ -20,6 +22,7 @@ class TestRenderPrompt:
UserMessage(content=TextContent(type="text", text="Hello, world!")) UserMessage(content=TextContent(type="text", text="Hello, world!"))
] ]
@pytest.mark.anyio
async def test_async_fn(self): async def test_async_fn(self):
async def fn() -> str: async def fn() -> str:
return "Hello, world!" return "Hello, world!"
@@ -29,6 +32,7 @@ class TestRenderPrompt:
UserMessage(content=TextContent(type="text", text="Hello, world!")) UserMessage(content=TextContent(type="text", text="Hello, world!"))
] ]
@pytest.mark.anyio
async def test_fn_with_args(self): async def test_fn_with_args(self):
async def fn(name: str, age: int = 30) -> str: async def fn(name: str, age: int = 30) -> str:
return f"Hello, {name}! You're {age} years old." return f"Hello, {name}! You're {age} years old."
@@ -42,6 +46,7 @@ class TestRenderPrompt:
) )
] ]
@pytest.mark.anyio
async def test_fn_with_invalid_kwargs(self): async def test_fn_with_invalid_kwargs(self):
async def fn(name: str, age: int = 30) -> str: async def fn(name: str, age: int = 30) -> str:
return f"Hello, {name}! You're {age} years old." return f"Hello, {name}! You're {age} years old."
@@ -50,6 +55,7 @@ class TestRenderPrompt:
with pytest.raises(ValueError): with pytest.raises(ValueError):
await prompt.render(arguments=dict(age=40)) await prompt.render(arguments=dict(age=40))
@pytest.mark.anyio
async def test_fn_returns_message(self): async def test_fn_returns_message(self):
async def fn() -> UserMessage: async def fn() -> UserMessage:
return UserMessage(content="Hello, world!") return UserMessage(content="Hello, world!")
@@ -59,6 +65,7 @@ class TestRenderPrompt:
UserMessage(content=TextContent(type="text", text="Hello, world!")) UserMessage(content=TextContent(type="text", text="Hello, world!"))
] ]
@pytest.mark.anyio
async def test_fn_returns_assistant_message(self): async def test_fn_returns_assistant_message(self):
async def fn() -> AssistantMessage: async def fn() -> AssistantMessage:
return AssistantMessage( return AssistantMessage(
@@ -70,6 +77,7 @@ class TestRenderPrompt:
AssistantMessage(content=TextContent(type="text", text="Hello, world!")) AssistantMessage(content=TextContent(type="text", text="Hello, world!"))
] ]
@pytest.mark.anyio
async def test_fn_returns_multiple_messages(self): async def test_fn_returns_multiple_messages(self):
expected = [ expected = [
UserMessage("Hello, world!"), UserMessage("Hello, world!"),
@@ -83,6 +91,7 @@ class TestRenderPrompt:
prompt = Prompt.from_function(fn) prompt = Prompt.from_function(fn)
assert await prompt.render() == expected assert await prompt.render() == expected
@pytest.mark.anyio
async def test_fn_returns_list_of_strings(self): async def test_fn_returns_list_of_strings(self):
expected = [ expected = [
"Hello, world!", "Hello, world!",
@@ -95,6 +104,7 @@ class TestRenderPrompt:
prompt = Prompt.from_function(fn) prompt = Prompt.from_function(fn)
assert await prompt.render() == [UserMessage(t) for t in expected] assert await prompt.render() == [UserMessage(t) for t in expected]
@pytest.mark.anyio
async def test_fn_returns_resource_content(self): async def test_fn_returns_resource_content(self):
"""Test returning a message with resource content.""" """Test returning a message with resource content."""
@@ -124,6 +134,7 @@ class TestRenderPrompt:
) )
] ]
@pytest.mark.anyio
async def test_fn_returns_mixed_content(self): async def test_fn_returns_mixed_content(self):
"""Test returning messages with mixed content types.""" """Test returning messages with mixed content types."""
@@ -163,6 +174,7 @@ class TestRenderPrompt:
), ),
] ]
@pytest.mark.anyio
async def test_fn_returns_dict_with_resource(self): async def test_fn_returns_dict_with_resource(self):
"""Test returning a dict with resource content.""" """Test returning a dict with resource content."""

View File

@@ -1,5 +1,6 @@
import pytest import pytest
from mcp.server.fastmcp.prompts.base import UserMessage, TextContent, Prompt
from mcp.server.fastmcp.prompts.base import Prompt, TextContent, UserMessage
from mcp.server.fastmcp.prompts.manager import PromptManager from mcp.server.fastmcp.prompts.manager import PromptManager
@@ -60,6 +61,7 @@ class TestPromptManager:
assert len(prompts) == 2 assert len(prompts) == 2
assert prompts == [prompt1, prompt2] assert prompts == [prompt1, prompt2]
@pytest.mark.anyio
async def test_render_prompt(self): async def test_render_prompt(self):
"""Test rendering a prompt.""" """Test rendering a prompt."""
@@ -74,6 +76,7 @@ class TestPromptManager:
UserMessage(content=TextContent(type="text", text="Hello, world!")) UserMessage(content=TextContent(type="text", text="Hello, world!"))
] ]
@pytest.mark.anyio
async def test_render_prompt_with_args(self): async def test_render_prompt_with_args(self):
"""Test rendering a prompt with arguments.""" """Test rendering a prompt with arguments."""
@@ -88,12 +91,14 @@ class TestPromptManager:
UserMessage(content=TextContent(type="text", text="Hello, World!")) UserMessage(content=TextContent(type="text", text="Hello, World!"))
] ]
@pytest.mark.anyio
async def test_render_unknown_prompt(self): async def test_render_unknown_prompt(self):
"""Test rendering a non-existent prompt.""" """Test rendering a non-existent prompt."""
manager = PromptManager() manager = PromptManager()
with pytest.raises(ValueError, match="Unknown prompt: unknown"): with pytest.raises(ValueError, match="Unknown prompt: unknown"):
await manager.render_prompt("unknown") await manager.render_prompt("unknown")
@pytest.mark.anyio
async def test_render_prompt_with_missing_args(self): async def test_render_prompt_with_missing_args(self):
"""Test rendering a prompt with missing required arguments.""" """Test rendering a prompt with missing required arguments."""

View File

@@ -1,8 +1,8 @@
import os import os
import pytest
from pathlib import Path from pathlib import Path
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
import pytest
from pydantic import FileUrl from pydantic import FileUrl
from mcp.server.fastmcp.resources import FileResource from mcp.server.fastmcp.resources import FileResource
@@ -53,6 +53,7 @@ class TestFileResource:
assert isinstance(resource.path, Path) assert isinstance(resource.path, Path)
assert resource.path.is_absolute() assert resource.path.is_absolute()
@pytest.mark.anyio
async def test_read_text_file(self, temp_file: Path): async def test_read_text_file(self, temp_file: Path):
"""Test reading a text file.""" """Test reading a text file."""
resource = FileResource( resource = FileResource(
@@ -64,6 +65,7 @@ class TestFileResource:
assert content == "test content" assert content == "test content"
assert resource.mime_type == "text/plain" assert resource.mime_type == "text/plain"
@pytest.mark.anyio
async def test_read_binary_file(self, temp_file: Path): async def test_read_binary_file(self, temp_file: Path):
"""Test reading a file as binary.""" """Test reading a file as binary."""
resource = FileResource( resource = FileResource(
@@ -85,6 +87,7 @@ class TestFileResource:
path=Path("test.txt"), path=Path("test.txt"),
) )
@pytest.mark.anyio
async def test_missing_file_error(self, temp_file: Path): async def test_missing_file_error(self, temp_file: Path):
"""Test error when file doesn't exist.""" """Test error when file doesn't exist."""
# Create path to non-existent file # Create path to non-existent file
@@ -100,6 +103,7 @@ class TestFileResource:
@pytest.mark.skipif( @pytest.mark.skipif(
os.name == "nt", reason="File permissions behave differently on Windows" os.name == "nt", reason="File permissions behave differently on Windows"
) )
@pytest.mark.anyio
async def test_permission_error(self, temp_file: Path): async def test_permission_error(self, temp_file: Path):
"""Test reading a file without permissions.""" """Test reading a file without permissions."""
temp_file.chmod(0o000) # Remove all permissions temp_file.chmod(0o000) # Remove all permissions

View File

@@ -1,5 +1,6 @@
from pydantic import BaseModel, AnyUrl
import pytest import pytest
from pydantic import AnyUrl, BaseModel
from mcp.server.fastmcp.resources import FunctionResource from mcp.server.fastmcp.resources import FunctionResource
@@ -24,6 +25,7 @@ class TestFunctionResource:
assert resource.mime_type == "text/plain" # default assert resource.mime_type == "text/plain" # default
assert resource.fn == my_func assert resource.fn == my_func
@pytest.mark.anyio
async def test_read_text(self): async def test_read_text(self):
"""Test reading text from a FunctionResource.""" """Test reading text from a FunctionResource."""
@@ -39,6 +41,7 @@ class TestFunctionResource:
assert content == "Hello, world!" assert content == "Hello, world!"
assert resource.mime_type == "text/plain" assert resource.mime_type == "text/plain"
@pytest.mark.anyio
async def test_read_binary(self): async def test_read_binary(self):
"""Test reading binary data from a FunctionResource.""" """Test reading binary data from a FunctionResource."""
@@ -53,6 +56,7 @@ class TestFunctionResource:
content = await resource.read() content = await resource.read()
assert content == b"Hello, world!" assert content == b"Hello, world!"
@pytest.mark.anyio
async def test_json_conversion(self): async def test_json_conversion(self):
"""Test automatic JSON conversion of non-string results.""" """Test automatic JSON conversion of non-string results."""
@@ -68,6 +72,7 @@ class TestFunctionResource:
assert isinstance(content, str) assert isinstance(content, str)
assert '"key": "value"' in content assert '"key": "value"' in content
@pytest.mark.anyio
async def test_error_handling(self): async def test_error_handling(self):
"""Test error handling in FunctionResource.""" """Test error handling in FunctionResource."""
@@ -82,6 +87,7 @@ class TestFunctionResource:
with pytest.raises(ValueError, match="Error reading resource function://test"): with pytest.raises(ValueError, match="Error reading resource function://test"):
await resource.read() await resource.read()
@pytest.mark.anyio
async def test_basemodel_conversion(self): async def test_basemodel_conversion(self):
"""Test handling of BaseModel types.""" """Test handling of BaseModel types."""
@@ -96,6 +102,7 @@ class TestFunctionResource:
content = await resource.read() content = await resource.read()
assert content == '{"name": "test"}' assert content == '{"name": "test"}'
@pytest.mark.anyio
async def test_custom_type_conversion(self): async def test_custom_type_conversion(self):
"""Test handling of custom types.""" """Test handling of custom types."""

View File

@@ -1,6 +1,7 @@
import pytest
from pathlib import Path from pathlib import Path
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
import pytest
from pydantic import AnyUrl, FileUrl from pydantic import AnyUrl, FileUrl
from mcp.server.fastmcp.resources import ( from mcp.server.fastmcp.resources import (
@@ -80,6 +81,7 @@ class TestResourceManager:
manager.add_resource(resource) manager.add_resource(resource)
assert "Resource already exists" not in caplog.text assert "Resource already exists" not in caplog.text
@pytest.mark.anyio
async def test_get_resource(self, temp_file: Path): async def test_get_resource(self, temp_file: Path):
"""Test getting a resource by URI.""" """Test getting a resource by URI."""
manager = ResourceManager() manager = ResourceManager()
@@ -92,6 +94,7 @@ class TestResourceManager:
retrieved = await manager.get_resource(resource.uri) retrieved = await manager.get_resource(resource.uri)
assert retrieved == resource assert retrieved == resource
@pytest.mark.anyio
async def test_get_resource_from_template(self): async def test_get_resource_from_template(self):
"""Test getting a resource through a template.""" """Test getting a resource through a template."""
manager = ResourceManager() manager = ResourceManager()
@@ -111,6 +114,7 @@ class TestResourceManager:
content = await resource.read() content = await resource.read()
assert content == "Hello, world!" assert content == "Hello, world!"
@pytest.mark.anyio
async def test_get_unknown_resource(self): async def test_get_unknown_resource(self):
"""Test getting a non-existent resource.""" """Test getting a non-existent resource."""
manager = ResourceManager() manager = ResourceManager()

View File

@@ -1,4 +1,5 @@
import json import json
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
@@ -45,6 +46,7 @@ class TestResourceTemplate:
assert template.matches("test://foo") is None assert template.matches("test://foo") is None
assert template.matches("other://foo/123") is None assert template.matches("other://foo/123") is None
@pytest.mark.anyio
async def test_create_resource(self): async def test_create_resource(self):
"""Test creating a resource from a template.""" """Test creating a resource from a template."""
@@ -68,6 +70,7 @@ class TestResourceTemplate:
data = json.loads(content) data = json.loads(content)
assert data == {"key": "foo", "value": 123} assert data == {"key": "foo", "value": 123}
@pytest.mark.anyio
async def test_template_error(self): async def test_template_error(self):
"""Test error handling in template resource creation.""" """Test error handling in template resource creation."""
@@ -83,6 +86,7 @@ class TestResourceTemplate:
with pytest.raises(ValueError, match="Error creating resource from template"): with pytest.raises(ValueError, match="Error creating resource from template"):
await template.create_resource("fail://test", {"x": "test"}) await template.create_resource("fail://test", {"x": "test"})
@pytest.mark.anyio
async def test_async_text_resource(self): async def test_async_text_resource(self):
"""Test creating a text resource from async function.""" """Test creating a text resource from async function."""
@@ -104,6 +108,7 @@ class TestResourceTemplate:
content = await resource.read() content = await resource.read()
assert content == "Hello, world!" assert content == "Hello, world!"
@pytest.mark.anyio
async def test_async_binary_resource(self): async def test_async_binary_resource(self):
"""Test creating a binary resource from async function.""" """Test creating a binary resource from async function."""
@@ -125,6 +130,7 @@ class TestResourceTemplate:
content = await resource.read() content = await resource.read()
assert content == b"test" assert content == b"test"
@pytest.mark.anyio
async def test_basemodel_conversion(self): async def test_basemodel_conversion(self):
"""Test handling of BaseModel types.""" """Test handling of BaseModel types."""
@@ -152,6 +158,7 @@ class TestResourceTemplate:
data = json.loads(content) data = json.loads(content)
assert data == {"key": "foo", "value": 123} assert data == {"key": "foo", "value": 123}
@pytest.mark.anyio
async def test_custom_type_conversion(self): async def test_custom_type_conversion(self):
"""Test handling of custom types.""" """Test handling of custom types."""

View File

@@ -90,6 +90,7 @@ class TestResourceValidation:
) )
assert resource.mime_type == "application/json" assert resource.mime_type == "application/json"
@pytest.mark.anyio
async def test_resource_read_abstract(self): async def test_resource_read_abstract(self):
"""Test that Resource.read() is abstract.""" """Test that Resource.read() is abstract."""

View File

@@ -1,8 +1,10 @@
import json import json
from mcp.server.fastmcp import FastMCP
import pytest
from pathlib import Path from pathlib import Path
import pytest
from mcp.server.fastmcp import FastMCP
@pytest.fixture() @pytest.fixture()
def test_dir(tmp_path_factory) -> Path: def test_dir(tmp_path_factory) -> Path:
@@ -71,6 +73,7 @@ def tools(mcp: FastMCP, test_dir: Path) -> FastMCP:
return mcp return mcp
@pytest.mark.anyio
async def test_list_resources(mcp: FastMCP): async def test_list_resources(mcp: FastMCP):
resources = await mcp.list_resources() resources = await mcp.list_resources()
assert len(resources) == 4 assert len(resources) == 4
@@ -83,6 +86,7 @@ async def test_list_resources(mcp: FastMCP):
] ]
@pytest.mark.anyio
async def test_read_resource_dir(mcp: FastMCP): async def test_read_resource_dir(mcp: FastMCP):
files = await mcp.read_resource("dir://test_dir") files = await mcp.read_resource("dir://test_dir")
files = json.loads(files) files = json.loads(files)
@@ -94,11 +98,13 @@ async def test_read_resource_dir(mcp: FastMCP):
] ]
@pytest.mark.anyio
async def test_read_resource_file(mcp: FastMCP): async def test_read_resource_file(mcp: FastMCP):
result = await mcp.read_resource("file://test_dir/example.py") result = await mcp.read_resource("file://test_dir/example.py")
assert result == "print('hello world')" assert result == "print('hello world')"
@pytest.mark.anyio
async def test_delete_file(mcp: FastMCP, test_dir: Path): async def test_delete_file(mcp: FastMCP, test_dir: Path):
await mcp.call_tool( await mcp.call_tool(
"delete_file", arguments=dict(path=str(test_dir / "example.py")) "delete_file", arguments=dict(path=str(test_dir / "example.py"))
@@ -106,6 +112,7 @@ async def test_delete_file(mcp: FastMCP, test_dir: Path):
assert not (test_dir / "example.py").exists() assert not (test_dir / "example.py").exists()
@pytest.mark.anyio
async def test_delete_file_and_check_resources(mcp: FastMCP, test_dir: Path): async def test_delete_file_and_check_resources(mcp: FastMCP, test_dir: Path):
await mcp.call_tool( await mcp.call_tool(
"delete_file", arguments=dict(path=str(test_dir / "example.py")) "delete_file", arguments=dict(path=str(test_dir / "example.py"))

View File

@@ -85,6 +85,7 @@ def complex_arguments_fn(
return "ok!" return "ok!"
@pytest.mark.anyio
async def test_complex_function_runtime_arg_validation_non_json(): async def test_complex_function_runtime_arg_validation_non_json():
"""Test that basic non-JSON arguments are validated correctly""" """Test that basic non-JSON arguments are validated correctly"""
meta = func_metadata(complex_arguments_fn) meta = func_metadata(complex_arguments_fn)
@@ -121,6 +122,7 @@ async def test_complex_function_runtime_arg_validation_non_json():
) )
@pytest.mark.anyio
async def test_complex_function_runtime_arg_validation_with_json(): async def test_complex_function_runtime_arg_validation_with_json():
"""Test that JSON string arguments are parsed and validated correctly""" """Test that JSON string arguments are parsed and validated correctly"""
meta = func_metadata(complex_arguments_fn) meta = func_metadata(complex_arguments_fn)
@@ -140,7 +142,7 @@ async def test_complex_function_runtime_arg_validation_with_json():
"unannotated": "test", "unannotated": "test",
"my_model_a": "{}", # JSON string "my_model_a": "{}", # JSON string
"my_model_a_forward_ref": "{}", # JSON string "my_model_a_forward_ref": "{}", # JSON string
"my_model_b": '{"how_many_shrimp": 5, "ok": {"x": 1}, "y": null}', # JSON string "my_model_b": '{"how_many_shrimp": 5, "ok": {"x": 1}, "y": null}',
}, },
arguments_to_pass_directly=None, arguments_to_pass_directly=None,
) )
@@ -197,6 +199,7 @@ def test_skip_names():
assert model.also_keep == 2.5 # type: ignore assert model.also_keep == 2.5 # type: ignore
@pytest.mark.anyio
async def test_lambda_function(): async def test_lambda_function():
"""Test lambda function schema and validation""" """Test lambda function schema and validation"""
fn = lambda x, y=5: x # noqa: E731 fn = lambda x, y=5: x # noqa: E731
@@ -297,7 +300,7 @@ def test_complex_function_json_schema():
}, },
"field_with_default_via_field_annotation_before_nondefault_arg": { "field_with_default_via_field_annotation_before_nondefault_arg": {
"default": 1, "default": 1,
"title": "Field With Default Via Field Annotation Before Nondefault Arg", "title": "Field With Default Via Field Annotation Before Arg",
"type": "integer", "type": "integer",
}, },
"unannotated": {"title": "unannotated", "type": "string"}, "unannotated": {"title": "unannotated", "type": "string"},
@@ -316,11 +319,7 @@ def test_complex_function_json_schema():
"type": "string", "type": "string",
}, },
"my_model_a_with_default": { "my_model_a_with_default": {
"allOf": [ "allOf": [{"$ref": "#/$defs/SomeInputModelA"}],
{
"$ref": "#/$defs/SomeInputModelA"
}
],
"default": {}, "default": {},
}, },
"an_int_with_default": { "an_int_with_default": {

View File

@@ -3,32 +3,34 @@ from pathlib import Path
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Union
import pytest import pytest
from mcp.shared.exceptions import McpError
from mcp.shared.memory import (
create_connected_server_and_client_session as client_session,
)
from mcp.types import (
ImageContent,
TextContent,
TextResourceContents,
BlobResourceContents,
)
from pydantic import AnyUrl from pydantic import AnyUrl
from mcp.server.fastmcp import Context, FastMCP from mcp.server.fastmcp import Context, FastMCP
from mcp.server.fastmcp.prompts.base import EmbeddedResource, Message, UserMessage from mcp.server.fastmcp.prompts.base import EmbeddedResource, Message, UserMessage
from mcp.server.fastmcp.resources import FileResource, FunctionResource from mcp.server.fastmcp.resources import FileResource, FunctionResource
from mcp.server.fastmcp.utilities.types import Image from mcp.server.fastmcp.utilities.types import Image
from mcp.shared.exceptions import McpError
from mcp.shared.memory import (
create_connected_server_and_client_session as client_session,
)
from mcp.types import (
BlobResourceContents,
ImageContent,
TextContent,
TextResourceContents,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from mcp.server.fastmcp import Context from mcp.server.fastmcp import Context
class TestServer: class TestServer:
@pytest.mark.anyio
async def test_create_server(self): async def test_create_server(self):
mcp = FastMCP() mcp = FastMCP()
assert mcp.name == "FastMCP" assert mcp.name == "FastMCP"
@pytest.mark.anyio
async def test_add_tool_decorator(self): async def test_add_tool_decorator(self):
mcp = FastMCP() mcp = FastMCP()
@@ -38,6 +40,7 @@ class TestServer:
assert len(mcp._tool_manager.list_tools()) == 1 assert len(mcp._tool_manager.list_tools()) == 1
@pytest.mark.anyio
async def test_add_tool_decorator_incorrect_usage(self): async def test_add_tool_decorator_incorrect_usage(self):
mcp = FastMCP() mcp = FastMCP()
@@ -47,6 +50,7 @@ class TestServer:
def add(x: int, y: int) -> int: def add(x: int, y: int) -> int:
return x + y return x + y
@pytest.mark.anyio
async def test_add_resource_decorator(self): async def test_add_resource_decorator(self):
mcp = FastMCP() mcp = FastMCP()
@@ -56,6 +60,7 @@ class TestServer:
assert len(mcp._resource_manager._templates) == 1 assert len(mcp._resource_manager._templates) == 1
@pytest.mark.anyio
async def test_add_resource_decorator_incorrect_usage(self): async def test_add_resource_decorator_incorrect_usage(self):
mcp = FastMCP() mcp = FastMCP()
@@ -88,12 +93,14 @@ def mixed_content_tool_fn() -> list[Union[TextContent, ImageContent]]:
class TestServerTools: class TestServerTools:
@pytest.mark.anyio
async def test_add_tool(self): async def test_add_tool(self):
mcp = FastMCP() mcp = FastMCP()
mcp.add_tool(tool_fn) mcp.add_tool(tool_fn)
mcp.add_tool(tool_fn) mcp.add_tool(tool_fn)
assert len(mcp._tool_manager.list_tools()) == 1 assert len(mcp._tool_manager.list_tools()) == 1
@pytest.mark.anyio
async def test_list_tools(self): async def test_list_tools(self):
mcp = FastMCP() mcp = FastMCP()
mcp.add_tool(tool_fn) mcp.add_tool(tool_fn)
@@ -101,6 +108,7 @@ class TestServerTools:
tools = await client.list_tools() tools = await client.list_tools()
assert len(tools.tools) == 1 assert len(tools.tools) == 1
@pytest.mark.anyio
async def test_call_tool(self): async def test_call_tool(self):
mcp = FastMCP() mcp = FastMCP()
mcp.add_tool(tool_fn) mcp.add_tool(tool_fn)
@@ -109,6 +117,7 @@ class TestServerTools:
assert not hasattr(result, "error") assert not hasattr(result, "error")
assert len(result.content) > 0 assert len(result.content) > 0
@pytest.mark.anyio
async def test_tool_exception_handling(self): async def test_tool_exception_handling(self):
mcp = FastMCP() mcp = FastMCP()
mcp.add_tool(error_tool_fn) mcp.add_tool(error_tool_fn)
@@ -120,6 +129,7 @@ class TestServerTools:
assert "Test error" in content.text assert "Test error" in content.text
assert result.isError is True assert result.isError is True
@pytest.mark.anyio
async def test_tool_error_handling(self): async def test_tool_error_handling(self):
mcp = FastMCP() mcp = FastMCP()
mcp.add_tool(error_tool_fn) mcp.add_tool(error_tool_fn)
@@ -131,6 +141,7 @@ class TestServerTools:
assert "Test error" in content.text assert "Test error" in content.text
assert result.isError is True assert result.isError is True
@pytest.mark.anyio
async def test_tool_error_details(self): async def test_tool_error_details(self):
"""Test that exception details are properly formatted in the response""" """Test that exception details are properly formatted in the response"""
mcp = FastMCP() mcp = FastMCP()
@@ -143,6 +154,7 @@ class TestServerTools:
assert "Test error" in content.text assert "Test error" in content.text
assert result.isError is True assert result.isError is True
@pytest.mark.anyio
async def test_tool_return_value_conversion(self): async def test_tool_return_value_conversion(self):
mcp = FastMCP() mcp = FastMCP()
mcp.add_tool(tool_fn) mcp.add_tool(tool_fn)
@@ -153,6 +165,7 @@ class TestServerTools:
assert isinstance(content, TextContent) assert isinstance(content, TextContent)
assert content.text == "3" assert content.text == "3"
@pytest.mark.anyio
async def test_tool_image_helper(self, tmp_path: Path): async def test_tool_image_helper(self, tmp_path: Path):
# Create a test image # Create a test image
image_path = tmp_path / "test.png" image_path = tmp_path / "test.png"
@@ -171,6 +184,7 @@ class TestServerTools:
decoded = base64.b64decode(content.data) decoded = base64.b64decode(content.data)
assert decoded == b"fake png data" assert decoded == b"fake png data"
@pytest.mark.anyio
async def test_tool_mixed_content(self): async def test_tool_mixed_content(self):
mcp = FastMCP() mcp = FastMCP()
mcp.add_tool(mixed_content_tool_fn) mcp.add_tool(mixed_content_tool_fn)
@@ -185,8 +199,10 @@ class TestServerTools:
assert content2.mimeType == "image/png" assert content2.mimeType == "image/png"
assert content2.data == "abc" assert content2.data == "abc"
@pytest.mark.anyio
async def test_tool_mixed_list_with_image(self, tmp_path: Path): async def test_tool_mixed_list_with_image(self, tmp_path: Path):
"""Test that lists containing Image objects and other types are handled correctly""" """Test that lists containing Image objects and other types are handled
correctly"""
# Create a test image # Create a test image
image_path = tmp_path / "test.png" image_path = tmp_path / "test.png"
image_path.write_bytes(b"test image data") image_path.write_bytes(b"test image data")
@@ -224,6 +240,7 @@ class TestServerTools:
class TestServerResources: class TestServerResources:
@pytest.mark.anyio
async def test_text_resource(self): async def test_text_resource(self):
mcp = FastMCP() mcp = FastMCP()
@@ -240,6 +257,7 @@ class TestServerResources:
assert isinstance(result.contents[0], TextResourceContents) assert isinstance(result.contents[0], TextResourceContents)
assert result.contents[0].text == "Hello, world!" assert result.contents[0].text == "Hello, world!"
@pytest.mark.anyio
async def test_binary_resource(self): async def test_binary_resource(self):
mcp = FastMCP() mcp = FastMCP()
@@ -259,6 +277,7 @@ class TestServerResources:
assert isinstance(result.contents[0], BlobResourceContents) assert isinstance(result.contents[0], BlobResourceContents)
assert result.contents[0].blob == base64.b64encode(b"Binary data").decode() assert result.contents[0].blob == base64.b64encode(b"Binary data").decode()
@pytest.mark.anyio
async def test_file_resource_text(self, tmp_path: Path): async def test_file_resource_text(self, tmp_path: Path):
mcp = FastMCP() mcp = FastMCP()
@@ -276,6 +295,7 @@ class TestServerResources:
assert isinstance(result.contents[0], TextResourceContents) assert isinstance(result.contents[0], TextResourceContents)
assert result.contents[0].text == "Hello from file!" assert result.contents[0].text == "Hello from file!"
@pytest.mark.anyio
async def test_file_resource_binary(self, tmp_path: Path): async def test_file_resource_binary(self, tmp_path: Path):
mcp = FastMCP() mcp = FastMCP()
@@ -301,6 +321,7 @@ class TestServerResources:
class TestServerResourceTemplates: class TestServerResourceTemplates:
@pytest.mark.anyio
async def test_resource_with_params(self): async def test_resource_with_params(self):
"""Test that a resource with function parameters raises an error if the URI """Test that a resource with function parameters raises an error if the URI
parameters don't match""" parameters don't match"""
@@ -312,6 +333,7 @@ class TestServerResourceTemplates:
def get_data_fn(param: str) -> str: def get_data_fn(param: str) -> str:
return f"Data: {param}" return f"Data: {param}"
@pytest.mark.anyio
async def test_resource_with_uri_params(self): async def test_resource_with_uri_params(self):
"""Test that a resource with URI parameters is automatically a template""" """Test that a resource with URI parameters is automatically a template"""
mcp = FastMCP() mcp = FastMCP()
@@ -322,6 +344,7 @@ class TestServerResourceTemplates:
def get_data() -> str: def get_data() -> str:
return "Data" return "Data"
@pytest.mark.anyio
async def test_resource_with_untyped_params(self): async def test_resource_with_untyped_params(self):
"""Test that a resource with untyped parameters raises an error""" """Test that a resource with untyped parameters raises an error"""
mcp = FastMCP() mcp = FastMCP()
@@ -330,6 +353,7 @@ class TestServerResourceTemplates:
def get_data(param) -> str: def get_data(param) -> str:
return "Data" return "Data"
@pytest.mark.anyio
async def test_resource_matching_params(self): async def test_resource_matching_params(self):
"""Test that a resource with matching URI and function parameters works""" """Test that a resource with matching URI and function parameters works"""
mcp = FastMCP() mcp = FastMCP()
@@ -343,6 +367,7 @@ class TestServerResourceTemplates:
assert isinstance(result.contents[0], TextResourceContents) assert isinstance(result.contents[0], TextResourceContents)
assert result.contents[0].text == "Data for test" assert result.contents[0].text == "Data for test"
@pytest.mark.anyio
async def test_resource_mismatched_params(self): async def test_resource_mismatched_params(self):
"""Test that mismatched parameters raise an error""" """Test that mismatched parameters raise an error"""
mcp = FastMCP() mcp = FastMCP()
@@ -353,6 +378,7 @@ class TestServerResourceTemplates:
def get_data(user: str) -> str: def get_data(user: str) -> str:
return f"Data for {user}" return f"Data for {user}"
@pytest.mark.anyio
async def test_resource_multiple_params(self): async def test_resource_multiple_params(self):
"""Test that multiple parameters work correctly""" """Test that multiple parameters work correctly"""
mcp = FastMCP() mcp = FastMCP()
@@ -368,6 +394,7 @@ class TestServerResourceTemplates:
assert isinstance(result.contents[0], TextResourceContents) assert isinstance(result.contents[0], TextResourceContents)
assert result.contents[0].text == "Data for cursor/fastmcp" assert result.contents[0].text == "Data for cursor/fastmcp"
@pytest.mark.anyio
async def test_resource_multiple_mismatched_params(self): async def test_resource_multiple_mismatched_params(self):
"""Test that mismatched parameters raise an error""" """Test that mismatched parameters raise an error"""
mcp = FastMCP() mcp = FastMCP()
@@ -390,6 +417,7 @@ class TestServerResourceTemplates:
assert isinstance(result.contents[0], TextResourceContents) assert isinstance(result.contents[0], TextResourceContents)
assert result.contents[0].text == "Static data" assert result.contents[0].text == "Static data"
@pytest.mark.anyio
async def test_template_to_resource_conversion(self): async def test_template_to_resource_conversion(self):
"""Test that templates are properly converted to resources when accessed""" """Test that templates are properly converted to resources when accessed"""
mcp = FastMCP() mcp = FastMCP()
@@ -412,6 +440,7 @@ class TestServerResourceTemplates:
class TestContextInjection: class TestContextInjection:
"""Test context injection in tools.""" """Test context injection in tools."""
@pytest.mark.anyio
async def test_context_detection(self): async def test_context_detection(self):
"""Test that context parameters are properly detected.""" """Test that context parameters are properly detected."""
mcp = FastMCP() mcp = FastMCP()
@@ -422,6 +451,7 @@ class TestContextInjection:
tool = mcp._tool_manager.add_tool(tool_with_context) tool = mcp._tool_manager.add_tool(tool_with_context)
assert tool.context_kwarg == "ctx" assert tool.context_kwarg == "ctx"
@pytest.mark.anyio
async def test_context_injection(self): async def test_context_injection(self):
"""Test that context is properly injected into tool calls.""" """Test that context is properly injected into tool calls."""
mcp = FastMCP() mcp = FastMCP()
@@ -439,6 +469,7 @@ class TestContextInjection:
assert "Request" in content.text assert "Request" in content.text
assert "42" in content.text assert "42" in content.text
@pytest.mark.anyio
async def test_async_context(self): async def test_async_context(self):
"""Test that context works in async functions.""" """Test that context works in async functions."""
mcp = FastMCP() mcp = FastMCP()
@@ -456,6 +487,7 @@ class TestContextInjection:
assert "Async request" in content.text assert "Async request" in content.text
assert "42" in content.text assert "42" in content.text
@pytest.mark.anyio
async def test_context_logging(self): async def test_context_logging(self):
"""Test that context logging methods work.""" """Test that context logging methods work."""
mcp = FastMCP() mcp = FastMCP()
@@ -475,6 +507,7 @@ class TestContextInjection:
assert isinstance(content, TextContent) assert isinstance(content, TextContent)
assert "Logged messages for test" in content.text assert "Logged messages for test" in content.text
@pytest.mark.anyio
async def test_optional_context(self): async def test_optional_context(self):
"""Test that context is optional.""" """Test that context is optional."""
mcp = FastMCP() mcp = FastMCP()
@@ -490,6 +523,7 @@ class TestContextInjection:
assert isinstance(content, TextContent) assert isinstance(content, TextContent)
assert content.text == "42" assert content.text == "42"
@pytest.mark.anyio
async def test_context_resource_access(self): async def test_context_resource_access(self):
"""Test that context can access resources.""" """Test that context can access resources."""
mcp = FastMCP() mcp = FastMCP()
@@ -514,6 +548,7 @@ class TestContextInjection:
class TestServerPrompts: class TestServerPrompts:
"""Test prompt functionality in FastMCP server.""" """Test prompt functionality in FastMCP server."""
@pytest.mark.anyio
async def test_prompt_decorator(self): async def test_prompt_decorator(self):
"""Test that the prompt decorator registers prompts correctly.""" """Test that the prompt decorator registers prompts correctly."""
mcp = FastMCP() mcp = FastMCP()
@@ -530,6 +565,7 @@ class TestServerPrompts:
assert isinstance(content[0].content, TextContent) assert isinstance(content[0].content, TextContent)
assert content[0].content.text == "Hello, world!" assert content[0].content.text == "Hello, world!"
@pytest.mark.anyio
async def test_prompt_decorator_with_name(self): async def test_prompt_decorator_with_name(self):
"""Test prompt decorator with custom name.""" """Test prompt decorator with custom name."""
mcp = FastMCP() mcp = FastMCP()
@@ -545,6 +581,7 @@ class TestServerPrompts:
assert isinstance(content[0].content, TextContent) assert isinstance(content[0].content, TextContent)
assert content[0].content.text == "Hello, world!" assert content[0].content.text == "Hello, world!"
@pytest.mark.anyio
async def test_prompt_decorator_with_description(self): async def test_prompt_decorator_with_description(self):
"""Test prompt decorator with custom description.""" """Test prompt decorator with custom description."""
mcp = FastMCP() mcp = FastMCP()
@@ -569,6 +606,7 @@ class TestServerPrompts:
def fn() -> str: def fn() -> str:
return "Hello, world!" return "Hello, world!"
@pytest.mark.anyio
async def test_list_prompts(self): async def test_list_prompts(self):
"""Test listing prompts through MCP protocol.""" """Test listing prompts through MCP protocol."""
mcp = FastMCP() mcp = FastMCP()
@@ -590,6 +628,7 @@ class TestServerPrompts:
assert prompt.arguments[1].name == "optional" assert prompt.arguments[1].name == "optional"
assert prompt.arguments[1].required is False assert prompt.arguments[1].required is False
@pytest.mark.anyio
async def test_get_prompt(self): async def test_get_prompt(self):
"""Test getting a prompt through MCP protocol.""" """Test getting a prompt through MCP protocol."""
mcp = FastMCP() mcp = FastMCP()
@@ -607,6 +646,7 @@ class TestServerPrompts:
assert isinstance(content, TextContent) assert isinstance(content, TextContent)
assert content.text == "Hello, World!" assert content.text == "Hello, World!"
@pytest.mark.anyio
async def test_get_prompt_with_resource(self): async def test_get_prompt_with_resource(self):
"""Test getting a prompt that returns resource content.""" """Test getting a prompt that returns resource content."""
mcp = FastMCP() mcp = FastMCP()
@@ -636,6 +676,7 @@ class TestServerPrompts:
assert resource.text == "File contents" assert resource.text == "File contents"
assert resource.mimeType == "text/plain" assert resource.mimeType == "text/plain"
@pytest.mark.anyio
async def test_get_unknown_prompt(self): async def test_get_unknown_prompt(self):
"""Test error when getting unknown prompt.""" """Test error when getting unknown prompt."""
mcp = FastMCP() mcp = FastMCP()
@@ -643,6 +684,7 @@ class TestServerPrompts:
with pytest.raises(McpError, match="Unknown prompt"): with pytest.raises(McpError, match="Unknown prompt"):
await client.get_prompt("unknown") await client.get_prompt("unknown")
@pytest.mark.anyio
async def test_get_prompt_missing_args(self): async def test_get_prompt_missing_args(self):
"""Test error when required arguments are missing.""" """Test error when required arguments are missing."""
mcp = FastMCP() mcp = FastMCP()

View File

@@ -1,9 +1,10 @@
import json
import logging import logging
from typing import Optional from typing import Optional
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
import json
from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.exceptions import ToolError
from mcp.server.fastmcp.tools import ToolManager from mcp.server.fastmcp.tools import ToolManager
@@ -27,6 +28,7 @@ class TestAddTools:
assert tool.parameters["properties"]["a"]["type"] == "integer" assert tool.parameters["properties"]["a"]["type"] == "integer"
assert tool.parameters["properties"]["b"]["type"] == "integer" assert tool.parameters["properties"]["b"]["type"] == "integer"
@pytest.mark.anyio
async def test_async_function(self): async def test_async_function(self):
"""Test registering and running an async function.""" """Test registering and running an async function."""
@@ -111,6 +113,7 @@ class TestAddTools:
class TestCallTools: class TestCallTools:
@pytest.mark.anyio
async def test_call_tool(self): async def test_call_tool(self):
def add(a: int, b: int) -> int: def add(a: int, b: int) -> int:
"""Add two numbers.""" """Add two numbers."""
@@ -121,6 +124,7 @@ class TestCallTools:
result = await manager.call_tool("add", {"a": 1, "b": 2}) result = await manager.call_tool("add", {"a": 1, "b": 2})
assert result == 3 assert result == 3
@pytest.mark.anyio
async def test_call_async_tool(self): async def test_call_async_tool(self):
async def double(n: int) -> int: async def double(n: int) -> int:
"""Double a number.""" """Double a number."""
@@ -131,6 +135,7 @@ class TestCallTools:
result = await manager.call_tool("double", {"n": 5}) result = await manager.call_tool("double", {"n": 5})
assert result == 10 assert result == 10
@pytest.mark.anyio
async def test_call_tool_with_default_args(self): async def test_call_tool_with_default_args(self):
def add(a: int, b: int = 1) -> int: def add(a: int, b: int = 1) -> int:
"""Add two numbers.""" """Add two numbers."""
@@ -141,6 +146,7 @@ class TestCallTools:
result = await manager.call_tool("add", {"a": 1}) result = await manager.call_tool("add", {"a": 1})
assert result == 2 assert result == 2
@pytest.mark.anyio
async def test_call_tool_with_missing_args(self): async def test_call_tool_with_missing_args(self):
def add(a: int, b: int) -> int: def add(a: int, b: int) -> int:
"""Add two numbers.""" """Add two numbers."""
@@ -151,11 +157,13 @@ class TestCallTools:
with pytest.raises(ToolError): with pytest.raises(ToolError):
await manager.call_tool("add", {"a": 1}) await manager.call_tool("add", {"a": 1})
@pytest.mark.anyio
async def test_call_unknown_tool(self): async def test_call_unknown_tool(self):
manager = ToolManager() manager = ToolManager()
with pytest.raises(ToolError): with pytest.raises(ToolError):
await manager.call_tool("unknown", {"a": 1}) await manager.call_tool("unknown", {"a": 1})
@pytest.mark.anyio
async def test_call_tool_with_list_int_input(self): async def test_call_tool_with_list_int_input(self):
def sum_vals(vals: list[int]) -> int: def sum_vals(vals: list[int]) -> int:
return sum(vals) return sum(vals)
@@ -168,6 +176,7 @@ class TestCallTools:
result = await manager.call_tool("sum_vals", {"vals": [1, 2, 3]}) result = await manager.call_tool("sum_vals", {"vals": [1, 2, 3]})
assert result == 6 assert result == 6
@pytest.mark.anyio
async def test_call_tool_with_list_str_or_str_input(self): async def test_call_tool_with_list_str_or_str_input(self):
def concat_strs(vals: list[str] | str) -> str: def concat_strs(vals: list[str] | str) -> str:
return vals if isinstance(vals, str) else "".join(vals) return vals if isinstance(vals, str) else "".join(vals)
@@ -184,6 +193,7 @@ class TestCallTools:
result = await manager.call_tool("concat_strs", {"vals": '"a"'}) result = await manager.call_tool("concat_strs", {"vals": '"a"'})
assert result == '"a"' assert result == '"a"'
@pytest.mark.anyio
async def test_call_tool_with_complex_model(self): async def test_call_tool_with_complex_model(self):
from mcp.server.fastmcp import Context from mcp.server.fastmcp import Context
@@ -212,6 +222,7 @@ class TestCallTools:
class TestToolSchema: class TestToolSchema:
@pytest.mark.anyio
async def test_context_arg_excluded_from_schema(self): async def test_context_arg_excluded_from_schema(self):
from mcp.server.fastmcp import Context from mcp.server.fastmcp import Context
@@ -229,7 +240,8 @@ class TestContextHandling:
"""Test context handling in the tool manager.""" """Test context handling in the tool manager."""
def test_context_parameter_detection(self): def test_context_parameter_detection(self):
"""Test that context parameters are properly detected in Tool.from_function().""" """Test that context parameters are properly detected in
Tool.from_function()."""
from mcp.server.fastmcp import Context from mcp.server.fastmcp import Context
def tool_with_context(x: int, ctx: Context) -> str: def tool_with_context(x: int, ctx: Context) -> str:
@@ -245,6 +257,7 @@ class TestContextHandling:
tool = manager.add_tool(tool_without_context) tool = manager.add_tool(tool_without_context)
assert tool.context_kwarg is None assert tool.context_kwarg is None
@pytest.mark.anyio
async def test_context_injection(self): async def test_context_injection(self):
"""Test that context is properly injected during tool execution.""" """Test that context is properly injected during tool execution."""
from mcp.server.fastmcp import Context, FastMCP from mcp.server.fastmcp import Context, FastMCP
@@ -261,6 +274,7 @@ class TestContextHandling:
result = await manager.call_tool("tool_with_context", {"x": 42}, context=ctx) result = await manager.call_tool("tool_with_context", {"x": 42}, context=ctx)
assert result == "42" assert result == "42"
@pytest.mark.anyio
async def test_context_injection_async(self): async def test_context_injection_async(self):
"""Test that context is properly injected in async tools.""" """Test that context is properly injected in async tools."""
from mcp.server.fastmcp import Context, FastMCP from mcp.server.fastmcp import Context, FastMCP
@@ -277,6 +291,7 @@ class TestContextHandling:
result = await manager.call_tool("async_tool", {"x": 42}, context=ctx) result = await manager.call_tool("async_tool", {"x": 42}, context=ctx)
assert result == "42" assert result == "42"
@pytest.mark.anyio
async def test_context_optional(self): async def test_context_optional(self):
"""Test that context is optional when calling tools.""" """Test that context is optional when calling tools."""
from mcp.server.fastmcp import Context from mcp.server.fastmcp import Context
@@ -290,6 +305,7 @@ class TestContextHandling:
result = await manager.call_tool("tool_with_context", {"x": 42}) result = await manager.call_tool("tool_with_context", {"x": 42})
assert result == "42" assert result == "42"
@pytest.mark.anyio
async def test_context_error_handling(self): async def test_context_error_handling(self):
"""Test error handling when context injection fails.""" """Test error handling when context injection fails."""
from mcp.server.fastmcp import Context, FastMCP from mcp.server.fastmcp import Context, FastMCP

View File

@@ -1,50 +1,54 @@
"""Tests for example servers""" """Tests for example servers"""
import pytest import pytest
from mcp.shared.memory import create_connected_server_and_client_session as client_session
from mcp.shared.memory import (
create_connected_server_and_client_session as client_session,
)
@pytest.mark.anyio @pytest.mark.anyio
async def test_simple_echo(): async def test_simple_echo():
"""Test the simple echo server""" """Test the simple echo server"""
from examples.fastmcp.simple_echo import mcp from examples.fastmcp.simple_echo import mcp
async with client_session(mcp._mcp_server) as client: async with client_session(mcp._mcp_server) as client:
result = await client.call_tool("echo", {"text": "hello"}) result = await client.call_tool("echo", {"text": "hello"})
assert len(result.content) == 1 assert len(result.content) == 1
content = result.content[0] content = result.content[0]
assert content.text == "hello" assert content.text == "hello"
@pytest.mark.anyio @pytest.mark.anyio
async def test_complex_inputs(): async def test_complex_inputs():
"""Test the complex inputs server""" """Test the complex inputs server"""
from examples.fastmcp.complex_inputs import mcp from examples.fastmcp.complex_inputs import mcp
async with client_session(mcp._mcp_server) as client: async with client_session(mcp._mcp_server) as client:
tank = { tank = {"shrimp": [{"name": "bob"}, {"name": "alice"}]}
"shrimp": [{"name": "bob"}, {"name": "alice"}] result = await client.call_tool(
} "name_shrimp", {"tank": tank, "extra_names": ["charlie"]}
result = await client.call_tool("name_shrimp", { )
"tank": tank,
"extra_names": ["charlie"]
})
assert len(result.content) == 3 assert len(result.content) == 3
assert result.content[0].text == "bob" assert result.content[0].text == "bob"
assert result.content[1].text == "alice" assert result.content[1].text == "alice"
assert result.content[2].text == "charlie" assert result.content[2].text == "charlie"
@pytest.mark.anyio @pytest.mark.anyio
async def test_desktop(): async def test_desktop():
"""Test the desktop server""" """Test the desktop server"""
from examples.fastmcp.desktop import mcp from examples.fastmcp.desktop import mcp
async with client_session(mcp._mcp_server) as client: async with client_session(mcp._mcp_server) as client:
# Test the add function # Test the add function
result = await client.call_tool("add", {"a": 1, "b": 2}) result = await client.call_tool("add", {"a": 1, "b": 2})
assert len(result.content) == 1 assert len(result.content) == 1
content = result.content[0] content = result.content[0]
assert content.text == "3" assert content.text == "3"
# Test the desktop resource # Test the desktop resource
result = await client.read_resource("dir://desktop") result = await client.read_resource("dir://desktop")
assert len(result.contents) == 1 assert len(result.contents) == 1
content = result.contents[0] content = result.contents[0]
assert isinstance(content.text, str) assert isinstance(content.text, str)

View File

@@ -1,3 +1,5 @@
import pytest
from mcp.types import ( from mcp.types import (
LATEST_PROTOCOL_VERSION, LATEST_PROTOCOL_VERSION,
ClientRequest, ClientRequest,
@@ -6,7 +8,8 @@ from mcp.types import (
) )
def test_jsonrpc_request(): @pytest.mark.anyio
async def test_jsonrpc_request():
json_data = { json_data = {
"jsonrpc": "2.0", "jsonrpc": "2.0",
"id": 1, "id": 1,

14
uv.lock generated
View File

@@ -216,7 +216,6 @@ rich = [
dev = [ dev = [
{ name = "pyright" }, { name = "pyright" },
{ name = "pytest" }, { name = "pytest" },
{ name = "pytest-asyncio" },
{ name = "pytest-flakefinder" }, { name = "pytest-flakefinder" },
{ name = "pytest-xdist" }, { name = "pytest-xdist" },
{ name = "ruff" }, { name = "ruff" },
@@ -241,7 +240,6 @@ requires-dist = [
dev = [ dev = [
{ name = "pyright", specifier = ">=1.1.378" }, { name = "pyright", specifier = ">=1.1.378" },
{ name = "pytest", specifier = ">=8.3.3" }, { name = "pytest", specifier = ">=8.3.3" },
{ name = "pytest-asyncio", specifier = ">=0.24.0" },
{ name = "pytest-flakefinder", specifier = ">=1.1.0" }, { name = "pytest-flakefinder", specifier = ">=1.1.0" },
{ name = "pytest-xdist", specifier = ">=3.6.1" }, { name = "pytest-xdist", specifier = ">=3.6.1" },
{ name = "ruff", specifier = ">=0.6.9" }, { name = "ruff", specifier = ">=0.6.9" },
@@ -526,18 +524,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/6b/77/7440a06a8ead44c7757a64362dd22df5760f9b12dc5f11b6188cd2fc27a0/pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2", size = 342341 }, { url = "https://files.pythonhosted.org/packages/6b/77/7440a06a8ead44c7757a64362dd22df5760f9b12dc5f11b6188cd2fc27a0/pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2", size = 342341 },
] ]
[[package]]
name = "pytest-asyncio"
version = "0.24.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pytest" },
]
sdist = { url = "https://files.pythonhosted.org/packages/52/6d/c6cf50ce320cf8611df7a1254d86233b3df7cc07f9b5f5cbcb82e08aa534/pytest_asyncio-0.24.0.tar.gz", hash = "sha256:d081d828e576d85f875399194281e92bf8a68d60d72d1a2faf2feddb6c46b276", size = 49855 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/96/31/6607dab48616902f76885dfcf62c08d929796fc3b2d2318faf9fd54dbed9/pytest_asyncio-0.24.0-py3-none-any.whl", hash = "sha256:a811296ed596b69bf0b6f3dc40f83bcaf341b155a269052d82efa2b25ac7037b", size = 18024 },
]
[[package]] [[package]]
name = "pytest-flakefinder" name = "pytest-flakefinder"
version = "1.1.0" version = "1.1.0"