Add strict mode to pyright (#315)

* Add strict mode to pyright

* Apply UP rule

* fix readme

* More correct

* Leave wrong Context for now

* Add strict mode to pyright

* Apply UP rule

* fix readme

* fix

* ignore
This commit is contained in:
Marcelo Trylesinski
2025-03-20 09:13:08 +00:00
committed by GitHub
parent 5a54d82459
commit ae77772ea8
27 changed files with 194 additions and 133 deletions

View File

@@ -2,8 +2,8 @@
import inspect
import json
from collections.abc import Callable
from typing import Any, Awaitable, Literal, Sequence
from collections.abc import Awaitable, Callable, Sequence
from typing import Any, Literal
import pydantic_core
from pydantic import BaseModel, Field, TypeAdapter, validate_call
@@ -19,7 +19,7 @@ class Message(BaseModel):
role: Literal["user", "assistant"]
content: CONTENT_TYPES
def __init__(self, content: str | CONTENT_TYPES, **kwargs):
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
if isinstance(content, str):
content = TextContent(type="text", text=content)
super().__init__(content=content, **kwargs)
@@ -30,7 +30,7 @@ class UserMessage(Message):
role: Literal["user", "assistant"] = "user"
def __init__(self, content: str | CONTENT_TYPES, **kwargs):
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
super().__init__(content=content, **kwargs)
@@ -39,11 +39,13 @@ class AssistantMessage(Message):
role: Literal["user", "assistant"] = "assistant"
def __init__(self, content: str | CONTENT_TYPES, **kwargs):
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
super().__init__(content=content, **kwargs)
message_validator = TypeAdapter(UserMessage | AssistantMessage)
message_validator = TypeAdapter[UserMessage | AssistantMessage](
UserMessage | AssistantMessage
)
SyncPromptResult = (
str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]]
@@ -73,12 +75,12 @@ class Prompt(BaseModel):
arguments: list[PromptArgument] | None = Field(
None, description="Arguments that can be passed to the prompt"
)
fn: Callable = Field(exclude=True)
fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True)
@classmethod
def from_function(
cls,
fn: Callable[..., PromptResult],
fn: Callable[..., PromptResult | Awaitable[PromptResult]],
name: str | None = None,
description: str | None = None,
) -> "Prompt":
@@ -99,7 +101,7 @@ class Prompt(BaseModel):
parameters = TypeAdapter(fn).json_schema()
# Convert parameters to PromptArguments
arguments = []
arguments: list[PromptArgument] = []
if "properties" in parameters:
for param_name, param in parameters["properties"].items():
required = param_name in parameters.get("required", [])
@@ -138,25 +140,23 @@ class Prompt(BaseModel):
result = await result
# Validate messages
if not isinstance(result, (list, tuple)):
if not isinstance(result, list | tuple):
result = [result]
# Convert result to messages
messages = []
for msg in result:
messages: list[Message] = []
for msg in result: # type: ignore[reportUnknownVariableType]
try:
if isinstance(msg, Message):
messages.append(msg)
elif isinstance(msg, dict):
msg = message_validator.validate_python(msg)
messages.append(msg)
messages.append(message_validator.validate_python(msg))
elif isinstance(msg, str):
messages.append(
UserMessage(content=TextContent(type="text", text=msg))
)
content = TextContent(type="text", text=msg)
messages.append(UserMessage(content=content))
else:
msg = json.dumps(pydantic_core.to_jsonable_python(msg))
messages.append(Message(role="user", content=msg))
content = json.dumps(pydantic_core.to_jsonable_python(msg))
messages.append(Message(role="user", content=content))
except Exception:
raise ValueError(
f"Could not convert prompt result to message: {msg}"

View File

@@ -1,6 +1,7 @@
"""Resource manager functionality."""
from typing import Callable
from collections.abc import Callable
from typing import Any
from pydantic import AnyUrl
@@ -47,7 +48,7 @@ class ResourceManager:
def add_template(
self,
fn: Callable,
fn: Callable[..., Any],
uri_template: str,
name: str | None = None,
description: str | None = None,

View File

@@ -1,8 +1,11 @@
"""Resource template functionality."""
from __future__ import annotations
import inspect
import re
from typing import Any, Callable
from collections.abc import Callable
from typing import Any
from pydantic import BaseModel, Field, TypeAdapter, validate_call
@@ -20,18 +23,20 @@ class ResourceTemplate(BaseModel):
mime_type: str = Field(
default="text/plain", description="MIME type of the resource content"
)
fn: Callable = Field(exclude=True)
parameters: dict = Field(description="JSON schema for function parameters")
fn: Callable[..., Any] = Field(exclude=True)
parameters: dict[str, Any] = Field(
description="JSON schema for function parameters"
)
@classmethod
def from_function(
cls,
fn: Callable,
fn: Callable[..., Any],
uri_template: str,
name: str | None = None,
description: str | None = None,
mime_type: str | None = None,
) -> "ResourceTemplate":
) -> ResourceTemplate:
"""Create a template from a function."""
func_name = name or fn.__name__
if func_name == "<lambda>":

View File

@@ -5,13 +5,13 @@ from __future__ import annotations as _annotations
import inspect
import json
import re
from collections.abc import AsyncIterator, Iterable
from collections.abc import AsyncIterator, Callable, Iterable, Sequence
from contextlib import (
AbstractAsyncContextManager,
asynccontextmanager,
)
from itertools import chain
from typing import Any, Callable, Generic, Literal, Sequence
from typing import Any, Generic, Literal
import anyio
import pydantic_core
@@ -20,6 +20,7 @@ from pydantic import BaseModel, Field
from pydantic.networks import AnyUrl
from pydantic_settings import BaseSettings, SettingsConfigDict
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.routing import Mount, Route
from mcp.server.fastmcp.exceptions import ResourceError
@@ -88,13 +89,13 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
)
lifespan: (
Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]] | None
Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None
) = Field(None, description="Lifespan context manager")
def lifespan_wrapper(
app: FastMCP,
lifespan: Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]],
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]],
) -> Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[object]]:
@asynccontextmanager
async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]:
@@ -179,7 +180,7 @@ class FastMCP:
for info in tools
]
def get_context(self) -> "Context[ServerSession, object]":
def get_context(self) -> Context[ServerSession, object]:
"""
Returns a Context object. Note that the context will only be valid
during a request; outside a request, most methods will error.
@@ -478,9 +479,11 @@ class FastMCP:
"""Return an instance of the SSE server app."""
sse = SseServerTransport("/messages/")
async def handle_sse(request):
async def handle_sse(request: Request) -> None:
async with sse.connect_sse(
request.scope, request.receive, request._send
request.scope,
request.receive,
request._send, # type: ignore[reportPrivateUsage]
) as streams:
await self._mcp_server.run(
streams[0],
@@ -535,14 +538,14 @@ def _convert_to_content(
if result is None:
return []
if isinstance(result, (TextContent, ImageContent, EmbeddedResource)):
if isinstance(result, TextContent | ImageContent | EmbeddedResource):
return [result]
if isinstance(result, Image):
return [result.to_image_content()]
if isinstance(result, (list, tuple)):
return list(chain.from_iterable(_convert_to_content(item) for item in result))
if isinstance(result, list | tuple):
return list(chain.from_iterable(_convert_to_content(item) for item in result)) # type: ignore[reportUnknownVariableType]
if not isinstance(result, str):
try:

View File

@@ -1,11 +1,11 @@
from __future__ import annotations as _annotations
import inspect
from typing import TYPE_CHECKING, Any, Callable
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, Field
import mcp.server.fastmcp
from mcp.server.fastmcp.exceptions import ToolError
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
@@ -38,8 +38,10 @@ class Tool(BaseModel):
name: str | None = None,
description: str | None = None,
context_kwarg: str | None = None,
) -> "Tool":
) -> Tool:
"""Create a Tool from a function."""
from mcp.server.fastmcp import Context
func_name = name or fn.__name__
if func_name == "<lambda>":
@@ -48,11 +50,10 @@ class Tool(BaseModel):
func_doc = description or fn.__doc__ or ""
is_async = inspect.iscoroutinefunction(fn)
# Find context parameter if it exists
if context_kwarg is None:
sig = inspect.signature(fn)
for param_name, param in sig.parameters.items():
if param.annotation is mcp.server.fastmcp.Context:
if param.annotation is Context:
context_kwarg = param_name
break

View File

@@ -32,7 +32,7 @@ class ToolManager:
def add_tool(
self,
fn: Callable,
fn: Callable[..., Any],
name: str | None = None,
description: str | None = None,
) -> Tool:

View File

@@ -80,7 +80,7 @@ class FuncMetadata(BaseModel):
dicts (JSON objects) as JSON strings, which can be pre-parsed here.
"""
new_data = data.copy() # Shallow copy
for field_name, field_info in self.arg_model.model_fields.items():
for field_name, _field_info in self.arg_model.model_fields.items():
if field_name not in data.keys():
continue
if isinstance(data[field_name], str):
@@ -177,7 +177,9 @@ def func_metadata(
def _get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any:
def try_eval_type(value, globalns, localns):
def try_eval_type(
value: Any, globalns: dict[str, Any], localns: dict[str, Any]
) -> tuple[Any, bool]:
try:
return eval_type_backport(value, globalns, localns), True
except NameError:

View File

@@ -24,7 +24,7 @@ def configure_logging(
Args:
level: the log level to use
"""
handlers = []
handlers: list[logging.Handler] = []
try:
from rich.console import Console
from rich.logging import RichHandler