Merge branch 'modelcontextprotocol:main' into patch-1

This commit is contained in:
Henry Mao
2025-03-06 13:24:01 +08:00
committed by GitHub
37 changed files with 2066 additions and 155 deletions

View File

@@ -1,13 +1,51 @@
from datetime import timedelta
from typing import Any, Protocol
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
from pydantic import AnyUrl, TypeAdapter
import mcp.types as types
from mcp.shared.session import BaseSession
from mcp.shared.context import RequestContext
from mcp.shared.session import BaseSession, RequestResponder
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
class SamplingFnT(Protocol):
async def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.ErrorData: ...
class ListRootsFnT(Protocol):
async def __call__(
self, context: RequestContext["ClientSession", Any]
) -> types.ListRootsResult | types.ErrorData: ...
async def _default_sampling_callback(
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="Sampling not supported",
)
async def _default_list_roots_callback(
context: RequestContext["ClientSession", Any],
) -> types.ListRootsResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="List roots not supported",
)
ClientResponse = TypeAdapter(types.ClientResult | types.ErrorData)
class ClientSession(
BaseSession[
types.ClientRequest,
@@ -22,6 +60,8 @@ class ClientSession(
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
) -> None:
super().__init__(
read_stream,
@@ -30,8 +70,24 @@ class ClientSession(
types.ServerNotification,
read_timeout_seconds=read_timeout_seconds,
)
self._sampling_callback = sampling_callback or _default_sampling_callback
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
async def initialize(self) -> types.InitializeResult:
sampling = (
types.SamplingCapability() if self._sampling_callback is not None else None
)
roots = (
types.RootsCapability(
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
listChanged=True,
)
if self._list_roots_callback is not None
else None
)
result = await self.send_request(
types.ClientRequest(
types.InitializeRequest(
@@ -39,14 +95,9 @@ class ClientSession(
params=types.InitializeRequestParams(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=types.ClientCapabilities(
sampling=None,
sampling=sampling,
experimental=None,
roots=types.RootsCapability(
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
listChanged=True
),
roots=roots,
),
clientInfo=types.Implementation(name="mcp", version="0.1.0"),
),
@@ -120,6 +171,17 @@ class ClientSession(
types.ListResourcesResult,
)
async def list_resource_templates(self) -> types.ListResourceTemplatesResult:
"""Send a resources/templates/list request."""
return await self.send_request(
types.ClientRequest(
types.ListResourceTemplatesRequest(
method="resources/templates/list",
)
),
types.ListResourceTemplatesResult,
)
async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
"""Send a resources/read request."""
return await self.send_request(
@@ -232,3 +294,32 @@ class ClientSession(
)
)
)
async def _received_request(
self, responder: RequestResponder[types.ServerRequest, types.ClientResult]
) -> None:
ctx = RequestContext[ClientSession, Any](
request_id=responder.request_id,
meta=responder.request_meta,
session=self,
lifespan_context=None,
)
match responder.request.root:
case types.CreateMessageRequest(params=params):
with responder:
response = await self._sampling_callback(ctx, params)
client_response = ClientResponse.validate_python(response)
await responder.respond(client_response)
case types.ListRootsRequest():
with responder:
response = await self._list_roots_callback(ctx)
client_response = ClientResponse.validate_python(response)
await responder.respond(client_response)
case types.PingRequest():
with responder:
return await responder.respond(
types.ClientResult(root=types.EmptyResult())
)

View File

@@ -3,8 +3,13 @@
import inspect
import json
import re
from collections.abc import AsyncIterator, Iterable
from contextlib import (
AbstractAsyncContextManager,
asynccontextmanager,
)
from itertools import chain
from typing import Any, Callable, Literal, Sequence
from typing import Any, Callable, Generic, Literal, Sequence
import anyio
import pydantic_core
@@ -19,12 +24,22 @@ from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceMan
from mcp.server.fastmcp.tools import ToolManager
from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger
from mcp.server.fastmcp.utilities.types import Image
from mcp.server.lowlevel import Server as MCPServer
from mcp.server.lowlevel.helper_types import ReadResourceContents
from mcp.server.lowlevel.server import (
LifespanResultT,
)
from mcp.server.lowlevel.server import (
Server as MCPServer,
)
from mcp.server.lowlevel.server import (
lifespan as default_lifespan,
)
from mcp.server.session import ServerSession
from mcp.server.sse import SseServerTransport
from mcp.server.stdio import stdio_server
from mcp.shared.context import RequestContext
from mcp.types import (
AnyFunction,
EmbeddedResource,
GetPromptResult,
ImageContent,
@@ -49,7 +64,7 @@ from mcp.types import (
logger = get_logger(__name__)
class Settings(BaseSettings):
class Settings(BaseSettings, Generic[LifespanResultT]):
"""FastMCP server settings.
All settings can be configured via environment variables with the prefix FASTMCP_.
@@ -84,13 +99,36 @@ class Settings(BaseSettings):
description="List of dependencies to install in the server environment",
)
lifespan: (
Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]] | None
) = Field(None, description="Lifespan context manager")
def lifespan_wrapper(
app: "FastMCP",
lifespan: Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]],
) -> Callable[[MCPServer], AbstractAsyncContextManager[object]]:
@asynccontextmanager
async def wrap(s: MCPServer) -> AsyncIterator[object]:
async with lifespan(app) as context:
yield context
return wrap
class FastMCP:
def __init__(
self, name: str | None = None, instructions: str | None = None, **settings: Any
):
self.settings = Settings(**settings)
self._mcp_server = MCPServer(name=name or "FastMCP", instructions=instructions)
self._mcp_server = MCPServer(
name=name or "FastMCP",
instructions=instructions,
lifespan=lifespan_wrapper(self, self.settings.lifespan)
if self.settings.lifespan
else default_lifespan,
)
self._tool_manager = ToolManager(
warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools
)
@@ -165,7 +203,7 @@ class FastMCP:
return Context(request_context=request_context, fastmcp=self)
async def call_tool(
self, name: str, arguments: dict
self, name: str, arguments: dict[str, Any]
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
"""Call a tool by name with arguments."""
context = self.get_context()
@@ -198,7 +236,7 @@ class FastMCP:
for template in templates
]
async def read_resource(self, uri: AnyUrl | str) -> ReadResourceContents:
async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContents]:
"""Read a resource by URI."""
resource = await self._resource_manager.get_resource(uri)
@@ -207,14 +245,14 @@ class FastMCP:
try:
content = await resource.read()
return ReadResourceContents(content=content, mime_type=resource.mime_type)
return [ReadResourceContents(content=content, mime_type=resource.mime_type)]
except Exception as e:
logger.error(f"Error reading resource {uri}: {e}")
raise ResourceError(str(e))
def add_tool(
self,
fn: Callable,
fn: AnyFunction,
name: str | None = None,
description: str | None = None,
) -> None:
@@ -230,7 +268,9 @@ class FastMCP:
"""
self._tool_manager.add_tool(fn, name=name, description=description)
def tool(self, name: str | None = None, description: str | None = None) -> Callable:
def tool(
self, name: str | None = None, description: str | None = None
) -> Callable[[AnyFunction], AnyFunction]:
"""Decorator to register a tool.
Tools can optionally request a Context object by adding a parameter with the
@@ -263,7 +303,7 @@ class FastMCP:
"Did you forget to call it? Use @tool() instead of @tool"
)
def decorator(fn: Callable) -> Callable:
def decorator(fn: AnyFunction) -> AnyFunction:
self.add_tool(fn, name=name, description=description)
return fn
@@ -284,7 +324,7 @@ class FastMCP:
name: str | None = None,
description: str | None = None,
mime_type: str | None = None,
) -> Callable:
) -> Callable[[AnyFunction], AnyFunction]:
"""Decorator to register a function as a resource.
The function will be called when the resource is read to generate its content.
@@ -328,7 +368,7 @@ class FastMCP:
"Did you forget to call it? Use @resource('uri') instead of @resource"
)
def decorator(fn: Callable) -> Callable:
def decorator(fn: AnyFunction) -> AnyFunction:
# Check if this should be a template
has_uri_params = "{" in uri and "}" in uri
has_func_params = bool(inspect.signature(fn).parameters)
@@ -376,7 +416,7 @@ class FastMCP:
def prompt(
self, name: str | None = None, description: str | None = None
) -> Callable:
) -> Callable[[AnyFunction], AnyFunction]:
"""Decorator to register a prompt.
Args:
@@ -417,7 +457,7 @@ class FastMCP:
"Did you forget to call it? Use @prompt() instead of @prompt"
)
def decorator(func: Callable) -> Callable:
def decorator(func: AnyFunction) -> AnyFunction:
prompt = Prompt.from_function(func, name=name, description=description)
self.add_prompt(prompt)
return func
@@ -558,7 +598,7 @@ class Context(BaseModel):
The context is optional - tools that don't need it can omit the parameter.
"""
_request_context: RequestContext | None
_request_context: RequestContext[ServerSession, Any] | None
_fastmcp: FastMCP | None
def __init__(
@@ -602,14 +642,14 @@ class Context(BaseModel):
else None
)
if not progress_token:
if progress_token is None:
return
await self.request_context.session.send_progress_notification(
progress_token=progress_token, progress=progress, total=total
)
async def read_resource(self, uri: str | AnyUrl) -> ReadResourceContents:
async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]:
"""Read a resource by URI.
Args:

View File

@@ -67,9 +67,11 @@ messages from the client.
import contextvars
import logging
import warnings
from collections.abc import Awaitable, Callable
from typing import Any, Sequence
from collections.abc import Awaitable, Callable, Iterable
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
from typing import Any, AsyncIterator, Generic, TypeVar
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
@@ -84,7 +86,10 @@ from mcp.shared.session import RequestResponder
logger = logging.getLogger(__name__)
request_ctx: contextvars.ContextVar[RequestContext[ServerSession]] = (
LifespanResultT = TypeVar("LifespanResultT")
# This will be properly typed in each Server instance's context
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any]] = (
contextvars.ContextVar("request_ctx")
)
@@ -101,13 +106,33 @@ class NotificationOptions:
self.tools_changed = tools_changed
class Server:
@asynccontextmanager
async def lifespan(server: "Server") -> AsyncIterator[object]:
"""Default lifespan context manager that does nothing.
Args:
server: The server instance this lifespan is managing
Returns:
An empty context object
"""
yield {}
class Server(Generic[LifespanResultT]):
def __init__(
self, name: str, version: str | None = None, instructions: str | None = None
self,
name: str,
version: str | None = None,
instructions: str | None = None,
lifespan: Callable[
["Server"], AbstractAsyncContextManager[LifespanResultT]
] = lifespan,
):
self.name = name
self.version = version
self.instructions = instructions
self.lifespan = lifespan
self.request_handlers: dict[
type, Callable[..., Awaitable[types.ServerResult]]
] = {
@@ -188,7 +213,7 @@ class Server:
)
@property
def request_context(self) -> RequestContext[ServerSession]:
def request_context(self) -> RequestContext[ServerSession, LifespanResultT]:
"""If called outside of a request context, this will raise a LookupError."""
return request_ctx.get()
@@ -254,7 +279,9 @@ class Server:
def read_resource(self):
def decorator(
func: Callable[[AnyUrl], Awaitable[str | bytes | ReadResourceContents]],
func: Callable[
[AnyUrl], Awaitable[str | bytes | Iterable[ReadResourceContents]]
],
):
logger.debug("Registering handler for ReadResourceRequest")
@@ -282,13 +309,22 @@ class Server:
case str() | bytes() as data:
warnings.warn(
"Returning str or bytes from read_resource is deprecated. "
"Use ReadResourceContents instead.",
"Use Iterable[ReadResourceContents] instead.",
DeprecationWarning,
stacklevel=2,
)
content = create_content(data, None)
case ReadResourceContents() as contents:
content = create_content(contents.content, contents.mime_type)
case Iterable() as contents:
contents_list = [
create_content(content_item.content, content_item.mime_type)
for content_item in contents
if isinstance(content_item, ReadResourceContents)
]
return types.ServerResult(
types.ReadResourceResult(
contents=contents_list,
)
)
case _:
raise ValueError(
f"Unexpected return type from read_resource: {type(result)}"
@@ -362,7 +398,7 @@ class Server:
func: Callable[
...,
Awaitable[
Sequence[
Iterable[
types.TextContent | types.ImageContent | types.EmbeddedResource
]
],
@@ -445,31 +481,54 @@ class Server:
# in-process servers.
raise_exceptions: bool = False,
):
with warnings.catch_warnings(record=True) as w:
async with ServerSession(
read_stream, write_stream, initialization_options
) as session:
async with AsyncExitStack() as stack:
lifespan_context = await stack.enter_async_context(self.lifespan(self))
session = await stack.enter_async_context(
ServerSession(read_stream, write_stream, initialization_options)
)
async with anyio.create_task_group() as tg:
async for message in session.incoming_messages:
logger.debug(f"Received message: {message}")
match message:
case RequestResponder(request=types.ClientRequest(root=req)):
await self._handle_request(
message, req, session, raise_exceptions
)
case types.ClientNotification(root=notify):
await self._handle_notification(notify)
tg.start_soon(
self._handle_message,
message,
session,
lifespan_context,
raise_exceptions,
)
for warning in w:
logger.info(
f"Warning: {warning.category.__name__}: {warning.message}"
async def _handle_message(
self,
message: RequestResponder[types.ClientRequest, types.ServerResult]
| types.ClientNotification
| Exception,
session: ServerSession,
lifespan_context: LifespanResultT,
raise_exceptions: bool = False,
):
with warnings.catch_warnings(record=True) as w:
match message:
case (
RequestResponder(request=types.ClientRequest(root=req)) as responder
):
with responder:
await self._handle_request(
message, req, session, lifespan_context, raise_exceptions
)
case types.ClientNotification(root=notify):
await self._handle_notification(notify)
for warning in w:
logger.info(f"Warning: {warning.category.__name__}: {warning.message}")
async def _handle_request(
self,
message: RequestResponder,
req: Any,
session: ServerSession,
lifespan_context: LifespanResultT,
raise_exceptions: bool,
):
logger.info(f"Processing request of type {type(req).__name__}")
@@ -486,6 +545,7 @@ class Server:
message.request_id,
message.request_meta,
session,
lifespan_context,
)
)
response = await handler(req)

View File

@@ -126,19 +126,20 @@ class ServerSession(
case types.InitializeRequest(params=params):
self._initialization_state = InitializationState.Initializing
self._client_params = params
await responder.respond(
types.ServerResult(
types.InitializeResult(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=self._init_options.capabilities,
serverInfo=types.Implementation(
name=self._init_options.server_name,
version=self._init_options.server_version,
),
instructions=self._init_options.instructions,
with responder:
await responder.respond(
types.ServerResult(
types.InitializeResult(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=self._init_options.capabilities,
serverInfo=types.Implementation(
name=self._init_options.server_name,
version=self._init_options.server_version,
),
instructions=self._init_options.instructions,
)
)
)
)
case _:
if self._initialization_state != InitializationState.Initialized:
raise RuntimeError(

View File

@@ -20,6 +20,7 @@ Example usage:
import sys
from contextlib import asynccontextmanager
from io import TextIOWrapper
import anyio
import anyio.lowlevel
@@ -38,11 +39,13 @@ async def stdio_server(
from the current process' stdin and writing to stdout.
"""
# Purposely not using context managers for these, as we don't want to close
# standard process handles.
# standard process handles. Encoding of stdin/stdout as text streams on
# python is platform-dependent (Windows is particularly problematic), so we
# re-wrap the underlying binary stream to ensure UTF-8.
if not stdin:
stdin = anyio.wrap_file(sys.stdin)
stdin = anyio.wrap_file(TextIOWrapper(sys.stdin.buffer, encoding="utf-8"))
if not stdout:
stdout = anyio.wrap_file(sys.stdout)
stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8"))
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]

View File

@@ -5,10 +5,12 @@ from mcp.shared.session import BaseSession
from mcp.types import RequestId, RequestParams
SessionT = TypeVar("SessionT", bound=BaseSession)
LifespanContextT = TypeVar("LifespanContextT")
@dataclass
class RequestContext(Generic[SessionT]):
class RequestContext(Generic[SessionT, LifespanContextT]):
request_id: RequestId
meta: RequestParams.Meta | None
session: SessionT
lifespan_context: LifespanContextT

View File

@@ -9,7 +9,7 @@ from typing import AsyncGenerator
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp.client.session import ClientSession
from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT
from mcp.server import Server
from mcp.types import JSONRPCMessage
@@ -54,6 +54,8 @@ async def create_client_server_memory_streams() -> (
async def create_connected_server_and_client_session(
server: Server,
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
raise_exceptions: bool = False,
) -> AsyncGenerator[ClientSession, None]:
"""Creates a ClientSession that is connected to a running MCP server."""
@@ -80,6 +82,8 @@ async def create_connected_server_and_client_session(
read_stream=client_read,
write_stream=client_write,
read_timeout_seconds=read_timeout_seconds,
sampling_callback=sampling_callback,
list_roots_callback=list_roots_callback,
) as client_session:
await client_session.initialize()
yield client_session

View File

@@ -1,6 +1,7 @@
import logging
from contextlib import AbstractAsyncContextManager
from datetime import timedelta
from typing import Generic, TypeVar
from typing import Any, Callable, Generic, TypeVar
import anyio
import anyio.lowlevel
@@ -10,6 +11,7 @@ from pydantic import BaseModel
from mcp.shared.exceptions import McpError
from mcp.types import (
CancelledNotification,
ClientNotification,
ClientRequest,
ClientResult,
@@ -38,27 +40,98 @@ RequestId = str | int
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
"""Handles responding to MCP requests and manages request lifecycle.
This class MUST be used as a context manager to ensure proper cleanup and
cancellation handling:
Example:
with request_responder as resp:
await resp.respond(result)
The context manager ensures:
1. Proper cancellation scope setup and cleanup
2. Request completion tracking
3. Cleanup of in-flight requests
"""
def __init__(
self,
request_id: RequestId,
request_meta: RequestParams.Meta | None,
request: ReceiveRequestT,
session: "BaseSession",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
) -> None:
self.request_id = request_id
self.request_meta = request_meta
self.request = request
self._session = session
self._responded = False
self._completed = False
self._cancel_scope = anyio.CancelScope()
self._on_complete = on_complete
self._entered = False # Track if we're in a context manager
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
"""Enter the context manager, enabling request cancellation tracking."""
self._entered = True
self._cancel_scope = anyio.CancelScope()
self._cancel_scope.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Exit the context manager, performing cleanup and notifying completion."""
try:
if self._completed:
self._on_complete(self)
finally:
self._entered = False
if not self._cancel_scope:
raise RuntimeError("No active cancel scope")
self._cancel_scope.__exit__(exc_type, exc_val, exc_tb)
async def respond(self, response: SendResultT | ErrorData) -> None:
assert not self._responded, "Request already responded to"
self._responded = True
"""Send a response for this request.
Must be called within a context manager block.
Raises:
RuntimeError: If not used within a context manager
AssertionError: If request was already responded to
"""
if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager")
assert not self._completed, "Request already responded to"
if not self.cancelled:
self._completed = True
await self._session._send_response(
request_id=self.request_id, response=response
)
async def cancel(self) -> None:
"""Cancel this request and mark it as completed."""
if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager")
if not self._cancel_scope:
raise RuntimeError("No active cancel scope")
self._cancel_scope.cancel()
self._completed = True # Mark as completed so it's removed from in_flight
# Send an error response to indicate cancellation
await self._session._send_response(
request_id=self.request_id, response=response
request_id=self.request_id,
response=ErrorData(code=0, message="Request cancelled", data=None),
)
@property
def in_flight(self) -> bool:
return not self._completed and not self.cancelled
@property
def cancelled(self) -> bool:
return self._cancel_scope is not None and self._cancel_scope.cancel_called
class BaseSession(
AbstractAsyncContextManager,
@@ -82,6 +155,7 @@ class BaseSession(
RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]
]
_request_id: int
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
def __init__(
self,
@@ -99,6 +173,7 @@ class BaseSession(
self._receive_request_type = receive_request_type
self._receive_notification_type = receive_notification_type
self._read_timeout_seconds = read_timeout_seconds
self._in_flight = {}
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
anyio.create_memory_object_stream[
@@ -219,6 +294,7 @@ class BaseSession(
by_alias=True, mode="json", exclude_none=True
)
)
responder = RequestResponder(
request_id=message.root.id,
request_meta=validated_request.root.params.meta
@@ -226,20 +302,37 @@ class BaseSession(
else None,
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
)
self._in_flight[responder.request_id] = responder
await self._received_request(responder)
if not responder._responded:
if not responder._completed:
await self._incoming_message_stream_writer.send(responder)
elif isinstance(message.root, JSONRPCNotification):
notification = self._receive_notification_type.model_validate(
message.root.model_dump(
by_alias=True, mode="json", exclude_none=True
)
)
await self._received_notification(notification)
await self._incoming_message_stream_writer.send(notification)
elif isinstance(message.root, JSONRPCNotification):
try:
notification = self._receive_notification_type.model_validate(
message.root.model_dump(
by_alias=True, mode="json", exclude_none=True
)
)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
await self._in_flight[cancelled_id].cancel()
else:
await self._received_notification(notification)
await self._incoming_message_stream_writer.send(
notification
)
except Exception as e:
# For other validation errors, log and continue
logging.warning(
f"Failed to validate notification: {e}. "
f"Message was: {message.root}"
)
else: # Response or error
stream = self._response_streams.pop(message.root.id, None)
if stream:

View File

@@ -1,7 +1,15 @@
from typing import Annotated, Any, Generic, Literal, TypeVar
from typing import (
Annotated,
Any,
Callable,
Generic,
Literal,
TypeAlias,
TypeVar,
)
from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel
from pydantic.networks import AnyUrl
from pydantic.networks import AnyUrl, UrlConstraints
"""
Model Context Protocol bindings for Python
@@ -27,6 +35,7 @@ ProgressToken = str | int
Cursor = str
Role = Literal["user", "assistant"]
RequestId = str | int
AnyFunction: TypeAlias = Callable[..., Any]
class RequestParams(BaseModel):
@@ -352,7 +361,7 @@ class Annotations(BaseModel):
class Resource(BaseModel):
"""A known resource that the server is capable of reading."""
uri: AnyUrl
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
"""The URI of this resource."""
name: str
"""A human-readable name for this resource."""
@@ -414,7 +423,7 @@ class ListResourceTemplatesResult(PaginatedResult):
class ReadResourceRequestParams(RequestParams):
"""Parameters for reading a resource."""
uri: AnyUrl
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
"""
The URI of the resource to read. The URI can use any protocol; it is up to the
server how to interpret it.
@@ -432,7 +441,7 @@ class ReadResourceRequest(Request):
class ResourceContents(BaseModel):
"""The contents of a specific resource or sub-resource."""
uri: AnyUrl
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
"""The URI of this resource."""
mimeType: str | None = None
"""The MIME type of this resource, if known."""
@@ -475,7 +484,7 @@ class ResourceListChangedNotification(Notification):
class SubscribeRequestParams(RequestParams):
"""Parameters for subscribing to a resource."""
uri: AnyUrl
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
"""
The URI of the resource to subscribe to. The URI can use any protocol; it is up to
the server how to interpret it.
@@ -496,7 +505,7 @@ class SubscribeRequest(Request):
class UnsubscribeRequestParams(RequestParams):
"""Parameters for unsubscribing from a resource."""
uri: AnyUrl
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
"""The URI of the resource to unsubscribe from."""
model_config = ConfigDict(extra="allow")
@@ -514,7 +523,7 @@ class UnsubscribeRequest(Request):
class ResourceUpdatedNotificationParams(NotificationParams):
"""Parameters for resource update notifications."""
uri: AnyUrl
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
"""
The URI of the resource that has been updated. This might be a sub-resource of the
one that the client actually subscribed to.