Merge branch 'modelcontextprotocol:main' into patch-1

This commit is contained in:
Henry Mao
2025-03-06 13:24:01 +08:00
committed by GitHub
37 changed files with 2066 additions and 155 deletions

View File

@@ -1,13 +1,51 @@
from datetime import timedelta
from typing import Any, Protocol
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
from pydantic import AnyUrl, TypeAdapter
import mcp.types as types
from mcp.shared.session import BaseSession
from mcp.shared.context import RequestContext
from mcp.shared.session import BaseSession, RequestResponder
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
class SamplingFnT(Protocol):
async def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.ErrorData: ...
class ListRootsFnT(Protocol):
async def __call__(
self, context: RequestContext["ClientSession", Any]
) -> types.ListRootsResult | types.ErrorData: ...
async def _default_sampling_callback(
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="Sampling not supported",
)
async def _default_list_roots_callback(
context: RequestContext["ClientSession", Any],
) -> types.ListRootsResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="List roots not supported",
)
ClientResponse = TypeAdapter(types.ClientResult | types.ErrorData)
class ClientSession(
BaseSession[
types.ClientRequest,
@@ -22,6 +60,8 @@ class ClientSession(
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
) -> None:
super().__init__(
read_stream,
@@ -30,8 +70,24 @@ class ClientSession(
types.ServerNotification,
read_timeout_seconds=read_timeout_seconds,
)
self._sampling_callback = sampling_callback or _default_sampling_callback
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
async def initialize(self) -> types.InitializeResult:
sampling = (
types.SamplingCapability() if self._sampling_callback is not None else None
)
roots = (
types.RootsCapability(
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
listChanged=True,
)
if self._list_roots_callback is not None
else None
)
result = await self.send_request(
types.ClientRequest(
types.InitializeRequest(
@@ -39,14 +95,9 @@ class ClientSession(
params=types.InitializeRequestParams(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=types.ClientCapabilities(
sampling=None,
sampling=sampling,
experimental=None,
roots=types.RootsCapability(
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
listChanged=True
),
roots=roots,
),
clientInfo=types.Implementation(name="mcp", version="0.1.0"),
),
@@ -120,6 +171,17 @@ class ClientSession(
types.ListResourcesResult,
)
async def list_resource_templates(self) -> types.ListResourceTemplatesResult:
"""Send a resources/templates/list request."""
return await self.send_request(
types.ClientRequest(
types.ListResourceTemplatesRequest(
method="resources/templates/list",
)
),
types.ListResourceTemplatesResult,
)
async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
"""Send a resources/read request."""
return await self.send_request(
@@ -232,3 +294,32 @@ class ClientSession(
)
)
)
async def _received_request(
self, responder: RequestResponder[types.ServerRequest, types.ClientResult]
) -> None:
ctx = RequestContext[ClientSession, Any](
request_id=responder.request_id,
meta=responder.request_meta,
session=self,
lifespan_context=None,
)
match responder.request.root:
case types.CreateMessageRequest(params=params):
with responder:
response = await self._sampling_callback(ctx, params)
client_response = ClientResponse.validate_python(response)
await responder.respond(client_response)
case types.ListRootsRequest():
with responder:
response = await self._list_roots_callback(ctx)
client_response = ClientResponse.validate_python(response)
await responder.respond(client_response)
case types.PingRequest():
with responder:
return await responder.respond(
types.ClientResult(root=types.EmptyResult())
)