mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 06:54:18 +01:00
Add strict mode to pyright (#315)
* Add strict mode to pyright * Apply UP rule * fix readme * More correct * Leave wrong Context for now * Add strict mode to pyright * Apply UP rule * fix readme * fix * ignore
This commit is contained in:
committed by
GitHub
parent
5a54d82459
commit
ae77772ea8
@@ -2,8 +2,8 @@
|
||||
|
||||
import inspect
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Awaitable, Literal, Sequence
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
import pydantic_core
|
||||
from pydantic import BaseModel, Field, TypeAdapter, validate_call
|
||||
@@ -19,7 +19,7 @@ class Message(BaseModel):
|
||||
role: Literal["user", "assistant"]
|
||||
content: CONTENT_TYPES
|
||||
|
||||
def __init__(self, content: str | CONTENT_TYPES, **kwargs):
|
||||
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
|
||||
if isinstance(content, str):
|
||||
content = TextContent(type="text", text=content)
|
||||
super().__init__(content=content, **kwargs)
|
||||
@@ -30,7 +30,7 @@ class UserMessage(Message):
|
||||
|
||||
role: Literal["user", "assistant"] = "user"
|
||||
|
||||
def __init__(self, content: str | CONTENT_TYPES, **kwargs):
|
||||
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
@@ -39,11 +39,13 @@ class AssistantMessage(Message):
|
||||
|
||||
role: Literal["user", "assistant"] = "assistant"
|
||||
|
||||
def __init__(self, content: str | CONTENT_TYPES, **kwargs):
|
||||
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
message_validator = TypeAdapter(UserMessage | AssistantMessage)
|
||||
message_validator = TypeAdapter[UserMessage | AssistantMessage](
|
||||
UserMessage | AssistantMessage
|
||||
)
|
||||
|
||||
SyncPromptResult = (
|
||||
str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]]
|
||||
@@ -73,12 +75,12 @@ class Prompt(BaseModel):
|
||||
arguments: list[PromptArgument] | None = Field(
|
||||
None, description="Arguments that can be passed to the prompt"
|
||||
)
|
||||
fn: Callable = Field(exclude=True)
|
||||
fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True)
|
||||
|
||||
@classmethod
|
||||
def from_function(
|
||||
cls,
|
||||
fn: Callable[..., PromptResult],
|
||||
fn: Callable[..., PromptResult | Awaitable[PromptResult]],
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> "Prompt":
|
||||
@@ -99,7 +101,7 @@ class Prompt(BaseModel):
|
||||
parameters = TypeAdapter(fn).json_schema()
|
||||
|
||||
# Convert parameters to PromptArguments
|
||||
arguments = []
|
||||
arguments: list[PromptArgument] = []
|
||||
if "properties" in parameters:
|
||||
for param_name, param in parameters["properties"].items():
|
||||
required = param_name in parameters.get("required", [])
|
||||
@@ -138,25 +140,23 @@ class Prompt(BaseModel):
|
||||
result = await result
|
||||
|
||||
# Validate messages
|
||||
if not isinstance(result, (list, tuple)):
|
||||
if not isinstance(result, list | tuple):
|
||||
result = [result]
|
||||
|
||||
# Convert result to messages
|
||||
messages = []
|
||||
for msg in result:
|
||||
messages: list[Message] = []
|
||||
for msg in result: # type: ignore[reportUnknownVariableType]
|
||||
try:
|
||||
if isinstance(msg, Message):
|
||||
messages.append(msg)
|
||||
elif isinstance(msg, dict):
|
||||
msg = message_validator.validate_python(msg)
|
||||
messages.append(msg)
|
||||
messages.append(message_validator.validate_python(msg))
|
||||
elif isinstance(msg, str):
|
||||
messages.append(
|
||||
UserMessage(content=TextContent(type="text", text=msg))
|
||||
)
|
||||
content = TextContent(type="text", text=msg)
|
||||
messages.append(UserMessage(content=content))
|
||||
else:
|
||||
msg = json.dumps(pydantic_core.to_jsonable_python(msg))
|
||||
messages.append(Message(role="user", content=msg))
|
||||
content = json.dumps(pydantic_core.to_jsonable_python(msg))
|
||||
messages.append(Message(role="user", content=content))
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Could not convert prompt result to message: {msg}"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Resource manager functionality."""
|
||||
|
||||
from typing import Callable
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from pydantic import AnyUrl
|
||||
|
||||
@@ -47,7 +48,7 @@ class ResourceManager:
|
||||
|
||||
def add_template(
|
||||
self,
|
||||
fn: Callable,
|
||||
fn: Callable[..., Any],
|
||||
uri_template: str,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
"""Resource template functionality."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import re
|
||||
from typing import Any, Callable
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, TypeAdapter, validate_call
|
||||
|
||||
@@ -20,18 +23,20 @@ class ResourceTemplate(BaseModel):
|
||||
mime_type: str = Field(
|
||||
default="text/plain", description="MIME type of the resource content"
|
||||
)
|
||||
fn: Callable = Field(exclude=True)
|
||||
parameters: dict = Field(description="JSON schema for function parameters")
|
||||
fn: Callable[..., Any] = Field(exclude=True)
|
||||
parameters: dict[str, Any] = Field(
|
||||
description="JSON schema for function parameters"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_function(
|
||||
cls,
|
||||
fn: Callable,
|
||||
fn: Callable[..., Any],
|
||||
uri_template: str,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
) -> "ResourceTemplate":
|
||||
) -> ResourceTemplate:
|
||||
"""Create a template from a function."""
|
||||
func_name = name or fn.__name__
|
||||
if func_name == "<lambda>":
|
||||
|
||||
@@ -5,13 +5,13 @@ from __future__ import annotations as _annotations
|
||||
import inspect
|
||||
import json
|
||||
import re
|
||||
from collections.abc import AsyncIterator, Iterable
|
||||
from collections.abc import AsyncIterator, Callable, Iterable, Sequence
|
||||
from contextlib import (
|
||||
AbstractAsyncContextManager,
|
||||
asynccontextmanager,
|
||||
)
|
||||
from itertools import chain
|
||||
from typing import Any, Callable, Generic, Literal, Sequence
|
||||
from typing import Any, Generic, Literal
|
||||
|
||||
import anyio
|
||||
import pydantic_core
|
||||
@@ -20,6 +20,7 @@ from pydantic import BaseModel, Field
|
||||
from pydantic.networks import AnyUrl
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from starlette.applications import Starlette
|
||||
from starlette.requests import Request
|
||||
from starlette.routing import Mount, Route
|
||||
|
||||
from mcp.server.fastmcp.exceptions import ResourceError
|
||||
@@ -88,13 +89,13 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
|
||||
)
|
||||
|
||||
lifespan: (
|
||||
Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]] | None
|
||||
Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None
|
||||
) = Field(None, description="Lifespan context manager")
|
||||
|
||||
|
||||
def lifespan_wrapper(
|
||||
app: FastMCP,
|
||||
lifespan: Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]],
|
||||
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]],
|
||||
) -> Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[object]]:
|
||||
@asynccontextmanager
|
||||
async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]:
|
||||
@@ -179,7 +180,7 @@ class FastMCP:
|
||||
for info in tools
|
||||
]
|
||||
|
||||
def get_context(self) -> "Context[ServerSession, object]":
|
||||
def get_context(self) -> Context[ServerSession, object]:
|
||||
"""
|
||||
Returns a Context object. Note that the context will only be valid
|
||||
during a request; outside a request, most methods will error.
|
||||
@@ -478,9 +479,11 @@ class FastMCP:
|
||||
"""Return an instance of the SSE server app."""
|
||||
sse = SseServerTransport("/messages/")
|
||||
|
||||
async def handle_sse(request):
|
||||
async def handle_sse(request: Request) -> None:
|
||||
async with sse.connect_sse(
|
||||
request.scope, request.receive, request._send
|
||||
request.scope,
|
||||
request.receive,
|
||||
request._send, # type: ignore[reportPrivateUsage]
|
||||
) as streams:
|
||||
await self._mcp_server.run(
|
||||
streams[0],
|
||||
@@ -535,14 +538,14 @@ def _convert_to_content(
|
||||
if result is None:
|
||||
return []
|
||||
|
||||
if isinstance(result, (TextContent, ImageContent, EmbeddedResource)):
|
||||
if isinstance(result, TextContent | ImageContent | EmbeddedResource):
|
||||
return [result]
|
||||
|
||||
if isinstance(result, Image):
|
||||
return [result.to_image_content()]
|
||||
|
||||
if isinstance(result, (list, tuple)):
|
||||
return list(chain.from_iterable(_convert_to_content(item) for item in result))
|
||||
if isinstance(result, list | tuple):
|
||||
return list(chain.from_iterable(_convert_to_content(item) for item in result)) # type: ignore[reportUnknownVariableType]
|
||||
|
||||
if not isinstance(result, str):
|
||||
try:
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import mcp.server.fastmcp
|
||||
from mcp.server.fastmcp.exceptions import ToolError
|
||||
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
|
||||
|
||||
@@ -38,8 +38,10 @@ class Tool(BaseModel):
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
context_kwarg: str | None = None,
|
||||
) -> "Tool":
|
||||
) -> Tool:
|
||||
"""Create a Tool from a function."""
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
func_name = name or fn.__name__
|
||||
|
||||
if func_name == "<lambda>":
|
||||
@@ -48,11 +50,10 @@ class Tool(BaseModel):
|
||||
func_doc = description or fn.__doc__ or ""
|
||||
is_async = inspect.iscoroutinefunction(fn)
|
||||
|
||||
# Find context parameter if it exists
|
||||
if context_kwarg is None:
|
||||
sig = inspect.signature(fn)
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param.annotation is mcp.server.fastmcp.Context:
|
||||
if param.annotation is Context:
|
||||
context_kwarg = param_name
|
||||
break
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ class ToolManager:
|
||||
|
||||
def add_tool(
|
||||
self,
|
||||
fn: Callable,
|
||||
fn: Callable[..., Any],
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> Tool:
|
||||
|
||||
@@ -80,7 +80,7 @@ class FuncMetadata(BaseModel):
|
||||
dicts (JSON objects) as JSON strings, which can be pre-parsed here.
|
||||
"""
|
||||
new_data = data.copy() # Shallow copy
|
||||
for field_name, field_info in self.arg_model.model_fields.items():
|
||||
for field_name, _field_info in self.arg_model.model_fields.items():
|
||||
if field_name not in data.keys():
|
||||
continue
|
||||
if isinstance(data[field_name], str):
|
||||
@@ -177,7 +177,9 @@ def func_metadata(
|
||||
|
||||
|
||||
def _get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any:
|
||||
def try_eval_type(value, globalns, localns):
|
||||
def try_eval_type(
|
||||
value: Any, globalns: dict[str, Any], localns: dict[str, Any]
|
||||
) -> tuple[Any, bool]:
|
||||
try:
|
||||
return eval_type_backport(value, globalns, localns), True
|
||||
except NameError:
|
||||
|
||||
@@ -24,7 +24,7 @@ def configure_logging(
|
||||
Args:
|
||||
level: the log level to use
|
||||
"""
|
||||
handlers = []
|
||||
handlers: list[logging.Handler] = []
|
||||
try:
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
|
||||
@@ -69,9 +69,9 @@ from __future__ import annotations as _annotations
|
||||
import contextvars
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import Awaitable, Callable, Iterable
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
|
||||
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
|
||||
from typing import Any, AsyncIterator, Generic, TypeVar
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
@@ -155,9 +155,7 @@ class Server(Generic[LifespanResultT]):
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
|
||||
v = version(package)
|
||||
if v is not None:
|
||||
return v
|
||||
return version(package)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -320,7 +318,6 @@ class Server(Generic[LifespanResultT]):
|
||||
contents_list = [
|
||||
create_content(content_item.content, content_item.mime_type)
|
||||
for content_item in contents
|
||||
if isinstance(content_item, ReadResourceContents)
|
||||
]
|
||||
return types.ServerResult(
|
||||
types.ReadResourceResult(
|
||||
@@ -511,7 +508,8 @@ class Server(Generic[LifespanResultT]):
|
||||
raise_exceptions: bool = False,
|
||||
):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
match message:
|
||||
# TODO(Marcelo): We should be checking if message is Exception here.
|
||||
match message: # type: ignore[reportMatchNotExhaustive]
|
||||
case (
|
||||
RequestResponder(request=types.ClientRequest(root=req)) as responder
|
||||
):
|
||||
@@ -527,7 +525,7 @@ class Server(Generic[LifespanResultT]):
|
||||
|
||||
async def _handle_request(
|
||||
self,
|
||||
message: RequestResponder,
|
||||
message: RequestResponder[types.ClientRequest, types.ServerResult],
|
||||
req: Any,
|
||||
session: ServerSession,
|
||||
lifespan_context: LifespanResultT,
|
||||
|
||||
Reference in New Issue
Block a user