Add strict mode to pyright (#315)

* Add strict mode to pyright

* Apply UP rule

* fix readme

* More correct

* Leave wrong Context for now

* Add strict mode to pyright

* Apply UP rule

* fix readme

* fix

* ignore
This commit is contained in:
Marcelo Trylesinski
2025-03-20 09:13:08 +00:00
committed by GitHub
parent 5a54d82459
commit ae77772ea8
27 changed files with 194 additions and 133 deletions

View File

@@ -4,6 +4,7 @@ import json
import os
import sys
from pathlib import Path
from typing import Any
from mcp.server.fastmcp.utilities.logging import get_logger
@@ -116,10 +117,7 @@ def update_claude_config(
# Add fastmcp run command
args.extend(["mcp", "run", file_spec])
server_config = {
"command": "uv",
"args": args,
}
server_config: dict[str, Any] = {"command": "uv", "args": args}
# Add environment variables if specified
if env_vars:

View File

@@ -1,7 +1,7 @@
import json
import logging
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import AsyncGenerator
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

View File

@@ -2,8 +2,8 @@
import inspect
import json
from collections.abc import Callable
from typing import Any, Awaitable, Literal, Sequence
from collections.abc import Awaitable, Callable, Sequence
from typing import Any, Literal
import pydantic_core
from pydantic import BaseModel, Field, TypeAdapter, validate_call
@@ -19,7 +19,7 @@ class Message(BaseModel):
role: Literal["user", "assistant"]
content: CONTENT_TYPES
def __init__(self, content: str | CONTENT_TYPES, **kwargs):
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
if isinstance(content, str):
content = TextContent(type="text", text=content)
super().__init__(content=content, **kwargs)
@@ -30,7 +30,7 @@ class UserMessage(Message):
role: Literal["user", "assistant"] = "user"
def __init__(self, content: str | CONTENT_TYPES, **kwargs):
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
super().__init__(content=content, **kwargs)
@@ -39,11 +39,13 @@ class AssistantMessage(Message):
role: Literal["user", "assistant"] = "assistant"
def __init__(self, content: str | CONTENT_TYPES, **kwargs):
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
super().__init__(content=content, **kwargs)
message_validator = TypeAdapter(UserMessage | AssistantMessage)
message_validator = TypeAdapter[UserMessage | AssistantMessage](
UserMessage | AssistantMessage
)
SyncPromptResult = (
str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]]
@@ -73,12 +75,12 @@ class Prompt(BaseModel):
arguments: list[PromptArgument] | None = Field(
None, description="Arguments that can be passed to the prompt"
)
fn: Callable = Field(exclude=True)
fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True)
@classmethod
def from_function(
cls,
fn: Callable[..., PromptResult],
fn: Callable[..., PromptResult | Awaitable[PromptResult]],
name: str | None = None,
description: str | None = None,
) -> "Prompt":
@@ -99,7 +101,7 @@ class Prompt(BaseModel):
parameters = TypeAdapter(fn).json_schema()
# Convert parameters to PromptArguments
arguments = []
arguments: list[PromptArgument] = []
if "properties" in parameters:
for param_name, param in parameters["properties"].items():
required = param_name in parameters.get("required", [])
@@ -138,25 +140,23 @@ class Prompt(BaseModel):
result = await result
# Validate messages
if not isinstance(result, (list, tuple)):
if not isinstance(result, list | tuple):
result = [result]
# Convert result to messages
messages = []
for msg in result:
messages: list[Message] = []
for msg in result: # type: ignore[reportUnknownVariableType]
try:
if isinstance(msg, Message):
messages.append(msg)
elif isinstance(msg, dict):
msg = message_validator.validate_python(msg)
messages.append(msg)
messages.append(message_validator.validate_python(msg))
elif isinstance(msg, str):
messages.append(
UserMessage(content=TextContent(type="text", text=msg))
)
content = TextContent(type="text", text=msg)
messages.append(UserMessage(content=content))
else:
msg = json.dumps(pydantic_core.to_jsonable_python(msg))
messages.append(Message(role="user", content=msg))
content = json.dumps(pydantic_core.to_jsonable_python(msg))
messages.append(Message(role="user", content=content))
except Exception:
raise ValueError(
f"Could not convert prompt result to message: {msg}"

View File

@@ -1,6 +1,7 @@
"""Resource manager functionality."""
from typing import Callable
from collections.abc import Callable
from typing import Any
from pydantic import AnyUrl
@@ -47,7 +48,7 @@ class ResourceManager:
def add_template(
self,
fn: Callable,
fn: Callable[..., Any],
uri_template: str,
name: str | None = None,
description: str | None = None,

View File

@@ -1,8 +1,11 @@
"""Resource template functionality."""
from __future__ import annotations
import inspect
import re
from typing import Any, Callable
from collections.abc import Callable
from typing import Any
from pydantic import BaseModel, Field, TypeAdapter, validate_call
@@ -20,18 +23,20 @@ class ResourceTemplate(BaseModel):
mime_type: str = Field(
default="text/plain", description="MIME type of the resource content"
)
fn: Callable = Field(exclude=True)
parameters: dict = Field(description="JSON schema for function parameters")
fn: Callable[..., Any] = Field(exclude=True)
parameters: dict[str, Any] = Field(
description="JSON schema for function parameters"
)
@classmethod
def from_function(
cls,
fn: Callable,
fn: Callable[..., Any],
uri_template: str,
name: str | None = None,
description: str | None = None,
mime_type: str | None = None,
) -> "ResourceTemplate":
) -> ResourceTemplate:
"""Create a template from a function."""
func_name = name or fn.__name__
if func_name == "<lambda>":

View File

@@ -5,13 +5,13 @@ from __future__ import annotations as _annotations
import inspect
import json
import re
from collections.abc import AsyncIterator, Iterable
from collections.abc import AsyncIterator, Callable, Iterable, Sequence
from contextlib import (
AbstractAsyncContextManager,
asynccontextmanager,
)
from itertools import chain
from typing import Any, Callable, Generic, Literal, Sequence
from typing import Any, Generic, Literal
import anyio
import pydantic_core
@@ -20,6 +20,7 @@ from pydantic import BaseModel, Field
from pydantic.networks import AnyUrl
from pydantic_settings import BaseSettings, SettingsConfigDict
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.routing import Mount, Route
from mcp.server.fastmcp.exceptions import ResourceError
@@ -88,13 +89,13 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
)
lifespan: (
Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]] | None
Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None
) = Field(None, description="Lifespan context manager")
def lifespan_wrapper(
app: FastMCP,
lifespan: Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]],
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]],
) -> Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[object]]:
@asynccontextmanager
async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]:
@@ -179,7 +180,7 @@ class FastMCP:
for info in tools
]
def get_context(self) -> "Context[ServerSession, object]":
def get_context(self) -> Context[ServerSession, object]:
"""
Returns a Context object. Note that the context will only be valid
during a request; outside a request, most methods will error.
@@ -478,9 +479,11 @@ class FastMCP:
"""Return an instance of the SSE server app."""
sse = SseServerTransport("/messages/")
async def handle_sse(request):
async def handle_sse(request: Request) -> None:
async with sse.connect_sse(
request.scope, request.receive, request._send
request.scope,
request.receive,
request._send, # type: ignore[reportPrivateUsage]
) as streams:
await self._mcp_server.run(
streams[0],
@@ -535,14 +538,14 @@ def _convert_to_content(
if result is None:
return []
if isinstance(result, (TextContent, ImageContent, EmbeddedResource)):
if isinstance(result, TextContent | ImageContent | EmbeddedResource):
return [result]
if isinstance(result, Image):
return [result.to_image_content()]
if isinstance(result, (list, tuple)):
return list(chain.from_iterable(_convert_to_content(item) for item in result))
if isinstance(result, list | tuple):
return list(chain.from_iterable(_convert_to_content(item) for item in result)) # type: ignore[reportUnknownVariableType]
if not isinstance(result, str):
try:

View File

@@ -1,11 +1,11 @@
from __future__ import annotations as _annotations
import inspect
from typing import TYPE_CHECKING, Any, Callable
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, Field
import mcp.server.fastmcp
from mcp.server.fastmcp.exceptions import ToolError
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
@@ -38,8 +38,10 @@ class Tool(BaseModel):
name: str | None = None,
description: str | None = None,
context_kwarg: str | None = None,
) -> "Tool":
) -> Tool:
"""Create a Tool from a function."""
from mcp.server.fastmcp import Context
func_name = name or fn.__name__
if func_name == "<lambda>":
@@ -48,11 +50,10 @@ class Tool(BaseModel):
func_doc = description or fn.__doc__ or ""
is_async = inspect.iscoroutinefunction(fn)
# Find context parameter if it exists
if context_kwarg is None:
sig = inspect.signature(fn)
for param_name, param in sig.parameters.items():
if param.annotation is mcp.server.fastmcp.Context:
if param.annotation is Context:
context_kwarg = param_name
break

View File

@@ -32,7 +32,7 @@ class ToolManager:
def add_tool(
self,
fn: Callable,
fn: Callable[..., Any],
name: str | None = None,
description: str | None = None,
) -> Tool:

View File

@@ -80,7 +80,7 @@ class FuncMetadata(BaseModel):
dicts (JSON objects) as JSON strings, which can be pre-parsed here.
"""
new_data = data.copy() # Shallow copy
for field_name, field_info in self.arg_model.model_fields.items():
for field_name, _field_info in self.arg_model.model_fields.items():
if field_name not in data.keys():
continue
if isinstance(data[field_name], str):
@@ -177,7 +177,9 @@ def func_metadata(
def _get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any:
def try_eval_type(value, globalns, localns):
def try_eval_type(
value: Any, globalns: dict[str, Any], localns: dict[str, Any]
) -> tuple[Any, bool]:
try:
return eval_type_backport(value, globalns, localns), True
except NameError:

View File

@@ -24,7 +24,7 @@ def configure_logging(
Args:
level: the log level to use
"""
handlers = []
handlers: list[logging.Handler] = []
try:
from rich.console import Console
from rich.logging import RichHandler

View File

@@ -69,9 +69,9 @@ from __future__ import annotations as _annotations
import contextvars
import logging
import warnings
from collections.abc import Awaitable, Callable, Iterable
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
from typing import Any, AsyncIterator, Generic, TypeVar
from typing import Any, Generic, TypeVar
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
@@ -155,9 +155,7 @@ class Server(Generic[LifespanResultT]):
try:
from importlib.metadata import version
v = version(package)
if v is not None:
return v
return version(package)
except Exception:
pass
@@ -320,7 +318,6 @@ class Server(Generic[LifespanResultT]):
contents_list = [
create_content(content_item.content, content_item.mime_type)
for content_item in contents
if isinstance(content_item, ReadResourceContents)
]
return types.ServerResult(
types.ReadResourceResult(
@@ -511,7 +508,8 @@ class Server(Generic[LifespanResultT]):
raise_exceptions: bool = False,
):
with warnings.catch_warnings(record=True) as w:
match message:
# TODO(Marcelo): We should be checking if message is Exception here.
match message: # type: ignore[reportMatchNotExhaustive]
case (
RequestResponder(request=types.ClientRequest(root=req)) as responder
):
@@ -527,7 +525,7 @@ class Server(Generic[LifespanResultT]):
async def _handle_request(
self,
message: RequestResponder,
message: RequestResponder[types.ClientRequest, types.ServerResult],
req: Any,
session: ServerSession,
lifespan_context: LifespanResultT,

View File

@@ -2,9 +2,10 @@
In-memory transports
"""
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from datetime import timedelta
from typing import AsyncGenerator
from typing import Any
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
@@ -52,7 +53,7 @@ async def create_client_server_memory_streams() -> (
@asynccontextmanager
async def create_connected_server_and_client_session(
server: Server,
server: Server[Any],
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,

View File

@@ -1,10 +1,19 @@
from collections.abc import Generator
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Generic
from pydantic import BaseModel
from mcp.shared.context import RequestContext
from mcp.shared.session import BaseSession
from mcp.shared.session import (
BaseSession,
ReceiveNotificationT,
ReceiveRequestT,
SendNotificationT,
SendRequestT,
SendResultT,
)
from mcp.types import ProgressToken
@@ -14,8 +23,22 @@ class Progress(BaseModel):
@dataclass
class ProgressContext:
session: BaseSession
class ProgressContext(
Generic[
SendRequestT,
SendNotificationT,
SendResultT,
ReceiveRequestT,
ReceiveNotificationT,
]
):
session: BaseSession[
SendRequestT,
SendNotificationT,
SendResultT,
ReceiveRequestT,
ReceiveNotificationT,
]
progress_token: ProgressToken
total: float | None
current: float = field(default=0.0, init=False)
@@ -29,7 +52,27 @@ class ProgressContext:
@contextmanager
def progress(ctx: RequestContext, total: float | None = None):
def progress(
ctx: RequestContext[
BaseSession[
SendRequestT,
SendNotificationT,
SendResultT,
ReceiveRequestT,
ReceiveNotificationT,
]
],
total: float | None = None,
) -> Generator[
ProgressContext[
SendRequestT,
SendNotificationT,
SendResultT,
ReceiveRequestT,
ReceiveNotificationT,
],
None,
]:
if ctx.meta is None or ctx.meta.progressToken is None:
raise ValueError("No progress token provided")

View File

@@ -1,7 +1,9 @@
import logging
from collections.abc import Callable
from contextlib import AsyncExitStack
from datetime import timedelta
from typing import Any, Callable, Generic, TypeVar
from types import TracebackType
from typing import Any, Generic, TypeVar
import anyio
import anyio.lowlevel
@@ -86,7 +88,12 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
self._cancel_scope.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Exit the context manager, performing cleanup and notifying completion."""
try:
if self._completed:
@@ -112,7 +119,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
if not self.cancelled:
self._completed = True
await self._session._send_response(
await self._session._send_response( # type: ignore[reportPrivateUsage]
request_id=self.request_id, response=response
)
@@ -126,7 +133,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
self._cancel_scope.cancel()
self._completed = True # Mark as completed so it's removed from in_flight
# Send an error response to indicate cancellation
await self._session._send_response(
await self._session._send_response( # type: ignore[reportPrivateUsage]
request_id=self.request_id,
response=ErrorData(code=0, message="Request cancelled", data=None),
)
@@ -137,7 +144,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
@property
def cancelled(self) -> bool:
return self._cancel_scope is not None and self._cancel_scope.cancel_called
return self._cancel_scope.cancel_called
class BaseSession(
@@ -202,7 +209,12 @@ class BaseSession(
self._task_group.start_soon(self._receive_loop)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
await self._exit_stack.aclose()
# Using BaseSession as a context manager should not block on exit (this
# would be very surprising behavior), so make sure to cancel the tasks
@@ -324,7 +336,7 @@ class BaseSession(
self._in_flight[responder.request_id] = responder
await self._received_request(responder)
if not responder._completed:
if not responder._completed: # type: ignore[reportPrivateUsage]
await self._incoming_message_stream_writer.send(responder)
elif isinstance(message.root, JSONRPCNotification):

View File

@@ -1,7 +1,7 @@
from collections.abc import Callable
from typing import (
Annotated,
Any,
Callable,
Generic,
Literal,
TypeAlias,
@@ -89,6 +89,7 @@ class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
"""Base class for JSON-RPC notifications."""
method: MethodT
params: NotificationParamsT
model_config = ConfigDict(extra="allow")
@@ -1010,7 +1011,9 @@ class CancelledNotificationParams(NotificationParams):
model_config = ConfigDict(extra="allow")
class CancelledNotification(Notification):
class CancelledNotification(
Notification[CancelledNotificationParams, Literal["notifications/cancelled"]]
):
"""
This notification can be sent by either side to indicate that it is cancelling a
previously-issued request.