Add strict mode to pyright (#315)

* Add strict mode to pyright

* Apply UP rule

* fix readme

* More correct

* Leave wrong Context for now

* Add strict mode to pyright

* Apply UP rule

* fix readme

* fix

* ignore
This commit is contained in:
Marcelo Trylesinski
2025-03-20 09:13:08 +00:00
committed by GitHub
parent 5a54d82459
commit ae77772ea8
27 changed files with 194 additions and 133 deletions

View File

@@ -2,8 +2,8 @@
import inspect
import json
from collections.abc import Callable
from typing import Any, Awaitable, Literal, Sequence
from collections.abc import Awaitable, Callable, Sequence
from typing import Any, Literal
import pydantic_core
from pydantic import BaseModel, Field, TypeAdapter, validate_call
@@ -19,7 +19,7 @@ class Message(BaseModel):
role: Literal["user", "assistant"]
content: CONTENT_TYPES
def __init__(self, content: str | CONTENT_TYPES, **kwargs):
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
if isinstance(content, str):
content = TextContent(type="text", text=content)
super().__init__(content=content, **kwargs)
@@ -30,7 +30,7 @@ class UserMessage(Message):
role: Literal["user", "assistant"] = "user"
def __init__(self, content: str | CONTENT_TYPES, **kwargs):
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
super().__init__(content=content, **kwargs)
@@ -39,11 +39,13 @@ class AssistantMessage(Message):
role: Literal["user", "assistant"] = "assistant"
def __init__(self, content: str | CONTENT_TYPES, **kwargs):
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
super().__init__(content=content, **kwargs)
message_validator = TypeAdapter(UserMessage | AssistantMessage)
message_validator = TypeAdapter[UserMessage | AssistantMessage](
UserMessage | AssistantMessage
)
SyncPromptResult = (
str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]]
@@ -73,12 +75,12 @@ class Prompt(BaseModel):
arguments: list[PromptArgument] | None = Field(
None, description="Arguments that can be passed to the prompt"
)
fn: Callable = Field(exclude=True)
fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True)
@classmethod
def from_function(
cls,
fn: Callable[..., PromptResult],
fn: Callable[..., PromptResult | Awaitable[PromptResult]],
name: str | None = None,
description: str | None = None,
) -> "Prompt":
@@ -99,7 +101,7 @@ class Prompt(BaseModel):
parameters = TypeAdapter(fn).json_schema()
# Convert parameters to PromptArguments
arguments = []
arguments: list[PromptArgument] = []
if "properties" in parameters:
for param_name, param in parameters["properties"].items():
required = param_name in parameters.get("required", [])
@@ -138,25 +140,23 @@ class Prompt(BaseModel):
result = await result
# Validate messages
if not isinstance(result, (list, tuple)):
if not isinstance(result, list | tuple):
result = [result]
# Convert result to messages
messages = []
for msg in result:
messages: list[Message] = []
for msg in result: # type: ignore[reportUnknownVariableType]
try:
if isinstance(msg, Message):
messages.append(msg)
elif isinstance(msg, dict):
msg = message_validator.validate_python(msg)
messages.append(msg)
messages.append(message_validator.validate_python(msg))
elif isinstance(msg, str):
messages.append(
UserMessage(content=TextContent(type="text", text=msg))
)
content = TextContent(type="text", text=msg)
messages.append(UserMessage(content=content))
else:
msg = json.dumps(pydantic_core.to_jsonable_python(msg))
messages.append(Message(role="user", content=msg))
content = json.dumps(pydantic_core.to_jsonable_python(msg))
messages.append(Message(role="user", content=content))
except Exception:
raise ValueError(
f"Could not convert prompt result to message: {msg}"

View File

@@ -1,6 +1,7 @@
"""Resource manager functionality."""
from typing import Callable
from collections.abc import Callable
from typing import Any
from pydantic import AnyUrl
@@ -47,7 +48,7 @@ class ResourceManager:
def add_template(
self,
fn: Callable,
fn: Callable[..., Any],
uri_template: str,
name: str | None = None,
description: str | None = None,

View File

@@ -1,8 +1,11 @@
"""Resource template functionality."""
from __future__ import annotations
import inspect
import re
from typing import Any, Callable
from collections.abc import Callable
from typing import Any
from pydantic import BaseModel, Field, TypeAdapter, validate_call
@@ -20,18 +23,20 @@ class ResourceTemplate(BaseModel):
mime_type: str = Field(
default="text/plain", description="MIME type of the resource content"
)
fn: Callable = Field(exclude=True)
parameters: dict = Field(description="JSON schema for function parameters")
fn: Callable[..., Any] = Field(exclude=True)
parameters: dict[str, Any] = Field(
description="JSON schema for function parameters"
)
@classmethod
def from_function(
cls,
fn: Callable,
fn: Callable[..., Any],
uri_template: str,
name: str | None = None,
description: str | None = None,
mime_type: str | None = None,
) -> "ResourceTemplate":
) -> ResourceTemplate:
"""Create a template from a function."""
func_name = name or fn.__name__
if func_name == "<lambda>":

View File

@@ -5,13 +5,13 @@ from __future__ import annotations as _annotations
import inspect
import json
import re
from collections.abc import AsyncIterator, Iterable
from collections.abc import AsyncIterator, Callable, Iterable, Sequence
from contextlib import (
AbstractAsyncContextManager,
asynccontextmanager,
)
from itertools import chain
from typing import Any, Callable, Generic, Literal, Sequence
from typing import Any, Generic, Literal
import anyio
import pydantic_core
@@ -20,6 +20,7 @@ from pydantic import BaseModel, Field
from pydantic.networks import AnyUrl
from pydantic_settings import BaseSettings, SettingsConfigDict
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.routing import Mount, Route
from mcp.server.fastmcp.exceptions import ResourceError
@@ -88,13 +89,13 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
)
lifespan: (
Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]] | None
Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None
) = Field(None, description="Lifespan context manager")
def lifespan_wrapper(
app: FastMCP,
lifespan: Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]],
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]],
) -> Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[object]]:
@asynccontextmanager
async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]:
@@ -179,7 +180,7 @@ class FastMCP:
for info in tools
]
def get_context(self) -> "Context[ServerSession, object]":
def get_context(self) -> Context[ServerSession, object]:
"""
Returns a Context object. Note that the context will only be valid
during a request; outside a request, most methods will error.
@@ -478,9 +479,11 @@ class FastMCP:
"""Return an instance of the SSE server app."""
sse = SseServerTransport("/messages/")
async def handle_sse(request):
async def handle_sse(request: Request) -> None:
async with sse.connect_sse(
request.scope, request.receive, request._send
request.scope,
request.receive,
request._send, # type: ignore[reportPrivateUsage]
) as streams:
await self._mcp_server.run(
streams[0],
@@ -535,14 +538,14 @@ def _convert_to_content(
if result is None:
return []
if isinstance(result, (TextContent, ImageContent, EmbeddedResource)):
if isinstance(result, TextContent | ImageContent | EmbeddedResource):
return [result]
if isinstance(result, Image):
return [result.to_image_content()]
if isinstance(result, (list, tuple)):
return list(chain.from_iterable(_convert_to_content(item) for item in result))
if isinstance(result, list | tuple):
return list(chain.from_iterable(_convert_to_content(item) for item in result)) # type: ignore[reportUnknownVariableType]
if not isinstance(result, str):
try:

View File

@@ -1,11 +1,11 @@
from __future__ import annotations as _annotations
import inspect
from typing import TYPE_CHECKING, Any, Callable
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, Field
import mcp.server.fastmcp
from mcp.server.fastmcp.exceptions import ToolError
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
@@ -38,8 +38,10 @@ class Tool(BaseModel):
name: str | None = None,
description: str | None = None,
context_kwarg: str | None = None,
) -> "Tool":
) -> Tool:
"""Create a Tool from a function."""
from mcp.server.fastmcp import Context
func_name = name or fn.__name__
if func_name == "<lambda>":
@@ -48,11 +50,10 @@ class Tool(BaseModel):
func_doc = description or fn.__doc__ or ""
is_async = inspect.iscoroutinefunction(fn)
# Find context parameter if it exists
if context_kwarg is None:
sig = inspect.signature(fn)
for param_name, param in sig.parameters.items():
if param.annotation is mcp.server.fastmcp.Context:
if param.annotation is Context:
context_kwarg = param_name
break

View File

@@ -32,7 +32,7 @@ class ToolManager:
def add_tool(
self,
fn: Callable,
fn: Callable[..., Any],
name: str | None = None,
description: str | None = None,
) -> Tool:

View File

@@ -80,7 +80,7 @@ class FuncMetadata(BaseModel):
dicts (JSON objects) as JSON strings, which can be pre-parsed here.
"""
new_data = data.copy() # Shallow copy
for field_name, field_info in self.arg_model.model_fields.items():
for field_name, _field_info in self.arg_model.model_fields.items():
if field_name not in data.keys():
continue
if isinstance(data[field_name], str):
@@ -177,7 +177,9 @@ def func_metadata(
def _get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any:
def try_eval_type(value, globalns, localns):
def try_eval_type(
value: Any, globalns: dict[str, Any], localns: dict[str, Any]
) -> tuple[Any, bool]:
try:
return eval_type_backport(value, globalns, localns), True
except NameError:

View File

@@ -24,7 +24,7 @@ def configure_logging(
Args:
level: the log level to use
"""
handlers = []
handlers: list[logging.Handler] = []
try:
from rich.console import Console
from rich.logging import RichHandler

View File

@@ -69,9 +69,9 @@ from __future__ import annotations as _annotations
import contextvars
import logging
import warnings
from collections.abc import Awaitable, Callable, Iterable
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
from typing import Any, AsyncIterator, Generic, TypeVar
from typing import Any, Generic, TypeVar
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
@@ -155,9 +155,7 @@ class Server(Generic[LifespanResultT]):
try:
from importlib.metadata import version
v = version(package)
if v is not None:
return v
return version(package)
except Exception:
pass
@@ -320,7 +318,6 @@ class Server(Generic[LifespanResultT]):
contents_list = [
create_content(content_item.content, content_item.mime_type)
for content_item in contents
if isinstance(content_item, ReadResourceContents)
]
return types.ServerResult(
types.ReadResourceResult(
@@ -511,7 +508,8 @@ class Server(Generic[LifespanResultT]):
raise_exceptions: bool = False,
):
with warnings.catch_warnings(record=True) as w:
match message:
# TODO(Marcelo): We should be checking if message is Exception here.
match message: # type: ignore[reportMatchNotExhaustive]
case (
RequestResponder(request=types.ClientRequest(root=req)) as responder
):
@@ -527,7 +525,7 @@ class Server(Generic[LifespanResultT]):
async def _handle_request(
self,
message: RequestResponder,
message: RequestResponder[types.ClientRequest, types.ServerResult],
req: Any,
session: ServerSession,
lifespan_context: LifespanResultT,