mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-20 15:24:25 +01:00
Merge branch 'modelcontextprotocol:main' into patch-1
This commit is contained in:
@@ -1,13 +1,51 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any, Protocol
|
||||
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl
|
||||
from pydantic import AnyUrl, TypeAdapter
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.shared.session import BaseSession
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.shared.session import BaseSession, RequestResponder
|
||||
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
||||
|
||||
|
||||
class SamplingFnT(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
) -> types.CreateMessageResult | types.ErrorData: ...
|
||||
|
||||
|
||||
class ListRootsFnT(Protocol):
|
||||
async def __call__(
|
||||
self, context: RequestContext["ClientSession", Any]
|
||||
) -> types.ListRootsResult | types.ErrorData: ...
|
||||
|
||||
|
||||
async def _default_sampling_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
) -> types.CreateMessageResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="Sampling not supported",
|
||||
)
|
||||
|
||||
|
||||
async def _default_list_roots_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
) -> types.ListRootsResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="List roots not supported",
|
||||
)
|
||||
|
||||
|
||||
ClientResponse = TypeAdapter(types.ClientResult | types.ErrorData)
|
||||
|
||||
|
||||
class ClientSession(
|
||||
BaseSession[
|
||||
types.ClientRequest,
|
||||
@@ -22,6 +60,8 @@ class ClientSession(
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
sampling_callback: SamplingFnT | None = None,
|
||||
list_roots_callback: ListRootsFnT | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
read_stream,
|
||||
@@ -30,8 +70,24 @@ class ClientSession(
|
||||
types.ServerNotification,
|
||||
read_timeout_seconds=read_timeout_seconds,
|
||||
)
|
||||
self._sampling_callback = sampling_callback or _default_sampling_callback
|
||||
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
|
||||
|
||||
async def initialize(self) -> types.InitializeResult:
|
||||
sampling = (
|
||||
types.SamplingCapability() if self._sampling_callback is not None else None
|
||||
)
|
||||
roots = (
|
||||
types.RootsCapability(
|
||||
# TODO: Should this be based on whether we
|
||||
# _will_ send notifications, or only whether
|
||||
# they're supported?
|
||||
listChanged=True,
|
||||
)
|
||||
if self._list_roots_callback is not None
|
||||
else None
|
||||
)
|
||||
|
||||
result = await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.InitializeRequest(
|
||||
@@ -39,14 +95,9 @@ class ClientSession(
|
||||
params=types.InitializeRequestParams(
|
||||
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
||||
capabilities=types.ClientCapabilities(
|
||||
sampling=None,
|
||||
sampling=sampling,
|
||||
experimental=None,
|
||||
roots=types.RootsCapability(
|
||||
# TODO: Should this be based on whether we
|
||||
# _will_ send notifications, or only whether
|
||||
# they're supported?
|
||||
listChanged=True
|
||||
),
|
||||
roots=roots,
|
||||
),
|
||||
clientInfo=types.Implementation(name="mcp", version="0.1.0"),
|
||||
),
|
||||
@@ -120,6 +171,17 @@ class ClientSession(
|
||||
types.ListResourcesResult,
|
||||
)
|
||||
|
||||
async def list_resource_templates(self) -> types.ListResourceTemplatesResult:
|
||||
"""Send a resources/templates/list request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ListResourceTemplatesRequest(
|
||||
method="resources/templates/list",
|
||||
)
|
||||
),
|
||||
types.ListResourceTemplatesResult,
|
||||
)
|
||||
|
||||
async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
|
||||
"""Send a resources/read request."""
|
||||
return await self.send_request(
|
||||
@@ -232,3 +294,32 @@ class ClientSession(
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def _received_request(
|
||||
self, responder: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
) -> None:
|
||||
ctx = RequestContext[ClientSession, Any](
|
||||
request_id=responder.request_id,
|
||||
meta=responder.request_meta,
|
||||
session=self,
|
||||
lifespan_context=None,
|
||||
)
|
||||
|
||||
match responder.request.root:
|
||||
case types.CreateMessageRequest(params=params):
|
||||
with responder:
|
||||
response = await self._sampling_callback(ctx, params)
|
||||
client_response = ClientResponse.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case types.ListRootsRequest():
|
||||
with responder:
|
||||
response = await self._list_roots_callback(ctx)
|
||||
client_response = ClientResponse.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case types.PingRequest():
|
||||
with responder:
|
||||
return await responder.respond(
|
||||
types.ClientResult(root=types.EmptyResult())
|
||||
)
|
||||
|
||||
@@ -3,8 +3,13 @@
|
||||
import inspect
|
||||
import json
|
||||
import re
|
||||
from collections.abc import AsyncIterator, Iterable
|
||||
from contextlib import (
|
||||
AbstractAsyncContextManager,
|
||||
asynccontextmanager,
|
||||
)
|
||||
from itertools import chain
|
||||
from typing import Any, Callable, Literal, Sequence
|
||||
from typing import Any, Callable, Generic, Literal, Sequence
|
||||
|
||||
import anyio
|
||||
import pydantic_core
|
||||
@@ -19,12 +24,22 @@ from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceMan
|
||||
from mcp.server.fastmcp.tools import ToolManager
|
||||
from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger
|
||||
from mcp.server.fastmcp.utilities.types import Image
|
||||
from mcp.server.lowlevel import Server as MCPServer
|
||||
from mcp.server.lowlevel.helper_types import ReadResourceContents
|
||||
from mcp.server.lowlevel.server import (
|
||||
LifespanResultT,
|
||||
)
|
||||
from mcp.server.lowlevel.server import (
|
||||
Server as MCPServer,
|
||||
)
|
||||
from mcp.server.lowlevel.server import (
|
||||
lifespan as default_lifespan,
|
||||
)
|
||||
from mcp.server.session import ServerSession
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from mcp.server.stdio import stdio_server
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.types import (
|
||||
AnyFunction,
|
||||
EmbeddedResource,
|
||||
GetPromptResult,
|
||||
ImageContent,
|
||||
@@ -49,7 +64,7 @@ from mcp.types import (
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
class Settings(BaseSettings, Generic[LifespanResultT]):
|
||||
"""FastMCP server settings.
|
||||
|
||||
All settings can be configured via environment variables with the prefix FASTMCP_.
|
||||
@@ -84,13 +99,36 @@ class Settings(BaseSettings):
|
||||
description="List of dependencies to install in the server environment",
|
||||
)
|
||||
|
||||
lifespan: (
|
||||
Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]] | None
|
||||
) = Field(None, description="Lifespan context manager")
|
||||
|
||||
|
||||
def lifespan_wrapper(
|
||||
app: "FastMCP",
|
||||
lifespan: Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]],
|
||||
) -> Callable[[MCPServer], AbstractAsyncContextManager[object]]:
|
||||
@asynccontextmanager
|
||||
async def wrap(s: MCPServer) -> AsyncIterator[object]:
|
||||
async with lifespan(app) as context:
|
||||
yield context
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
class FastMCP:
|
||||
def __init__(
|
||||
self, name: str | None = None, instructions: str | None = None, **settings: Any
|
||||
):
|
||||
self.settings = Settings(**settings)
|
||||
self._mcp_server = MCPServer(name=name or "FastMCP", instructions=instructions)
|
||||
|
||||
self._mcp_server = MCPServer(
|
||||
name=name or "FastMCP",
|
||||
instructions=instructions,
|
||||
lifespan=lifespan_wrapper(self, self.settings.lifespan)
|
||||
if self.settings.lifespan
|
||||
else default_lifespan,
|
||||
)
|
||||
self._tool_manager = ToolManager(
|
||||
warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools
|
||||
)
|
||||
@@ -165,7 +203,7 @@ class FastMCP:
|
||||
return Context(request_context=request_context, fastmcp=self)
|
||||
|
||||
async def call_tool(
|
||||
self, name: str, arguments: dict
|
||||
self, name: str, arguments: dict[str, Any]
|
||||
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
|
||||
"""Call a tool by name with arguments."""
|
||||
context = self.get_context()
|
||||
@@ -198,7 +236,7 @@ class FastMCP:
|
||||
for template in templates
|
||||
]
|
||||
|
||||
async def read_resource(self, uri: AnyUrl | str) -> ReadResourceContents:
|
||||
async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContents]:
|
||||
"""Read a resource by URI."""
|
||||
|
||||
resource = await self._resource_manager.get_resource(uri)
|
||||
@@ -207,14 +245,14 @@ class FastMCP:
|
||||
|
||||
try:
|
||||
content = await resource.read()
|
||||
return ReadResourceContents(content=content, mime_type=resource.mime_type)
|
||||
return [ReadResourceContents(content=content, mime_type=resource.mime_type)]
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading resource {uri}: {e}")
|
||||
raise ResourceError(str(e))
|
||||
|
||||
def add_tool(
|
||||
self,
|
||||
fn: Callable,
|
||||
fn: AnyFunction,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> None:
|
||||
@@ -230,7 +268,9 @@ class FastMCP:
|
||||
"""
|
||||
self._tool_manager.add_tool(fn, name=name, description=description)
|
||||
|
||||
def tool(self, name: str | None = None, description: str | None = None) -> Callable:
|
||||
def tool(
|
||||
self, name: str | None = None, description: str | None = None
|
||||
) -> Callable[[AnyFunction], AnyFunction]:
|
||||
"""Decorator to register a tool.
|
||||
|
||||
Tools can optionally request a Context object by adding a parameter with the
|
||||
@@ -263,7 +303,7 @@ class FastMCP:
|
||||
"Did you forget to call it? Use @tool() instead of @tool"
|
||||
)
|
||||
|
||||
def decorator(fn: Callable) -> Callable:
|
||||
def decorator(fn: AnyFunction) -> AnyFunction:
|
||||
self.add_tool(fn, name=name, description=description)
|
||||
return fn
|
||||
|
||||
@@ -284,7 +324,7 @@ class FastMCP:
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
) -> Callable:
|
||||
) -> Callable[[AnyFunction], AnyFunction]:
|
||||
"""Decorator to register a function as a resource.
|
||||
|
||||
The function will be called when the resource is read to generate its content.
|
||||
@@ -328,7 +368,7 @@ class FastMCP:
|
||||
"Did you forget to call it? Use @resource('uri') instead of @resource"
|
||||
)
|
||||
|
||||
def decorator(fn: Callable) -> Callable:
|
||||
def decorator(fn: AnyFunction) -> AnyFunction:
|
||||
# Check if this should be a template
|
||||
has_uri_params = "{" in uri and "}" in uri
|
||||
has_func_params = bool(inspect.signature(fn).parameters)
|
||||
@@ -376,7 +416,7 @@ class FastMCP:
|
||||
|
||||
def prompt(
|
||||
self, name: str | None = None, description: str | None = None
|
||||
) -> Callable:
|
||||
) -> Callable[[AnyFunction], AnyFunction]:
|
||||
"""Decorator to register a prompt.
|
||||
|
||||
Args:
|
||||
@@ -417,7 +457,7 @@ class FastMCP:
|
||||
"Did you forget to call it? Use @prompt() instead of @prompt"
|
||||
)
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
def decorator(func: AnyFunction) -> AnyFunction:
|
||||
prompt = Prompt.from_function(func, name=name, description=description)
|
||||
self.add_prompt(prompt)
|
||||
return func
|
||||
@@ -558,7 +598,7 @@ class Context(BaseModel):
|
||||
The context is optional - tools that don't need it can omit the parameter.
|
||||
"""
|
||||
|
||||
_request_context: RequestContext | None
|
||||
_request_context: RequestContext[ServerSession, Any] | None
|
||||
_fastmcp: FastMCP | None
|
||||
|
||||
def __init__(
|
||||
@@ -602,14 +642,14 @@ class Context(BaseModel):
|
||||
else None
|
||||
)
|
||||
|
||||
if not progress_token:
|
||||
if progress_token is None:
|
||||
return
|
||||
|
||||
await self.request_context.session.send_progress_notification(
|
||||
progress_token=progress_token, progress=progress, total=total
|
||||
)
|
||||
|
||||
async def read_resource(self, uri: str | AnyUrl) -> ReadResourceContents:
|
||||
async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]:
|
||||
"""Read a resource by URI.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -67,9 +67,11 @@ messages from the client.
|
||||
import contextvars
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, Sequence
|
||||
from collections.abc import Awaitable, Callable, Iterable
|
||||
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
|
||||
from typing import Any, AsyncIterator, Generic, TypeVar
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl
|
||||
|
||||
@@ -84,7 +86,10 @@ from mcp.shared.session import RequestResponder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
request_ctx: contextvars.ContextVar[RequestContext[ServerSession]] = (
|
||||
LifespanResultT = TypeVar("LifespanResultT")
|
||||
|
||||
# This will be properly typed in each Server instance's context
|
||||
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any]] = (
|
||||
contextvars.ContextVar("request_ctx")
|
||||
)
|
||||
|
||||
@@ -101,13 +106,33 @@ class NotificationOptions:
|
||||
self.tools_changed = tools_changed
|
||||
|
||||
|
||||
class Server:
|
||||
@asynccontextmanager
|
||||
async def lifespan(server: "Server") -> AsyncIterator[object]:
|
||||
"""Default lifespan context manager that does nothing.
|
||||
|
||||
Args:
|
||||
server: The server instance this lifespan is managing
|
||||
|
||||
Returns:
|
||||
An empty context object
|
||||
"""
|
||||
yield {}
|
||||
|
||||
|
||||
class Server(Generic[LifespanResultT]):
|
||||
def __init__(
|
||||
self, name: str, version: str | None = None, instructions: str | None = None
|
||||
self,
|
||||
name: str,
|
||||
version: str | None = None,
|
||||
instructions: str | None = None,
|
||||
lifespan: Callable[
|
||||
["Server"], AbstractAsyncContextManager[LifespanResultT]
|
||||
] = lifespan,
|
||||
):
|
||||
self.name = name
|
||||
self.version = version
|
||||
self.instructions = instructions
|
||||
self.lifespan = lifespan
|
||||
self.request_handlers: dict[
|
||||
type, Callable[..., Awaitable[types.ServerResult]]
|
||||
] = {
|
||||
@@ -188,7 +213,7 @@ class Server:
|
||||
)
|
||||
|
||||
@property
|
||||
def request_context(self) -> RequestContext[ServerSession]:
|
||||
def request_context(self) -> RequestContext[ServerSession, LifespanResultT]:
|
||||
"""If called outside of a request context, this will raise a LookupError."""
|
||||
return request_ctx.get()
|
||||
|
||||
@@ -254,7 +279,9 @@ class Server:
|
||||
|
||||
def read_resource(self):
|
||||
def decorator(
|
||||
func: Callable[[AnyUrl], Awaitable[str | bytes | ReadResourceContents]],
|
||||
func: Callable[
|
||||
[AnyUrl], Awaitable[str | bytes | Iterable[ReadResourceContents]]
|
||||
],
|
||||
):
|
||||
logger.debug("Registering handler for ReadResourceRequest")
|
||||
|
||||
@@ -282,13 +309,22 @@ class Server:
|
||||
case str() | bytes() as data:
|
||||
warnings.warn(
|
||||
"Returning str or bytes from read_resource is deprecated. "
|
||||
"Use ReadResourceContents instead.",
|
||||
"Use Iterable[ReadResourceContents] instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
content = create_content(data, None)
|
||||
case ReadResourceContents() as contents:
|
||||
content = create_content(contents.content, contents.mime_type)
|
||||
case Iterable() as contents:
|
||||
contents_list = [
|
||||
create_content(content_item.content, content_item.mime_type)
|
||||
for content_item in contents
|
||||
if isinstance(content_item, ReadResourceContents)
|
||||
]
|
||||
return types.ServerResult(
|
||||
types.ReadResourceResult(
|
||||
contents=contents_list,
|
||||
)
|
||||
)
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Unexpected return type from read_resource: {type(result)}"
|
||||
@@ -362,7 +398,7 @@ class Server:
|
||||
func: Callable[
|
||||
...,
|
||||
Awaitable[
|
||||
Sequence[
|
||||
Iterable[
|
||||
types.TextContent | types.ImageContent | types.EmbeddedResource
|
||||
]
|
||||
],
|
||||
@@ -445,31 +481,54 @@ class Server:
|
||||
# in-process servers.
|
||||
raise_exceptions: bool = False,
|
||||
):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
async with ServerSession(
|
||||
read_stream, write_stream, initialization_options
|
||||
) as session:
|
||||
async with AsyncExitStack() as stack:
|
||||
lifespan_context = await stack.enter_async_context(self.lifespan(self))
|
||||
session = await stack.enter_async_context(
|
||||
ServerSession(read_stream, write_stream, initialization_options)
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
async for message in session.incoming_messages:
|
||||
logger.debug(f"Received message: {message}")
|
||||
|
||||
match message:
|
||||
case RequestResponder(request=types.ClientRequest(root=req)):
|
||||
await self._handle_request(
|
||||
message, req, session, raise_exceptions
|
||||
)
|
||||
case types.ClientNotification(root=notify):
|
||||
await self._handle_notification(notify)
|
||||
tg.start_soon(
|
||||
self._handle_message,
|
||||
message,
|
||||
session,
|
||||
lifespan_context,
|
||||
raise_exceptions,
|
||||
)
|
||||
|
||||
for warning in w:
|
||||
logger.info(
|
||||
f"Warning: {warning.category.__name__}: {warning.message}"
|
||||
async def _handle_message(
|
||||
self,
|
||||
message: RequestResponder[types.ClientRequest, types.ServerResult]
|
||||
| types.ClientNotification
|
||||
| Exception,
|
||||
session: ServerSession,
|
||||
lifespan_context: LifespanResultT,
|
||||
raise_exceptions: bool = False,
|
||||
):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
match message:
|
||||
case (
|
||||
RequestResponder(request=types.ClientRequest(root=req)) as responder
|
||||
):
|
||||
with responder:
|
||||
await self._handle_request(
|
||||
message, req, session, lifespan_context, raise_exceptions
|
||||
)
|
||||
case types.ClientNotification(root=notify):
|
||||
await self._handle_notification(notify)
|
||||
|
||||
for warning in w:
|
||||
logger.info(f"Warning: {warning.category.__name__}: {warning.message}")
|
||||
|
||||
async def _handle_request(
|
||||
self,
|
||||
message: RequestResponder,
|
||||
req: Any,
|
||||
session: ServerSession,
|
||||
lifespan_context: LifespanResultT,
|
||||
raise_exceptions: bool,
|
||||
):
|
||||
logger.info(f"Processing request of type {type(req).__name__}")
|
||||
@@ -486,6 +545,7 @@ class Server:
|
||||
message.request_id,
|
||||
message.request_meta,
|
||||
session,
|
||||
lifespan_context,
|
||||
)
|
||||
)
|
||||
response = await handler(req)
|
||||
|
||||
@@ -126,19 +126,20 @@ class ServerSession(
|
||||
case types.InitializeRequest(params=params):
|
||||
self._initialization_state = InitializationState.Initializing
|
||||
self._client_params = params
|
||||
await responder.respond(
|
||||
types.ServerResult(
|
||||
types.InitializeResult(
|
||||
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
||||
capabilities=self._init_options.capabilities,
|
||||
serverInfo=types.Implementation(
|
||||
name=self._init_options.server_name,
|
||||
version=self._init_options.server_version,
|
||||
),
|
||||
instructions=self._init_options.instructions,
|
||||
with responder:
|
||||
await responder.respond(
|
||||
types.ServerResult(
|
||||
types.InitializeResult(
|
||||
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
||||
capabilities=self._init_options.capabilities,
|
||||
serverInfo=types.Implementation(
|
||||
name=self._init_options.server_name,
|
||||
version=self._init_options.server_version,
|
||||
),
|
||||
instructions=self._init_options.instructions,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
case _:
|
||||
if self._initialization_state != InitializationState.Initialized:
|
||||
raise RuntimeError(
|
||||
|
||||
@@ -20,6 +20,7 @@ Example usage:
|
||||
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from io import TextIOWrapper
|
||||
|
||||
import anyio
|
||||
import anyio.lowlevel
|
||||
@@ -38,11 +39,13 @@ async def stdio_server(
|
||||
from the current process' stdin and writing to stdout.
|
||||
"""
|
||||
# Purposely not using context managers for these, as we don't want to close
|
||||
# standard process handles.
|
||||
# standard process handles. Encoding of stdin/stdout as text streams on
|
||||
# python is platform-dependent (Windows is particularly problematic), so we
|
||||
# re-wrap the underlying binary stream to ensure UTF-8.
|
||||
if not stdin:
|
||||
stdin = anyio.wrap_file(sys.stdin)
|
||||
stdin = anyio.wrap_file(TextIOWrapper(sys.stdin.buffer, encoding="utf-8"))
|
||||
if not stdout:
|
||||
stdout = anyio.wrap_file(sys.stdout)
|
||||
stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8"))
|
||||
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||
|
||||
@@ -5,10 +5,12 @@ from mcp.shared.session import BaseSession
|
||||
from mcp.types import RequestId, RequestParams
|
||||
|
||||
SessionT = TypeVar("SessionT", bound=BaseSession)
|
||||
LifespanContextT = TypeVar("LifespanContextT")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext(Generic[SessionT]):
|
||||
class RequestContext(Generic[SessionT, LifespanContextT]):
|
||||
request_id: RequestId
|
||||
meta: RequestParams.Meta | None
|
||||
session: SessionT
|
||||
lifespan_context: LifespanContextT
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import AsyncGenerator
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT
|
||||
from mcp.server import Server
|
||||
from mcp.types import JSONRPCMessage
|
||||
|
||||
@@ -54,6 +54,8 @@ async def create_client_server_memory_streams() -> (
|
||||
async def create_connected_server_and_client_session(
|
||||
server: Server,
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
sampling_callback: SamplingFnT | None = None,
|
||||
list_roots_callback: ListRootsFnT | None = None,
|
||||
raise_exceptions: bool = False,
|
||||
) -> AsyncGenerator[ClientSession, None]:
|
||||
"""Creates a ClientSession that is connected to a running MCP server."""
|
||||
@@ -80,6 +82,8 @@ async def create_connected_server_and_client_session(
|
||||
read_stream=client_read,
|
||||
write_stream=client_write,
|
||||
read_timeout_seconds=read_timeout_seconds,
|
||||
sampling_callback=sampling_callback,
|
||||
list_roots_callback=list_roots_callback,
|
||||
) as client_session:
|
||||
await client_session.initialize()
|
||||
yield client_session
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
from datetime import timedelta
|
||||
from typing import Generic, TypeVar
|
||||
from typing import Any, Callable, Generic, TypeVar
|
||||
|
||||
import anyio
|
||||
import anyio.lowlevel
|
||||
@@ -10,6 +11,7 @@ from pydantic import BaseModel
|
||||
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.types import (
|
||||
CancelledNotification,
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
ClientResult,
|
||||
@@ -38,27 +40,98 @@ RequestId = str | int
|
||||
|
||||
|
||||
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||
"""Handles responding to MCP requests and manages request lifecycle.
|
||||
|
||||
This class MUST be used as a context manager to ensure proper cleanup and
|
||||
cancellation handling:
|
||||
|
||||
Example:
|
||||
with request_responder as resp:
|
||||
await resp.respond(result)
|
||||
|
||||
The context manager ensures:
|
||||
1. Proper cancellation scope setup and cleanup
|
||||
2. Request completion tracking
|
||||
3. Cleanup of in-flight requests
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: RequestId,
|
||||
request_meta: RequestParams.Meta | None,
|
||||
request: ReceiveRequestT,
|
||||
session: "BaseSession",
|
||||
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.request_meta = request_meta
|
||||
self.request = request
|
||||
self._session = session
|
||||
self._responded = False
|
||||
self._completed = False
|
||||
self._cancel_scope = anyio.CancelScope()
|
||||
self._on_complete = on_complete
|
||||
self._entered = False # Track if we're in a context manager
|
||||
|
||||
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
|
||||
"""Enter the context manager, enabling request cancellation tracking."""
|
||||
self._entered = True
|
||||
self._cancel_scope = anyio.CancelScope()
|
||||
self._cancel_scope.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
"""Exit the context manager, performing cleanup and notifying completion."""
|
||||
try:
|
||||
if self._completed:
|
||||
self._on_complete(self)
|
||||
finally:
|
||||
self._entered = False
|
||||
if not self._cancel_scope:
|
||||
raise RuntimeError("No active cancel scope")
|
||||
self._cancel_scope.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
async def respond(self, response: SendResultT | ErrorData) -> None:
|
||||
assert not self._responded, "Request already responded to"
|
||||
self._responded = True
|
||||
"""Send a response for this request.
|
||||
|
||||
Must be called within a context manager block.
|
||||
Raises:
|
||||
RuntimeError: If not used within a context manager
|
||||
AssertionError: If request was already responded to
|
||||
"""
|
||||
if not self._entered:
|
||||
raise RuntimeError("RequestResponder must be used as a context manager")
|
||||
assert not self._completed, "Request already responded to"
|
||||
|
||||
if not self.cancelled:
|
||||
self._completed = True
|
||||
|
||||
await self._session._send_response(
|
||||
request_id=self.request_id, response=response
|
||||
)
|
||||
|
||||
async def cancel(self) -> None:
|
||||
"""Cancel this request and mark it as completed."""
|
||||
if not self._entered:
|
||||
raise RuntimeError("RequestResponder must be used as a context manager")
|
||||
if not self._cancel_scope:
|
||||
raise RuntimeError("No active cancel scope")
|
||||
|
||||
self._cancel_scope.cancel()
|
||||
self._completed = True # Mark as completed so it's removed from in_flight
|
||||
# Send an error response to indicate cancellation
|
||||
await self._session._send_response(
|
||||
request_id=self.request_id, response=response
|
||||
request_id=self.request_id,
|
||||
response=ErrorData(code=0, message="Request cancelled", data=None),
|
||||
)
|
||||
|
||||
@property
|
||||
def in_flight(self) -> bool:
|
||||
return not self._completed and not self.cancelled
|
||||
|
||||
@property
|
||||
def cancelled(self) -> bool:
|
||||
return self._cancel_scope is not None and self._cancel_scope.cancel_called
|
||||
|
||||
|
||||
class BaseSession(
|
||||
AbstractAsyncContextManager,
|
||||
@@ -82,6 +155,7 @@ class BaseSession(
|
||||
RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]
|
||||
]
|
||||
_request_id: int
|
||||
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -99,6 +173,7 @@ class BaseSession(
|
||||
self._receive_request_type = receive_request_type
|
||||
self._receive_notification_type = receive_notification_type
|
||||
self._read_timeout_seconds = read_timeout_seconds
|
||||
self._in_flight = {}
|
||||
|
||||
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
|
||||
anyio.create_memory_object_stream[
|
||||
@@ -219,6 +294,7 @@ class BaseSession(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
)
|
||||
)
|
||||
|
||||
responder = RequestResponder(
|
||||
request_id=message.root.id,
|
||||
request_meta=validated_request.root.params.meta
|
||||
@@ -226,20 +302,37 @@ class BaseSession(
|
||||
else None,
|
||||
request=validated_request,
|
||||
session=self,
|
||||
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
|
||||
)
|
||||
|
||||
self._in_flight[responder.request_id] = responder
|
||||
await self._received_request(responder)
|
||||
if not responder._responded:
|
||||
if not responder._completed:
|
||||
await self._incoming_message_stream_writer.send(responder)
|
||||
elif isinstance(message.root, JSONRPCNotification):
|
||||
notification = self._receive_notification_type.model_validate(
|
||||
message.root.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
)
|
||||
)
|
||||
|
||||
await self._received_notification(notification)
|
||||
await self._incoming_message_stream_writer.send(notification)
|
||||
elif isinstance(message.root, JSONRPCNotification):
|
||||
try:
|
||||
notification = self._receive_notification_type.model_validate(
|
||||
message.root.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
)
|
||||
)
|
||||
# Handle cancellation notifications
|
||||
if isinstance(notification.root, CancelledNotification):
|
||||
cancelled_id = notification.root.params.requestId
|
||||
if cancelled_id in self._in_flight:
|
||||
await self._in_flight[cancelled_id].cancel()
|
||||
else:
|
||||
await self._received_notification(notification)
|
||||
await self._incoming_message_stream_writer.send(
|
||||
notification
|
||||
)
|
||||
except Exception as e:
|
||||
# For other validation errors, log and continue
|
||||
logging.warning(
|
||||
f"Failed to validate notification: {e}. "
|
||||
f"Message was: {message.root}"
|
||||
)
|
||||
else: # Response or error
|
||||
stream = self._response_streams.pop(message.root.id, None)
|
||||
if stream:
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
from typing import Annotated, Any, Generic, Literal, TypeVar
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Literal,
|
||||
TypeAlias,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel
|
||||
from pydantic.networks import AnyUrl
|
||||
from pydantic.networks import AnyUrl, UrlConstraints
|
||||
|
||||
"""
|
||||
Model Context Protocol bindings for Python
|
||||
@@ -27,6 +35,7 @@ ProgressToken = str | int
|
||||
Cursor = str
|
||||
Role = Literal["user", "assistant"]
|
||||
RequestId = str | int
|
||||
AnyFunction: TypeAlias = Callable[..., Any]
|
||||
|
||||
|
||||
class RequestParams(BaseModel):
|
||||
@@ -352,7 +361,7 @@ class Annotations(BaseModel):
|
||||
class Resource(BaseModel):
|
||||
"""A known resource that the server is capable of reading."""
|
||||
|
||||
uri: AnyUrl
|
||||
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
|
||||
"""The URI of this resource."""
|
||||
name: str
|
||||
"""A human-readable name for this resource."""
|
||||
@@ -414,7 +423,7 @@ class ListResourceTemplatesResult(PaginatedResult):
|
||||
class ReadResourceRequestParams(RequestParams):
|
||||
"""Parameters for reading a resource."""
|
||||
|
||||
uri: AnyUrl
|
||||
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
|
||||
"""
|
||||
The URI of the resource to read. The URI can use any protocol; it is up to the
|
||||
server how to interpret it.
|
||||
@@ -432,7 +441,7 @@ class ReadResourceRequest(Request):
|
||||
class ResourceContents(BaseModel):
|
||||
"""The contents of a specific resource or sub-resource."""
|
||||
|
||||
uri: AnyUrl
|
||||
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
|
||||
"""The URI of this resource."""
|
||||
mimeType: str | None = None
|
||||
"""The MIME type of this resource, if known."""
|
||||
@@ -475,7 +484,7 @@ class ResourceListChangedNotification(Notification):
|
||||
class SubscribeRequestParams(RequestParams):
|
||||
"""Parameters for subscribing to a resource."""
|
||||
|
||||
uri: AnyUrl
|
||||
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
|
||||
"""
|
||||
The URI of the resource to subscribe to. The URI can use any protocol; it is up to
|
||||
the server how to interpret it.
|
||||
@@ -496,7 +505,7 @@ class SubscribeRequest(Request):
|
||||
class UnsubscribeRequestParams(RequestParams):
|
||||
"""Parameters for unsubscribing from a resource."""
|
||||
|
||||
uri: AnyUrl
|
||||
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
|
||||
"""The URI of the resource to unsubscribe from."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
@@ -514,7 +523,7 @@ class UnsubscribeRequest(Request):
|
||||
class ResourceUpdatedNotificationParams(NotificationParams):
|
||||
"""Parameters for resource update notifications."""
|
||||
|
||||
uri: AnyUrl
|
||||
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
|
||||
"""
|
||||
The URI of the resource that has been updated. This might be a sub-resource of the
|
||||
one that the client actually subscribed to.
|
||||
|
||||
Reference in New Issue
Block a user