mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Add strict mode to pyright (#315)
* Add strict mode to pyright * Apply UP rule * fix readme * More correct * Leave wrong Context for now * Add strict mode to pyright * Apply UP rule * fix readme * fix * ignore
This commit is contained in:
committed by
GitHub
parent
5a54d82459
commit
ae77772ea8
@@ -16,8 +16,11 @@
|
|||||||
<!-- omit in toc -->
|
<!-- omit in toc -->
|
||||||
## Table of Contents
|
## Table of Contents
|
||||||
|
|
||||||
|
- [MCP Python SDK](#mcp-python-sdk)
|
||||||
- [Overview](#overview)
|
- [Overview](#overview)
|
||||||
- [Installation](#installation)
|
- [Installation](#installation)
|
||||||
|
- [Adding MCP to your python project](#adding-mcp-to-your-python-project)
|
||||||
|
- [Running the standalone MCP development tools](#running-the-standalone-mcp-development-tools)
|
||||||
- [Quickstart](#quickstart)
|
- [Quickstart](#quickstart)
|
||||||
- [What is MCP?](#what-is-mcp)
|
- [What is MCP?](#what-is-mcp)
|
||||||
- [Core Concepts](#core-concepts)
|
- [Core Concepts](#core-concepts)
|
||||||
@@ -143,8 +146,8 @@ The FastMCP server is your core interface to the MCP protocol. It handles connec
|
|||||||
```python
|
```python
|
||||||
# Add lifespan support for startup/shutdown with strong typing
|
# Add lifespan support for startup/shutdown with strong typing
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import AsyncIterator
|
|
||||||
|
|
||||||
from fake_database import Database # Replace with your actual DB type
|
from fake_database import Database # Replace with your actual DB type
|
||||||
|
|
||||||
@@ -442,7 +445,7 @@ For more control, you can use the low-level server implementation directly. This
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
|
||||||
from fake_database import Database # Replace with your actual DB type
|
from fake_database import Database # Replace with your actual DB type
|
||||||
|
|
||||||
|
|||||||
@@ -76,13 +76,11 @@ packages = ["src/mcp"]
|
|||||||
include = ["src/mcp", "tests"]
|
include = ["src/mcp", "tests"]
|
||||||
venvPath = "."
|
venvPath = "."
|
||||||
venv = ".venv"
|
venv = ".venv"
|
||||||
strict = [
|
strict = ["src/mcp/**/*.py"]
|
||||||
"src/mcp/server/fastmcp/tools/base.py",
|
exclude = ["src/mcp/types.py"]
|
||||||
"src/mcp/client/*.py"
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = ["E", "F", "I"]
|
select = ["E", "F", "I", "UP"]
|
||||||
ignore = []
|
ignore = []
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from mcp.server.fastmcp.utilities.logging import get_logger
|
from mcp.server.fastmcp.utilities.logging import get_logger
|
||||||
|
|
||||||
@@ -116,10 +117,7 @@ def update_claude_config(
|
|||||||
# Add fastmcp run command
|
# Add fastmcp run command
|
||||||
args.extend(["mcp", "run", file_spec])
|
args.extend(["mcp", "run", file_spec])
|
||||||
|
|
||||||
server_config = {
|
server_config: dict[str, Any] = {"command": "uv", "args": args}
|
||||||
"command": "uv",
|
|
||||||
"args": args,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add environment variables if specified
|
# Add environment variables if specified
|
||||||
if env_vars:
|
if env_vars:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import AsyncGenerator
|
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||||
|
|||||||
@@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
from collections.abc import Callable
|
from collections.abc import Awaitable, Callable, Sequence
|
||||||
from typing import Any, Awaitable, Literal, Sequence
|
from typing import Any, Literal
|
||||||
|
|
||||||
import pydantic_core
|
import pydantic_core
|
||||||
from pydantic import BaseModel, Field, TypeAdapter, validate_call
|
from pydantic import BaseModel, Field, TypeAdapter, validate_call
|
||||||
@@ -19,7 +19,7 @@ class Message(BaseModel):
|
|||||||
role: Literal["user", "assistant"]
|
role: Literal["user", "assistant"]
|
||||||
content: CONTENT_TYPES
|
content: CONTENT_TYPES
|
||||||
|
|
||||||
def __init__(self, content: str | CONTENT_TYPES, **kwargs):
|
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
content = TextContent(type="text", text=content)
|
content = TextContent(type="text", text=content)
|
||||||
super().__init__(content=content, **kwargs)
|
super().__init__(content=content, **kwargs)
|
||||||
@@ -30,7 +30,7 @@ class UserMessage(Message):
|
|||||||
|
|
||||||
role: Literal["user", "assistant"] = "user"
|
role: Literal["user", "assistant"] = "user"
|
||||||
|
|
||||||
def __init__(self, content: str | CONTENT_TYPES, **kwargs):
|
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
|
||||||
super().__init__(content=content, **kwargs)
|
super().__init__(content=content, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@@ -39,11 +39,13 @@ class AssistantMessage(Message):
|
|||||||
|
|
||||||
role: Literal["user", "assistant"] = "assistant"
|
role: Literal["user", "assistant"] = "assistant"
|
||||||
|
|
||||||
def __init__(self, content: str | CONTENT_TYPES, **kwargs):
|
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
|
||||||
super().__init__(content=content, **kwargs)
|
super().__init__(content=content, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
message_validator = TypeAdapter(UserMessage | AssistantMessage)
|
message_validator = TypeAdapter[UserMessage | AssistantMessage](
|
||||||
|
UserMessage | AssistantMessage
|
||||||
|
)
|
||||||
|
|
||||||
SyncPromptResult = (
|
SyncPromptResult = (
|
||||||
str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]]
|
str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]]
|
||||||
@@ -73,12 +75,12 @@ class Prompt(BaseModel):
|
|||||||
arguments: list[PromptArgument] | None = Field(
|
arguments: list[PromptArgument] | None = Field(
|
||||||
None, description="Arguments that can be passed to the prompt"
|
None, description="Arguments that can be passed to the prompt"
|
||||||
)
|
)
|
||||||
fn: Callable = Field(exclude=True)
|
fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_function(
|
def from_function(
|
||||||
cls,
|
cls,
|
||||||
fn: Callable[..., PromptResult],
|
fn: Callable[..., PromptResult | Awaitable[PromptResult]],
|
||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
description: str | None = None,
|
description: str | None = None,
|
||||||
) -> "Prompt":
|
) -> "Prompt":
|
||||||
@@ -99,7 +101,7 @@ class Prompt(BaseModel):
|
|||||||
parameters = TypeAdapter(fn).json_schema()
|
parameters = TypeAdapter(fn).json_schema()
|
||||||
|
|
||||||
# Convert parameters to PromptArguments
|
# Convert parameters to PromptArguments
|
||||||
arguments = []
|
arguments: list[PromptArgument] = []
|
||||||
if "properties" in parameters:
|
if "properties" in parameters:
|
||||||
for param_name, param in parameters["properties"].items():
|
for param_name, param in parameters["properties"].items():
|
||||||
required = param_name in parameters.get("required", [])
|
required = param_name in parameters.get("required", [])
|
||||||
@@ -138,25 +140,23 @@ class Prompt(BaseModel):
|
|||||||
result = await result
|
result = await result
|
||||||
|
|
||||||
# Validate messages
|
# Validate messages
|
||||||
if not isinstance(result, (list, tuple)):
|
if not isinstance(result, list | tuple):
|
||||||
result = [result]
|
result = [result]
|
||||||
|
|
||||||
# Convert result to messages
|
# Convert result to messages
|
||||||
messages = []
|
messages: list[Message] = []
|
||||||
for msg in result:
|
for msg in result: # type: ignore[reportUnknownVariableType]
|
||||||
try:
|
try:
|
||||||
if isinstance(msg, Message):
|
if isinstance(msg, Message):
|
||||||
messages.append(msg)
|
messages.append(msg)
|
||||||
elif isinstance(msg, dict):
|
elif isinstance(msg, dict):
|
||||||
msg = message_validator.validate_python(msg)
|
messages.append(message_validator.validate_python(msg))
|
||||||
messages.append(msg)
|
|
||||||
elif isinstance(msg, str):
|
elif isinstance(msg, str):
|
||||||
messages.append(
|
content = TextContent(type="text", text=msg)
|
||||||
UserMessage(content=TextContent(type="text", text=msg))
|
messages.append(UserMessage(content=content))
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
msg = json.dumps(pydantic_core.to_jsonable_python(msg))
|
content = json.dumps(pydantic_core.to_jsonable_python(msg))
|
||||||
messages.append(Message(role="user", content=msg))
|
messages.append(Message(role="user", content=content))
|
||||||
except Exception:
|
except Exception:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Could not convert prompt result to message: {msg}"
|
f"Could not convert prompt result to message: {msg}"
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Resource manager functionality."""
|
"""Resource manager functionality."""
|
||||||
|
|
||||||
from typing import Callable
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import AnyUrl
|
from pydantic import AnyUrl
|
||||||
|
|
||||||
@@ -47,7 +48,7 @@ class ResourceManager:
|
|||||||
|
|
||||||
def add_template(
|
def add_template(
|
||||||
self,
|
self,
|
||||||
fn: Callable,
|
fn: Callable[..., Any],
|
||||||
uri_template: str,
|
uri_template: str,
|
||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
description: str | None = None,
|
description: str | None = None,
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
"""Resource template functionality."""
|
"""Resource template functionality."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import re
|
import re
|
||||||
from typing import Any, Callable
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, TypeAdapter, validate_call
|
from pydantic import BaseModel, Field, TypeAdapter, validate_call
|
||||||
|
|
||||||
@@ -20,18 +23,20 @@ class ResourceTemplate(BaseModel):
|
|||||||
mime_type: str = Field(
|
mime_type: str = Field(
|
||||||
default="text/plain", description="MIME type of the resource content"
|
default="text/plain", description="MIME type of the resource content"
|
||||||
)
|
)
|
||||||
fn: Callable = Field(exclude=True)
|
fn: Callable[..., Any] = Field(exclude=True)
|
||||||
parameters: dict = Field(description="JSON schema for function parameters")
|
parameters: dict[str, Any] = Field(
|
||||||
|
description="JSON schema for function parameters"
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_function(
|
def from_function(
|
||||||
cls,
|
cls,
|
||||||
fn: Callable,
|
fn: Callable[..., Any],
|
||||||
uri_template: str,
|
uri_template: str,
|
||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
description: str | None = None,
|
description: str | None = None,
|
||||||
mime_type: str | None = None,
|
mime_type: str | None = None,
|
||||||
) -> "ResourceTemplate":
|
) -> ResourceTemplate:
|
||||||
"""Create a template from a function."""
|
"""Create a template from a function."""
|
||||||
func_name = name or fn.__name__
|
func_name = name or fn.__name__
|
||||||
if func_name == "<lambda>":
|
if func_name == "<lambda>":
|
||||||
|
|||||||
@@ -5,13 +5,13 @@ from __future__ import annotations as _annotations
|
|||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from collections.abc import AsyncIterator, Iterable
|
from collections.abc import AsyncIterator, Callable, Iterable, Sequence
|
||||||
from contextlib import (
|
from contextlib import (
|
||||||
AbstractAsyncContextManager,
|
AbstractAsyncContextManager,
|
||||||
asynccontextmanager,
|
asynccontextmanager,
|
||||||
)
|
)
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Any, Callable, Generic, Literal, Sequence
|
from typing import Any, Generic, Literal
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
import pydantic_core
|
import pydantic_core
|
||||||
@@ -20,6 +20,7 @@ from pydantic import BaseModel, Field
|
|||||||
from pydantic.networks import AnyUrl
|
from pydantic.networks import AnyUrl
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
from starlette.applications import Starlette
|
from starlette.applications import Starlette
|
||||||
|
from starlette.requests import Request
|
||||||
from starlette.routing import Mount, Route
|
from starlette.routing import Mount, Route
|
||||||
|
|
||||||
from mcp.server.fastmcp.exceptions import ResourceError
|
from mcp.server.fastmcp.exceptions import ResourceError
|
||||||
@@ -88,13 +89,13 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
lifespan: (
|
lifespan: (
|
||||||
Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]] | None
|
Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None
|
||||||
) = Field(None, description="Lifespan context manager")
|
) = Field(None, description="Lifespan context manager")
|
||||||
|
|
||||||
|
|
||||||
def lifespan_wrapper(
|
def lifespan_wrapper(
|
||||||
app: FastMCP,
|
app: FastMCP,
|
||||||
lifespan: Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]],
|
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]],
|
||||||
) -> Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[object]]:
|
) -> Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[object]]:
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]:
|
async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]:
|
||||||
@@ -179,7 +180,7 @@ class FastMCP:
|
|||||||
for info in tools
|
for info in tools
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_context(self) -> "Context[ServerSession, object]":
|
def get_context(self) -> Context[ServerSession, object]:
|
||||||
"""
|
"""
|
||||||
Returns a Context object. Note that the context will only be valid
|
Returns a Context object. Note that the context will only be valid
|
||||||
during a request; outside a request, most methods will error.
|
during a request; outside a request, most methods will error.
|
||||||
@@ -478,9 +479,11 @@ class FastMCP:
|
|||||||
"""Return an instance of the SSE server app."""
|
"""Return an instance of the SSE server app."""
|
||||||
sse = SseServerTransport("/messages/")
|
sse = SseServerTransport("/messages/")
|
||||||
|
|
||||||
async def handle_sse(request):
|
async def handle_sse(request: Request) -> None:
|
||||||
async with sse.connect_sse(
|
async with sse.connect_sse(
|
||||||
request.scope, request.receive, request._send
|
request.scope,
|
||||||
|
request.receive,
|
||||||
|
request._send, # type: ignore[reportPrivateUsage]
|
||||||
) as streams:
|
) as streams:
|
||||||
await self._mcp_server.run(
|
await self._mcp_server.run(
|
||||||
streams[0],
|
streams[0],
|
||||||
@@ -535,14 +538,14 @@ def _convert_to_content(
|
|||||||
if result is None:
|
if result is None:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if isinstance(result, (TextContent, ImageContent, EmbeddedResource)):
|
if isinstance(result, TextContent | ImageContent | EmbeddedResource):
|
||||||
return [result]
|
return [result]
|
||||||
|
|
||||||
if isinstance(result, Image):
|
if isinstance(result, Image):
|
||||||
return [result.to_image_content()]
|
return [result.to_image_content()]
|
||||||
|
|
||||||
if isinstance(result, (list, tuple)):
|
if isinstance(result, list | tuple):
|
||||||
return list(chain.from_iterable(_convert_to_content(item) for item in result))
|
return list(chain.from_iterable(_convert_to_content(item) for item in result)) # type: ignore[reportUnknownVariableType]
|
||||||
|
|
||||||
if not isinstance(result, str):
|
if not isinstance(result, str):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
from __future__ import annotations as _annotations
|
from __future__ import annotations as _annotations
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from typing import TYPE_CHECKING, Any, Callable
|
from collections.abc import Callable
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
import mcp.server.fastmcp
|
|
||||||
from mcp.server.fastmcp.exceptions import ToolError
|
from mcp.server.fastmcp.exceptions import ToolError
|
||||||
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
|
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
|
||||||
|
|
||||||
@@ -38,8 +38,10 @@ class Tool(BaseModel):
|
|||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
description: str | None = None,
|
description: str | None = None,
|
||||||
context_kwarg: str | None = None,
|
context_kwarg: str | None = None,
|
||||||
) -> "Tool":
|
) -> Tool:
|
||||||
"""Create a Tool from a function."""
|
"""Create a Tool from a function."""
|
||||||
|
from mcp.server.fastmcp import Context
|
||||||
|
|
||||||
func_name = name or fn.__name__
|
func_name = name or fn.__name__
|
||||||
|
|
||||||
if func_name == "<lambda>":
|
if func_name == "<lambda>":
|
||||||
@@ -48,11 +50,10 @@ class Tool(BaseModel):
|
|||||||
func_doc = description or fn.__doc__ or ""
|
func_doc = description or fn.__doc__ or ""
|
||||||
is_async = inspect.iscoroutinefunction(fn)
|
is_async = inspect.iscoroutinefunction(fn)
|
||||||
|
|
||||||
# Find context parameter if it exists
|
|
||||||
if context_kwarg is None:
|
if context_kwarg is None:
|
||||||
sig = inspect.signature(fn)
|
sig = inspect.signature(fn)
|
||||||
for param_name, param in sig.parameters.items():
|
for param_name, param in sig.parameters.items():
|
||||||
if param.annotation is mcp.server.fastmcp.Context:
|
if param.annotation is Context:
|
||||||
context_kwarg = param_name
|
context_kwarg = param_name
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ class ToolManager:
|
|||||||
|
|
||||||
def add_tool(
|
def add_tool(
|
||||||
self,
|
self,
|
||||||
fn: Callable,
|
fn: Callable[..., Any],
|
||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
description: str | None = None,
|
description: str | None = None,
|
||||||
) -> Tool:
|
) -> Tool:
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ class FuncMetadata(BaseModel):
|
|||||||
dicts (JSON objects) as JSON strings, which can be pre-parsed here.
|
dicts (JSON objects) as JSON strings, which can be pre-parsed here.
|
||||||
"""
|
"""
|
||||||
new_data = data.copy() # Shallow copy
|
new_data = data.copy() # Shallow copy
|
||||||
for field_name, field_info in self.arg_model.model_fields.items():
|
for field_name, _field_info in self.arg_model.model_fields.items():
|
||||||
if field_name not in data.keys():
|
if field_name not in data.keys():
|
||||||
continue
|
continue
|
||||||
if isinstance(data[field_name], str):
|
if isinstance(data[field_name], str):
|
||||||
@@ -177,7 +177,9 @@ def func_metadata(
|
|||||||
|
|
||||||
|
|
||||||
def _get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any:
|
def _get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any:
|
||||||
def try_eval_type(value, globalns, localns):
|
def try_eval_type(
|
||||||
|
value: Any, globalns: dict[str, Any], localns: dict[str, Any]
|
||||||
|
) -> tuple[Any, bool]:
|
||||||
try:
|
try:
|
||||||
return eval_type_backport(value, globalns, localns), True
|
return eval_type_backport(value, globalns, localns), True
|
||||||
except NameError:
|
except NameError:
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ def configure_logging(
|
|||||||
Args:
|
Args:
|
||||||
level: the log level to use
|
level: the log level to use
|
||||||
"""
|
"""
|
||||||
handlers = []
|
handlers: list[logging.Handler] = []
|
||||||
try:
|
try:
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.logging import RichHandler
|
from rich.logging import RichHandler
|
||||||
|
|||||||
@@ -69,9 +69,9 @@ from __future__ import annotations as _annotations
|
|||||||
import contextvars
|
import contextvars
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Awaitable, Callable, Iterable
|
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
|
||||||
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
|
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
|
||||||
from typing import Any, AsyncIterator, Generic, TypeVar
|
from typing import Any, Generic, TypeVar
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||||
@@ -155,9 +155,7 @@ class Server(Generic[LifespanResultT]):
|
|||||||
try:
|
try:
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
|
|
||||||
v = version(package)
|
return version(package)
|
||||||
if v is not None:
|
|
||||||
return v
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -320,7 +318,6 @@ class Server(Generic[LifespanResultT]):
|
|||||||
contents_list = [
|
contents_list = [
|
||||||
create_content(content_item.content, content_item.mime_type)
|
create_content(content_item.content, content_item.mime_type)
|
||||||
for content_item in contents
|
for content_item in contents
|
||||||
if isinstance(content_item, ReadResourceContents)
|
|
||||||
]
|
]
|
||||||
return types.ServerResult(
|
return types.ServerResult(
|
||||||
types.ReadResourceResult(
|
types.ReadResourceResult(
|
||||||
@@ -511,7 +508,8 @@ class Server(Generic[LifespanResultT]):
|
|||||||
raise_exceptions: bool = False,
|
raise_exceptions: bool = False,
|
||||||
):
|
):
|
||||||
with warnings.catch_warnings(record=True) as w:
|
with warnings.catch_warnings(record=True) as w:
|
||||||
match message:
|
# TODO(Marcelo): We should be checking if message is Exception here.
|
||||||
|
match message: # type: ignore[reportMatchNotExhaustive]
|
||||||
case (
|
case (
|
||||||
RequestResponder(request=types.ClientRequest(root=req)) as responder
|
RequestResponder(request=types.ClientRequest(root=req)) as responder
|
||||||
):
|
):
|
||||||
@@ -527,7 +525,7 @@ class Server(Generic[LifespanResultT]):
|
|||||||
|
|
||||||
async def _handle_request(
|
async def _handle_request(
|
||||||
self,
|
self,
|
||||||
message: RequestResponder,
|
message: RequestResponder[types.ClientRequest, types.ServerResult],
|
||||||
req: Any,
|
req: Any,
|
||||||
session: ServerSession,
|
session: ServerSession,
|
||||||
lifespan_context: LifespanResultT,
|
lifespan_context: LifespanResultT,
|
||||||
|
|||||||
@@ -2,9 +2,10 @@
|
|||||||
In-memory transports
|
In-memory transports
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import AsyncGenerator
|
from typing import Any
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||||
@@ -52,7 +53,7 @@ async def create_client_server_memory_streams() -> (
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def create_connected_server_and_client_session(
|
async def create_connected_server_and_client_session(
|
||||||
server: Server,
|
server: Server[Any],
|
||||||
read_timeout_seconds: timedelta | None = None,
|
read_timeout_seconds: timedelta | None = None,
|
||||||
sampling_callback: SamplingFnT | None = None,
|
sampling_callback: SamplingFnT | None = None,
|
||||||
list_roots_callback: ListRootsFnT | None = None,
|
list_roots_callback: ListRootsFnT | None = None,
|
||||||
|
|||||||
@@ -1,10 +1,19 @@
|
|||||||
|
from collections.abc import Generator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Generic
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from mcp.shared.context import RequestContext
|
from mcp.shared.context import RequestContext
|
||||||
from mcp.shared.session import BaseSession
|
from mcp.shared.session import (
|
||||||
|
BaseSession,
|
||||||
|
ReceiveNotificationT,
|
||||||
|
ReceiveRequestT,
|
||||||
|
SendNotificationT,
|
||||||
|
SendRequestT,
|
||||||
|
SendResultT,
|
||||||
|
)
|
||||||
from mcp.types import ProgressToken
|
from mcp.types import ProgressToken
|
||||||
|
|
||||||
|
|
||||||
@@ -14,8 +23,22 @@ class Progress(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ProgressContext:
|
class ProgressContext(
|
||||||
session: BaseSession
|
Generic[
|
||||||
|
SendRequestT,
|
||||||
|
SendNotificationT,
|
||||||
|
SendResultT,
|
||||||
|
ReceiveRequestT,
|
||||||
|
ReceiveNotificationT,
|
||||||
|
]
|
||||||
|
):
|
||||||
|
session: BaseSession[
|
||||||
|
SendRequestT,
|
||||||
|
SendNotificationT,
|
||||||
|
SendResultT,
|
||||||
|
ReceiveRequestT,
|
||||||
|
ReceiveNotificationT,
|
||||||
|
]
|
||||||
progress_token: ProgressToken
|
progress_token: ProgressToken
|
||||||
total: float | None
|
total: float | None
|
||||||
current: float = field(default=0.0, init=False)
|
current: float = field(default=0.0, init=False)
|
||||||
@@ -29,7 +52,27 @@ class ProgressContext:
|
|||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def progress(ctx: RequestContext, total: float | None = None):
|
def progress(
|
||||||
|
ctx: RequestContext[
|
||||||
|
BaseSession[
|
||||||
|
SendRequestT,
|
||||||
|
SendNotificationT,
|
||||||
|
SendResultT,
|
||||||
|
ReceiveRequestT,
|
||||||
|
ReceiveNotificationT,
|
||||||
|
]
|
||||||
|
],
|
||||||
|
total: float | None = None,
|
||||||
|
) -> Generator[
|
||||||
|
ProgressContext[
|
||||||
|
SendRequestT,
|
||||||
|
SendNotificationT,
|
||||||
|
SendResultT,
|
||||||
|
ReceiveRequestT,
|
||||||
|
ReceiveNotificationT,
|
||||||
|
],
|
||||||
|
None,
|
||||||
|
]:
|
||||||
if ctx.meta is None or ctx.meta.progressToken is None:
|
if ctx.meta is None or ctx.meta.progressToken is None:
|
||||||
raise ValueError("No progress token provided")
|
raise ValueError("No progress token provided")
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from collections.abc import Callable
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Any, Callable, Generic, TypeVar
|
from types import TracebackType
|
||||||
|
from typing import Any, Generic, TypeVar
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
import anyio.lowlevel
|
import anyio.lowlevel
|
||||||
@@ -86,7 +88,12 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
|||||||
self._cancel_scope.__enter__()
|
self._cancel_scope.__enter__()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_val: BaseException | None,
|
||||||
|
exc_tb: TracebackType | None,
|
||||||
|
) -> None:
|
||||||
"""Exit the context manager, performing cleanup and notifying completion."""
|
"""Exit the context manager, performing cleanup and notifying completion."""
|
||||||
try:
|
try:
|
||||||
if self._completed:
|
if self._completed:
|
||||||
@@ -112,7 +119,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
|||||||
if not self.cancelled:
|
if not self.cancelled:
|
||||||
self._completed = True
|
self._completed = True
|
||||||
|
|
||||||
await self._session._send_response(
|
await self._session._send_response( # type: ignore[reportPrivateUsage]
|
||||||
request_id=self.request_id, response=response
|
request_id=self.request_id, response=response
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -126,7 +133,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
|||||||
self._cancel_scope.cancel()
|
self._cancel_scope.cancel()
|
||||||
self._completed = True # Mark as completed so it's removed from in_flight
|
self._completed = True # Mark as completed so it's removed from in_flight
|
||||||
# Send an error response to indicate cancellation
|
# Send an error response to indicate cancellation
|
||||||
await self._session._send_response(
|
await self._session._send_response( # type: ignore[reportPrivateUsage]
|
||||||
request_id=self.request_id,
|
request_id=self.request_id,
|
||||||
response=ErrorData(code=0, message="Request cancelled", data=None),
|
response=ErrorData(code=0, message="Request cancelled", data=None),
|
||||||
)
|
)
|
||||||
@@ -137,7 +144,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def cancelled(self) -> bool:
|
def cancelled(self) -> bool:
|
||||||
return self._cancel_scope is not None and self._cancel_scope.cancel_called
|
return self._cancel_scope.cancel_called
|
||||||
|
|
||||||
|
|
||||||
class BaseSession(
|
class BaseSession(
|
||||||
@@ -202,7 +209,12 @@ class BaseSession(
|
|||||||
self._task_group.start_soon(self._receive_loop)
|
self._task_group.start_soon(self._receive_loop)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_val: BaseException | None,
|
||||||
|
exc_tb: TracebackType | None,
|
||||||
|
) -> bool | None:
|
||||||
await self._exit_stack.aclose()
|
await self._exit_stack.aclose()
|
||||||
# Using BaseSession as a context manager should not block on exit (this
|
# Using BaseSession as a context manager should not block on exit (this
|
||||||
# would be very surprising behavior), so make sure to cancel the tasks
|
# would be very surprising behavior), so make sure to cancel the tasks
|
||||||
@@ -324,7 +336,7 @@ class BaseSession(
|
|||||||
|
|
||||||
self._in_flight[responder.request_id] = responder
|
self._in_flight[responder.request_id] = responder
|
||||||
await self._received_request(responder)
|
await self._received_request(responder)
|
||||||
if not responder._completed:
|
if not responder._completed: # type: ignore[reportPrivateUsage]
|
||||||
await self._incoming_message_stream_writer.send(responder)
|
await self._incoming_message_stream_writer.send(responder)
|
||||||
|
|
||||||
elif isinstance(message.root, JSONRPCNotification):
|
elif isinstance(message.root, JSONRPCNotification):
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
from typing import (
|
from typing import (
|
||||||
Annotated,
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
|
||||||
Generic,
|
Generic,
|
||||||
Literal,
|
Literal,
|
||||||
TypeAlias,
|
TypeAlias,
|
||||||
@@ -89,6 +89,7 @@ class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
|
|||||||
"""Base class for JSON-RPC notifications."""
|
"""Base class for JSON-RPC notifications."""
|
||||||
|
|
||||||
method: MethodT
|
method: MethodT
|
||||||
|
params: NotificationParamsT
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
@@ -1010,7 +1011,9 @@ class CancelledNotificationParams(NotificationParams):
|
|||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
class CancelledNotification(Notification):
|
class CancelledNotification(
|
||||||
|
Notification[CancelledNotificationParams, Literal["notifications/cancelled"]]
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
This notification can be sent by either side to indicate that it is cancelling a
|
This notification can be sent by either side to indicate that it is cancelling a
|
||||||
previously-issued request.
|
previously-issued request.
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import subprocess
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -8,7 +9,7 @@ from mcp.cli.claude import update_claude_config
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def temp_config_dir(tmp_path):
|
def temp_config_dir(tmp_path: Path):
|
||||||
"""Create a temporary Claude config directory."""
|
"""Create a temporary Claude config directory."""
|
||||||
config_dir = tmp_path / "Claude"
|
config_dir = tmp_path / "Claude"
|
||||||
config_dir.mkdir()
|
config_dir.mkdir()
|
||||||
@@ -16,23 +17,20 @@ def temp_config_dir(tmp_path):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_config_path(temp_config_dir):
|
def mock_config_path(temp_config_dir: Path):
|
||||||
"""Mock get_claude_config_path to return our temporary directory."""
|
"""Mock get_claude_config_path to return our temporary directory."""
|
||||||
with patch("mcp.cli.claude.get_claude_config_path", return_value=temp_config_dir):
|
with patch("mcp.cli.claude.get_claude_config_path", return_value=temp_config_dir):
|
||||||
yield temp_config_dir
|
yield temp_config_dir
|
||||||
|
|
||||||
|
|
||||||
def test_command_execution(mock_config_path):
|
def test_command_execution(mock_config_path: Path):
|
||||||
"""Test that the generated command can actually be executed."""
|
"""Test that the generated command can actually be executed."""
|
||||||
# Setup
|
# Setup
|
||||||
server_name = "test_server"
|
server_name = "test_server"
|
||||||
file_spec = "test_server.py:app"
|
file_spec = "test_server.py:app"
|
||||||
|
|
||||||
# Update config
|
# Update config
|
||||||
success = update_claude_config(
|
success = update_claude_config(file_spec=file_spec, server_name=server_name)
|
||||||
file_spec=file_spec,
|
|
||||||
server_name=server_name,
|
|
||||||
)
|
|
||||||
assert success
|
assert success
|
||||||
|
|
||||||
# Read the generated config
|
# Read the generated config
|
||||||
|
|||||||
@@ -7,11 +7,7 @@ from mcp.shared.context import RequestContext
|
|||||||
from mcp.shared.memory import (
|
from mcp.shared.memory import (
|
||||||
create_connected_server_and_client_session as create_session,
|
create_connected_server_and_client_session as create_session,
|
||||||
)
|
)
|
||||||
from mcp.types import (
|
from mcp.types import ListRootsResult, Root, TextContent
|
||||||
ListRootsResult,
|
|
||||||
Root,
|
|
||||||
TextContent,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -39,7 +35,7 @@ async def test_list_roots_callback():
|
|||||||
return callback_return
|
return callback_return
|
||||||
|
|
||||||
@server.tool("test_list_roots")
|
@server.tool("test_list_roots")
|
||||||
async def test_list_roots(context: Context, message: str):
|
async def test_list_roots(context: Context, message: str): # type: ignore[reportUnknownMemberType]
|
||||||
roots = await context.session.list_roots()
|
roots = await context.session.list_roots()
|
||||||
assert roots == callback_return
|
assert roots == callback_return
|
||||||
return True
|
return True
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import List, Literal
|
from typing import Literal
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
import pytest
|
import pytest
|
||||||
@@ -14,7 +14,7 @@ from mcp.types import (
|
|||||||
|
|
||||||
class LoggingCollector:
|
class LoggingCollector:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.log_messages: List[LoggingMessageNotificationParams] = []
|
self.log_messages: list[LoggingMessageNotificationParams] = []
|
||||||
|
|
||||||
async def __call__(self, params: LoggingMessageNotificationParams) -> None:
|
async def __call__(self, params: LoggingMessageNotificationParams) -> None:
|
||||||
self.log_messages.append(params)
|
self.log_messages.append(params)
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
"""Test to reproduce issue #88: Random error thrown on response."""
|
"""Test to reproduce issue #88: Random error thrown on response."""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Sequence
|
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
import pytest
|
import pytest
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import base64
|
import base64
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Union
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import AnyUrl
|
from pydantic import AnyUrl
|
||||||
@@ -114,7 +114,7 @@ def image_tool_fn(path: str) -> Image:
|
|||||||
return Image(path)
|
return Image(path)
|
||||||
|
|
||||||
|
|
||||||
def mixed_content_tool_fn() -> list[Union[TextContent, ImageContent]]:
|
def mixed_content_tool_fn() -> list[TextContent | ImageContent]:
|
||||||
return [
|
return [
|
||||||
TextContent(type="text", text="Hello"),
|
TextContent(type="text", text="Hello"),
|
||||||
ImageContent(type="image", data="abc", mimeType="image/png"),
|
ImageContent(type="image", data="abc", mimeType="image/png"),
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -296,7 +295,7 @@ class TestContextHandling:
|
|||||||
"""Test that context is optional when calling tools."""
|
"""Test that context is optional when calling tools."""
|
||||||
from mcp.server.fastmcp import Context
|
from mcp.server.fastmcp import Context
|
||||||
|
|
||||||
def tool_with_context(x: int, ctx: Optional[Context] = None) -> str:
|
def tool_with_context(x: int, ctx: Context | None = None) -> str:
|
||||||
return str(x)
|
return str(x)
|
||||||
|
|
||||||
manager = ToolManager()
|
manager = ToolManager()
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Tests for lifespan functionality in both low-level and FastMCP servers."""
|
"""Tests for lifespan functionality in both low-level and FastMCP servers."""
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import AsyncIterator
|
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
import pytest
|
import pytest
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
import pytest
|
import pytest
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import multiprocessing
|
import multiprocessing
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
from typing import AsyncGenerator, Generator
|
from collections.abc import AsyncGenerator, Generator
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
import httpx
|
import httpx
|
||||||
@@ -139,7 +139,7 @@ def server(server_port: int) -> Generator[None, None, None]:
|
|||||||
attempt += 1
|
attempt += 1
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Server failed to start after {} attempts".format(max_attempts)
|
f"Server failed to start after {max_attempts} attempts"
|
||||||
)
|
)
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import multiprocessing
|
import multiprocessing
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
from typing import AsyncGenerator, Generator
|
from collections.abc import AsyncGenerator, Generator
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
import pytest
|
import pytest
|
||||||
@@ -135,7 +135,7 @@ def server(server_port: int) -> Generator[None, None, None]:
|
|||||||
attempt += 1
|
attempt += 1
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Server failed to start after {} attempts".format(max_attempts)
|
f"Server failed to start after {max_attempts} attempts"
|
||||||
)
|
)
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|||||||
Reference in New Issue
Block a user