Add strict mode to pyright (#315)

* Add strict mode to pyright

* Apply UP rule

* fix readme

* More correct

* Leave wrong Context for now

* Add strict mode to pyright

* Apply UP rule

* fix readme

* fix

* ignore
This commit is contained in:
Marcelo Trylesinski
2025-03-20 09:13:08 +00:00
committed by GitHub
parent 5a54d82459
commit ae77772ea8
27 changed files with 194 additions and 133 deletions

View File

@@ -16,33 +16,36 @@
<!-- omit in toc --> <!-- omit in toc -->
## Table of Contents ## Table of Contents
- [Overview](#overview) - [MCP Python SDK](#mcp-python-sdk)
- [Installation](#installation) - [Overview](#overview)
- [Quickstart](#quickstart) - [Installation](#installation)
- [What is MCP?](#what-is-mcp) - [Adding MCP to your python project](#adding-mcp-to-your-python-project)
- [Core Concepts](#core-concepts) - [Running the standalone MCP development tools](#running-the-standalone-mcp-development-tools)
- [Server](#server) - [Quickstart](#quickstart)
- [Resources](#resources) - [What is MCP?](#what-is-mcp)
- [Tools](#tools) - [Core Concepts](#core-concepts)
- [Prompts](#prompts) - [Server](#server)
- [Images](#images) - [Resources](#resources)
- [Context](#context) - [Tools](#tools)
- [Running Your Server](#running-your-server) - [Prompts](#prompts)
- [Development Mode](#development-mode) - [Images](#images)
- [Claude Desktop Integration](#claude-desktop-integration) - [Context](#context)
- [Direct Execution](#direct-execution) - [Running Your Server](#running-your-server)
- [Mounting to an Existing ASGI Server](#mounting-to-an-existing-asgi-server) - [Development Mode](#development-mode)
- [Examples](#examples) - [Claude Desktop Integration](#claude-desktop-integration)
- [Echo Server](#echo-server) - [Direct Execution](#direct-execution)
- [SQLite Explorer](#sqlite-explorer) - [Mounting to an Existing ASGI Server](#mounting-to-an-existing-asgi-server)
- [Advanced Usage](#advanced-usage) - [Examples](#examples)
- [Low-Level Server](#low-level-server) - [Echo Server](#echo-server)
- [Writing MCP Clients](#writing-mcp-clients) - [SQLite Explorer](#sqlite-explorer)
- [MCP Primitives](#mcp-primitives) - [Advanced Usage](#advanced-usage)
- [Server Capabilities](#server-capabilities) - [Low-Level Server](#low-level-server)
- [Documentation](#documentation) - [Writing MCP Clients](#writing-mcp-clients)
- [Contributing](#contributing) - [MCP Primitives](#mcp-primitives)
- [License](#license) - [Server Capabilities](#server-capabilities)
- [Documentation](#documentation)
- [Contributing](#contributing)
- [License](#license)
[pypi-badge]: https://img.shields.io/pypi/v/mcp.svg [pypi-badge]: https://img.shields.io/pypi/v/mcp.svg
[pypi-url]: https://pypi.org/project/mcp/ [pypi-url]: https://pypi.org/project/mcp/
@@ -143,8 +146,8 @@ The FastMCP server is your core interface to the MCP protocol. It handles connec
```python ```python
# Add lifespan support for startup/shutdown with strong typing # Add lifespan support for startup/shutdown with strong typing
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from collections.abc import AsyncIterator
from dataclasses import dataclass from dataclasses import dataclass
from typing import AsyncIterator
from fake_database import Database # Replace with your actual DB type from fake_database import Database # Replace with your actual DB type
@@ -442,7 +445,7 @@ For more control, you can use the low-level server implementation directly. This
```python ```python
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import AsyncIterator from collections.abc import AsyncIterator
from fake_database import Database # Replace with your actual DB type from fake_database import Database # Replace with your actual DB type

View File

@@ -76,13 +76,11 @@ packages = ["src/mcp"]
include = ["src/mcp", "tests"] include = ["src/mcp", "tests"]
venvPath = "." venvPath = "."
venv = ".venv" venv = ".venv"
strict = [ strict = ["src/mcp/**/*.py"]
"src/mcp/server/fastmcp/tools/base.py", exclude = ["src/mcp/types.py"]
"src/mcp/client/*.py"
]
[tool.ruff.lint] [tool.ruff.lint]
select = ["E", "F", "I"] select = ["E", "F", "I", "UP"]
ignore = [] ignore = []
[tool.ruff] [tool.ruff]

View File

@@ -4,6 +4,7 @@ import json
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Any
from mcp.server.fastmcp.utilities.logging import get_logger from mcp.server.fastmcp.utilities.logging import get_logger
@@ -116,10 +117,7 @@ def update_claude_config(
# Add fastmcp run command # Add fastmcp run command
args.extend(["mcp", "run", file_spec]) args.extend(["mcp", "run", file_spec])
server_config = { server_config: dict[str, Any] = {"command": "uv", "args": args}
"command": "uv",
"args": args,
}
# Add environment variables if specified # Add environment variables if specified
if env_vars: if env_vars:

View File

@@ -1,7 +1,7 @@
import json import json
import logging import logging
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import AsyncGenerator
import anyio import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

View File

@@ -2,8 +2,8 @@
import inspect import inspect
import json import json
from collections.abc import Callable from collections.abc import Awaitable, Callable, Sequence
from typing import Any, Awaitable, Literal, Sequence from typing import Any, Literal
import pydantic_core import pydantic_core
from pydantic import BaseModel, Field, TypeAdapter, validate_call from pydantic import BaseModel, Field, TypeAdapter, validate_call
@@ -19,7 +19,7 @@ class Message(BaseModel):
role: Literal["user", "assistant"] role: Literal["user", "assistant"]
content: CONTENT_TYPES content: CONTENT_TYPES
def __init__(self, content: str | CONTENT_TYPES, **kwargs): def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
if isinstance(content, str): if isinstance(content, str):
content = TextContent(type="text", text=content) content = TextContent(type="text", text=content)
super().__init__(content=content, **kwargs) super().__init__(content=content, **kwargs)
@@ -30,7 +30,7 @@ class UserMessage(Message):
role: Literal["user", "assistant"] = "user" role: Literal["user", "assistant"] = "user"
def __init__(self, content: str | CONTENT_TYPES, **kwargs): def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
super().__init__(content=content, **kwargs) super().__init__(content=content, **kwargs)
@@ -39,11 +39,13 @@ class AssistantMessage(Message):
role: Literal["user", "assistant"] = "assistant" role: Literal["user", "assistant"] = "assistant"
def __init__(self, content: str | CONTENT_TYPES, **kwargs): def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
super().__init__(content=content, **kwargs) super().__init__(content=content, **kwargs)
message_validator = TypeAdapter(UserMessage | AssistantMessage) message_validator = TypeAdapter[UserMessage | AssistantMessage](
UserMessage | AssistantMessage
)
SyncPromptResult = ( SyncPromptResult = (
str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]] str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]]
@@ -73,12 +75,12 @@ class Prompt(BaseModel):
arguments: list[PromptArgument] | None = Field( arguments: list[PromptArgument] | None = Field(
None, description="Arguments that can be passed to the prompt" None, description="Arguments that can be passed to the prompt"
) )
fn: Callable = Field(exclude=True) fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True)
@classmethod @classmethod
def from_function( def from_function(
cls, cls,
fn: Callable[..., PromptResult], fn: Callable[..., PromptResult | Awaitable[PromptResult]],
name: str | None = None, name: str | None = None,
description: str | None = None, description: str | None = None,
) -> "Prompt": ) -> "Prompt":
@@ -99,7 +101,7 @@ class Prompt(BaseModel):
parameters = TypeAdapter(fn).json_schema() parameters = TypeAdapter(fn).json_schema()
# Convert parameters to PromptArguments # Convert parameters to PromptArguments
arguments = [] arguments: list[PromptArgument] = []
if "properties" in parameters: if "properties" in parameters:
for param_name, param in parameters["properties"].items(): for param_name, param in parameters["properties"].items():
required = param_name in parameters.get("required", []) required = param_name in parameters.get("required", [])
@@ -138,25 +140,23 @@ class Prompt(BaseModel):
result = await result result = await result
# Validate messages # Validate messages
if not isinstance(result, (list, tuple)): if not isinstance(result, list | tuple):
result = [result] result = [result]
# Convert result to messages # Convert result to messages
messages = [] messages: list[Message] = []
for msg in result: for msg in result: # type: ignore[reportUnknownVariableType]
try: try:
if isinstance(msg, Message): if isinstance(msg, Message):
messages.append(msg) messages.append(msg)
elif isinstance(msg, dict): elif isinstance(msg, dict):
msg = message_validator.validate_python(msg) messages.append(message_validator.validate_python(msg))
messages.append(msg)
elif isinstance(msg, str): elif isinstance(msg, str):
messages.append( content = TextContent(type="text", text=msg)
UserMessage(content=TextContent(type="text", text=msg)) messages.append(UserMessage(content=content))
)
else: else:
msg = json.dumps(pydantic_core.to_jsonable_python(msg)) content = json.dumps(pydantic_core.to_jsonable_python(msg))
messages.append(Message(role="user", content=msg)) messages.append(Message(role="user", content=content))
except Exception: except Exception:
raise ValueError( raise ValueError(
f"Could not convert prompt result to message: {msg}" f"Could not convert prompt result to message: {msg}"

View File

@@ -1,6 +1,7 @@
"""Resource manager functionality.""" """Resource manager functionality."""
from typing import Callable from collections.abc import Callable
from typing import Any
from pydantic import AnyUrl from pydantic import AnyUrl
@@ -47,7 +48,7 @@ class ResourceManager:
def add_template( def add_template(
self, self,
fn: Callable, fn: Callable[..., Any],
uri_template: str, uri_template: str,
name: str | None = None, name: str | None = None,
description: str | None = None, description: str | None = None,

View File

@@ -1,8 +1,11 @@
"""Resource template functionality.""" """Resource template functionality."""
from __future__ import annotations
import inspect import inspect
import re import re
from typing import Any, Callable from collections.abc import Callable
from typing import Any
from pydantic import BaseModel, Field, TypeAdapter, validate_call from pydantic import BaseModel, Field, TypeAdapter, validate_call
@@ -20,18 +23,20 @@ class ResourceTemplate(BaseModel):
mime_type: str = Field( mime_type: str = Field(
default="text/plain", description="MIME type of the resource content" default="text/plain", description="MIME type of the resource content"
) )
fn: Callable = Field(exclude=True) fn: Callable[..., Any] = Field(exclude=True)
parameters: dict = Field(description="JSON schema for function parameters") parameters: dict[str, Any] = Field(
description="JSON schema for function parameters"
)
@classmethod @classmethod
def from_function( def from_function(
cls, cls,
fn: Callable, fn: Callable[..., Any],
uri_template: str, uri_template: str,
name: str | None = None, name: str | None = None,
description: str | None = None, description: str | None = None,
mime_type: str | None = None, mime_type: str | None = None,
) -> "ResourceTemplate": ) -> ResourceTemplate:
"""Create a template from a function.""" """Create a template from a function."""
func_name = name or fn.__name__ func_name = name or fn.__name__
if func_name == "<lambda>": if func_name == "<lambda>":

View File

@@ -5,13 +5,13 @@ from __future__ import annotations as _annotations
import inspect import inspect
import json import json
import re import re
from collections.abc import AsyncIterator, Iterable from collections.abc import AsyncIterator, Callable, Iterable, Sequence
from contextlib import ( from contextlib import (
AbstractAsyncContextManager, AbstractAsyncContextManager,
asynccontextmanager, asynccontextmanager,
) )
from itertools import chain from itertools import chain
from typing import Any, Callable, Generic, Literal, Sequence from typing import Any, Generic, Literal
import anyio import anyio
import pydantic_core import pydantic_core
@@ -20,6 +20,7 @@ from pydantic import BaseModel, Field
from pydantic.networks import AnyUrl from pydantic.networks import AnyUrl
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.requests import Request
from starlette.routing import Mount, Route from starlette.routing import Mount, Route
from mcp.server.fastmcp.exceptions import ResourceError from mcp.server.fastmcp.exceptions import ResourceError
@@ -88,13 +89,13 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
) )
lifespan: ( lifespan: (
Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]] | None Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None
) = Field(None, description="Lifespan context manager") ) = Field(None, description="Lifespan context manager")
def lifespan_wrapper( def lifespan_wrapper(
app: FastMCP, app: FastMCP,
lifespan: Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]], lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]],
) -> Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[object]]: ) -> Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[object]]:
@asynccontextmanager @asynccontextmanager
async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]: async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]:
@@ -179,7 +180,7 @@ class FastMCP:
for info in tools for info in tools
] ]
def get_context(self) -> "Context[ServerSession, object]": def get_context(self) -> Context[ServerSession, object]:
""" """
Returns a Context object. Note that the context will only be valid Returns a Context object. Note that the context will only be valid
during a request; outside a request, most methods will error. during a request; outside a request, most methods will error.
@@ -478,9 +479,11 @@ class FastMCP:
"""Return an instance of the SSE server app.""" """Return an instance of the SSE server app."""
sse = SseServerTransport("/messages/") sse = SseServerTransport("/messages/")
async def handle_sse(request): async def handle_sse(request: Request) -> None:
async with sse.connect_sse( async with sse.connect_sse(
request.scope, request.receive, request._send request.scope,
request.receive,
request._send, # type: ignore[reportPrivateUsage]
) as streams: ) as streams:
await self._mcp_server.run( await self._mcp_server.run(
streams[0], streams[0],
@@ -535,14 +538,14 @@ def _convert_to_content(
if result is None: if result is None:
return [] return []
if isinstance(result, (TextContent, ImageContent, EmbeddedResource)): if isinstance(result, TextContent | ImageContent | EmbeddedResource):
return [result] return [result]
if isinstance(result, Image): if isinstance(result, Image):
return [result.to_image_content()] return [result.to_image_content()]
if isinstance(result, (list, tuple)): if isinstance(result, list | tuple):
return list(chain.from_iterable(_convert_to_content(item) for item in result)) return list(chain.from_iterable(_convert_to_content(item) for item in result)) # type: ignore[reportUnknownVariableType]
if not isinstance(result, str): if not isinstance(result, str):
try: try:

View File

@@ -1,11 +1,11 @@
from __future__ import annotations as _annotations from __future__ import annotations as _annotations
import inspect import inspect
from typing import TYPE_CHECKING, Any, Callable from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import mcp.server.fastmcp
from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.exceptions import ToolError
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
@@ -38,8 +38,10 @@ class Tool(BaseModel):
name: str | None = None, name: str | None = None,
description: str | None = None, description: str | None = None,
context_kwarg: str | None = None, context_kwarg: str | None = None,
) -> "Tool": ) -> Tool:
"""Create a Tool from a function.""" """Create a Tool from a function."""
from mcp.server.fastmcp import Context
func_name = name or fn.__name__ func_name = name or fn.__name__
if func_name == "<lambda>": if func_name == "<lambda>":
@@ -48,11 +50,10 @@ class Tool(BaseModel):
func_doc = description or fn.__doc__ or "" func_doc = description or fn.__doc__ or ""
is_async = inspect.iscoroutinefunction(fn) is_async = inspect.iscoroutinefunction(fn)
# Find context parameter if it exists
if context_kwarg is None: if context_kwarg is None:
sig = inspect.signature(fn) sig = inspect.signature(fn)
for param_name, param in sig.parameters.items(): for param_name, param in sig.parameters.items():
if param.annotation is mcp.server.fastmcp.Context: if param.annotation is Context:
context_kwarg = param_name context_kwarg = param_name
break break

View File

@@ -32,7 +32,7 @@ class ToolManager:
def add_tool( def add_tool(
self, self,
fn: Callable, fn: Callable[..., Any],
name: str | None = None, name: str | None = None,
description: str | None = None, description: str | None = None,
) -> Tool: ) -> Tool:

View File

@@ -80,7 +80,7 @@ class FuncMetadata(BaseModel):
dicts (JSON objects) as JSON strings, which can be pre-parsed here. dicts (JSON objects) as JSON strings, which can be pre-parsed here.
""" """
new_data = data.copy() # Shallow copy new_data = data.copy() # Shallow copy
for field_name, field_info in self.arg_model.model_fields.items(): for field_name, _field_info in self.arg_model.model_fields.items():
if field_name not in data.keys(): if field_name not in data.keys():
continue continue
if isinstance(data[field_name], str): if isinstance(data[field_name], str):
@@ -177,7 +177,9 @@ def func_metadata(
def _get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any: def _get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any:
def try_eval_type(value, globalns, localns): def try_eval_type(
value: Any, globalns: dict[str, Any], localns: dict[str, Any]
) -> tuple[Any, bool]:
try: try:
return eval_type_backport(value, globalns, localns), True return eval_type_backport(value, globalns, localns), True
except NameError: except NameError:

View File

@@ -24,7 +24,7 @@ def configure_logging(
Args: Args:
level: the log level to use level: the log level to use
""" """
handlers = [] handlers: list[logging.Handler] = []
try: try:
from rich.console import Console from rich.console import Console
from rich.logging import RichHandler from rich.logging import RichHandler

View File

@@ -69,9 +69,9 @@ from __future__ import annotations as _annotations
import contextvars import contextvars
import logging import logging
import warnings import warnings
from collections.abc import Awaitable, Callable, Iterable from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
from typing import Any, AsyncIterator, Generic, TypeVar from typing import Any, Generic, TypeVar
import anyio import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
@@ -155,9 +155,7 @@ class Server(Generic[LifespanResultT]):
try: try:
from importlib.metadata import version from importlib.metadata import version
v = version(package) return version(package)
if v is not None:
return v
except Exception: except Exception:
pass pass
@@ -320,7 +318,6 @@ class Server(Generic[LifespanResultT]):
contents_list = [ contents_list = [
create_content(content_item.content, content_item.mime_type) create_content(content_item.content, content_item.mime_type)
for content_item in contents for content_item in contents
if isinstance(content_item, ReadResourceContents)
] ]
return types.ServerResult( return types.ServerResult(
types.ReadResourceResult( types.ReadResourceResult(
@@ -511,7 +508,8 @@ class Server(Generic[LifespanResultT]):
raise_exceptions: bool = False, raise_exceptions: bool = False,
): ):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
match message: # TODO(Marcelo): We should be checking if message is Exception here.
match message: # type: ignore[reportMatchNotExhaustive]
case ( case (
RequestResponder(request=types.ClientRequest(root=req)) as responder RequestResponder(request=types.ClientRequest(root=req)) as responder
): ):
@@ -527,7 +525,7 @@ class Server(Generic[LifespanResultT]):
async def _handle_request( async def _handle_request(
self, self,
message: RequestResponder, message: RequestResponder[types.ClientRequest, types.ServerResult],
req: Any, req: Any,
session: ServerSession, session: ServerSession,
lifespan_context: LifespanResultT, lifespan_context: LifespanResultT,

View File

@@ -2,9 +2,10 @@
In-memory transports In-memory transports
""" """
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import timedelta from datetime import timedelta
from typing import AsyncGenerator from typing import Any
import anyio import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
@@ -52,7 +53,7 @@ async def create_client_server_memory_streams() -> (
@asynccontextmanager @asynccontextmanager
async def create_connected_server_and_client_session( async def create_connected_server_and_client_session(
server: Server, server: Server[Any],
read_timeout_seconds: timedelta | None = None, read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None, sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None, list_roots_callback: ListRootsFnT | None = None,

View File

@@ -1,10 +1,19 @@
from collections.abc import Generator
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Generic
from pydantic import BaseModel from pydantic import BaseModel
from mcp.shared.context import RequestContext from mcp.shared.context import RequestContext
from mcp.shared.session import BaseSession from mcp.shared.session import (
BaseSession,
ReceiveNotificationT,
ReceiveRequestT,
SendNotificationT,
SendRequestT,
SendResultT,
)
from mcp.types import ProgressToken from mcp.types import ProgressToken
@@ -14,8 +23,22 @@ class Progress(BaseModel):
@dataclass @dataclass
class ProgressContext: class ProgressContext(
session: BaseSession Generic[
SendRequestT,
SendNotificationT,
SendResultT,
ReceiveRequestT,
ReceiveNotificationT,
]
):
session: BaseSession[
SendRequestT,
SendNotificationT,
SendResultT,
ReceiveRequestT,
ReceiveNotificationT,
]
progress_token: ProgressToken progress_token: ProgressToken
total: float | None total: float | None
current: float = field(default=0.0, init=False) current: float = field(default=0.0, init=False)
@@ -29,7 +52,27 @@ class ProgressContext:
@contextmanager @contextmanager
def progress(ctx: RequestContext, total: float | None = None): def progress(
ctx: RequestContext[
BaseSession[
SendRequestT,
SendNotificationT,
SendResultT,
ReceiveRequestT,
ReceiveNotificationT,
]
],
total: float | None = None,
) -> Generator[
ProgressContext[
SendRequestT,
SendNotificationT,
SendResultT,
ReceiveRequestT,
ReceiveNotificationT,
],
None,
]:
if ctx.meta is None or ctx.meta.progressToken is None: if ctx.meta is None or ctx.meta.progressToken is None:
raise ValueError("No progress token provided") raise ValueError("No progress token provided")

View File

@@ -1,7 +1,9 @@
import logging import logging
from collections.abc import Callable
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from datetime import timedelta from datetime import timedelta
from typing import Any, Callable, Generic, TypeVar from types import TracebackType
from typing import Any, Generic, TypeVar
import anyio import anyio
import anyio.lowlevel import anyio.lowlevel
@@ -86,7 +88,12 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
self._cancel_scope.__enter__() self._cancel_scope.__enter__()
return self return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None: def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Exit the context manager, performing cleanup and notifying completion.""" """Exit the context manager, performing cleanup and notifying completion."""
try: try:
if self._completed: if self._completed:
@@ -112,7 +119,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
if not self.cancelled: if not self.cancelled:
self._completed = True self._completed = True
await self._session._send_response( await self._session._send_response( # type: ignore[reportPrivateUsage]
request_id=self.request_id, response=response request_id=self.request_id, response=response
) )
@@ -126,7 +133,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
self._cancel_scope.cancel() self._cancel_scope.cancel()
self._completed = True # Mark as completed so it's removed from in_flight self._completed = True # Mark as completed so it's removed from in_flight
# Send an error response to indicate cancellation # Send an error response to indicate cancellation
await self._session._send_response( await self._session._send_response( # type: ignore[reportPrivateUsage]
request_id=self.request_id, request_id=self.request_id,
response=ErrorData(code=0, message="Request cancelled", data=None), response=ErrorData(code=0, message="Request cancelled", data=None),
) )
@@ -137,7 +144,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
@property @property
def cancelled(self) -> bool: def cancelled(self) -> bool:
return self._cancel_scope is not None and self._cancel_scope.cancel_called return self._cancel_scope.cancel_called
class BaseSession( class BaseSession(
@@ -202,7 +209,12 @@ class BaseSession(
self._task_group.start_soon(self._receive_loop) self._task_group.start_soon(self._receive_loop)
return self return self
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
await self._exit_stack.aclose() await self._exit_stack.aclose()
# Using BaseSession as a context manager should not block on exit (this # Using BaseSession as a context manager should not block on exit (this
# would be very surprising behavior), so make sure to cancel the tasks # would be very surprising behavior), so make sure to cancel the tasks
@@ -324,7 +336,7 @@ class BaseSession(
self._in_flight[responder.request_id] = responder self._in_flight[responder.request_id] = responder
await self._received_request(responder) await self._received_request(responder)
if not responder._completed: if not responder._completed: # type: ignore[reportPrivateUsage]
await self._incoming_message_stream_writer.send(responder) await self._incoming_message_stream_writer.send(responder)
elif isinstance(message.root, JSONRPCNotification): elif isinstance(message.root, JSONRPCNotification):

View File

@@ -1,7 +1,7 @@
from collections.abc import Callable
from typing import ( from typing import (
Annotated, Annotated,
Any, Any,
Callable,
Generic, Generic,
Literal, Literal,
TypeAlias, TypeAlias,
@@ -89,6 +89,7 @@ class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
"""Base class for JSON-RPC notifications.""" """Base class for JSON-RPC notifications."""
method: MethodT method: MethodT
params: NotificationParamsT
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@@ -1010,7 +1011,9 @@ class CancelledNotificationParams(NotificationParams):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
class CancelledNotification(Notification): class CancelledNotification(
Notification[CancelledNotificationParams, Literal["notifications/cancelled"]]
):
""" """
This notification can be sent by either side to indicate that it is cancelling a This notification can be sent by either side to indicate that it is cancelling a
previously-issued request. previously-issued request.

View File

@@ -1,5 +1,6 @@
import json import json
import subprocess import subprocess
from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@@ -8,7 +9,7 @@ from mcp.cli.claude import update_claude_config
@pytest.fixture @pytest.fixture
def temp_config_dir(tmp_path): def temp_config_dir(tmp_path: Path):
"""Create a temporary Claude config directory.""" """Create a temporary Claude config directory."""
config_dir = tmp_path / "Claude" config_dir = tmp_path / "Claude"
config_dir.mkdir() config_dir.mkdir()
@@ -16,23 +17,20 @@ def temp_config_dir(tmp_path):
@pytest.fixture @pytest.fixture
def mock_config_path(temp_config_dir): def mock_config_path(temp_config_dir: Path):
"""Mock get_claude_config_path to return our temporary directory.""" """Mock get_claude_config_path to return our temporary directory."""
with patch("mcp.cli.claude.get_claude_config_path", return_value=temp_config_dir): with patch("mcp.cli.claude.get_claude_config_path", return_value=temp_config_dir):
yield temp_config_dir yield temp_config_dir
def test_command_execution(mock_config_path): def test_command_execution(mock_config_path: Path):
"""Test that the generated command can actually be executed.""" """Test that the generated command can actually be executed."""
# Setup # Setup
server_name = "test_server" server_name = "test_server"
file_spec = "test_server.py:app" file_spec = "test_server.py:app"
# Update config # Update config
success = update_claude_config( success = update_claude_config(file_spec=file_spec, server_name=server_name)
file_spec=file_spec,
server_name=server_name,
)
assert success assert success
# Read the generated config # Read the generated config

View File

@@ -7,11 +7,7 @@ from mcp.shared.context import RequestContext
from mcp.shared.memory import ( from mcp.shared.memory import (
create_connected_server_and_client_session as create_session, create_connected_server_and_client_session as create_session,
) )
from mcp.types import ( from mcp.types import ListRootsResult, Root, TextContent
ListRootsResult,
Root,
TextContent,
)
@pytest.mark.anyio @pytest.mark.anyio
@@ -39,7 +35,7 @@ async def test_list_roots_callback():
return callback_return return callback_return
@server.tool("test_list_roots") @server.tool("test_list_roots")
async def test_list_roots(context: Context, message: str): async def test_list_roots(context: Context, message: str): # type: ignore[reportUnknownMemberType]
roots = await context.session.list_roots() roots = await context.session.list_roots()
assert roots == callback_return assert roots == callback_return
return True return True

View File

@@ -1,4 +1,4 @@
from typing import List, Literal from typing import Literal
import anyio import anyio
import pytest import pytest
@@ -14,7 +14,7 @@ from mcp.types import (
class LoggingCollector: class LoggingCollector:
def __init__(self): def __init__(self):
self.log_messages: List[LoggingMessageNotificationParams] = [] self.log_messages: list[LoggingMessageNotificationParams] = []
async def __call__(self, params: LoggingMessageNotificationParams) -> None: async def __call__(self, params: LoggingMessageNotificationParams) -> None:
self.log_messages.append(params) self.log_messages.append(params)

View File

@@ -1,8 +1,8 @@
"""Test to reproduce issue #88: Random error thrown on response.""" """Test to reproduce issue #88: Random error thrown on response."""
from collections.abc import Sequence
from datetime import timedelta from datetime import timedelta
from pathlib import Path from pathlib import Path
from typing import Sequence
import anyio import anyio
import pytest import pytest

View File

@@ -1,6 +1,6 @@
import base64 import base64
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING
import pytest import pytest
from pydantic import AnyUrl from pydantic import AnyUrl
@@ -114,7 +114,7 @@ def image_tool_fn(path: str) -> Image:
return Image(path) return Image(path)
def mixed_content_tool_fn() -> list[Union[TextContent, ImageContent]]: def mixed_content_tool_fn() -> list[TextContent | ImageContent]:
return [ return [
TextContent(type="text", text="Hello"), TextContent(type="text", text="Hello"),
ImageContent(type="image", data="abc", mimeType="image/png"), ImageContent(type="image", data="abc", mimeType="image/png"),

View File

@@ -1,6 +1,5 @@
import json import json
import logging import logging
from typing import Optional
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
@@ -296,7 +295,7 @@ class TestContextHandling:
"""Test that context is optional when calling tools.""" """Test that context is optional when calling tools."""
from mcp.server.fastmcp import Context from mcp.server.fastmcp import Context
def tool_with_context(x: int, ctx: Optional[Context] = None) -> str: def tool_with_context(x: int, ctx: Context | None = None) -> str:
return str(x) return str(x)
manager = ToolManager() manager = ToolManager()

View File

@@ -1,7 +1,7 @@
"""Tests for lifespan functionality in both low-level and FastMCP servers.""" """Tests for lifespan functionality in both low-level and FastMCP servers."""
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import AsyncIterator
import anyio import anyio
import pytest import pytest

View File

@@ -1,4 +1,4 @@
from typing import AsyncGenerator from collections.abc import AsyncGenerator
import anyio import anyio
import pytest import pytest

View File

@@ -1,7 +1,7 @@
import multiprocessing import multiprocessing
import socket import socket
import time import time
from typing import AsyncGenerator, Generator from collections.abc import AsyncGenerator, Generator
import anyio import anyio
import httpx import httpx
@@ -139,7 +139,7 @@ def server(server_port: int) -> Generator[None, None, None]:
attempt += 1 attempt += 1
else: else:
raise RuntimeError( raise RuntimeError(
"Server failed to start after {} attempts".format(max_attempts) f"Server failed to start after {max_attempts} attempts"
) )
yield yield

View File

@@ -1,7 +1,7 @@
import multiprocessing import multiprocessing
import socket import socket
import time import time
from typing import AsyncGenerator, Generator from collections.abc import AsyncGenerator, Generator
import anyio import anyio
import pytest import pytest
@@ -135,7 +135,7 @@ def server(server_port: int) -> Generator[None, None, None]:
attempt += 1 attempt += 1
else: else:
raise RuntimeError( raise RuntimeError(
"Server failed to start after {} attempts".format(max_attempts) f"Server failed to start after {max_attempts} attempts"
) )
yield yield