Add client handling for sampling, list roots, ping (#218)

Adds sampling and list roots callbacks to the ClientSession, allowing the client to handle requests from the server.

Co-authored-by: TerminalMan <84923604+SecretiveShell@users.noreply.github.com>
Co-authored-by: David Soria Parra <davidsp@anthropic.com>
This commit is contained in:
Jerome
2025-02-20 10:49:43 +00:00
committed by GitHub
parent 106619967b
commit ff22f48365
6 changed files with 256 additions and 12 deletions

View File

@@ -1,13 +1,51 @@
from datetime import timedelta
from typing import Any, Protocol
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
from pydantic import AnyUrl, TypeAdapter
import mcp.types as types
from mcp.shared.session import BaseSession
from mcp.shared.context import RequestContext
from mcp.shared.session import BaseSession, RequestResponder
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
class SamplingFnT(Protocol):
async def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.ErrorData: ...
class ListRootsFnT(Protocol):
async def __call__(
self, context: RequestContext["ClientSession", Any]
) -> types.ListRootsResult | types.ErrorData: ...
async def _default_sampling_callback(
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="Sampling not supported",
)
async def _default_list_roots_callback(
context: RequestContext["ClientSession", Any],
) -> types.ListRootsResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="List roots not supported",
)
ClientResponse = TypeAdapter(types.ClientResult | types.ErrorData)
class ClientSession(
BaseSession[
types.ClientRequest,
@@ -22,6 +60,8 @@ class ClientSession(
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
) -> None:
super().__init__(
read_stream,
@@ -30,8 +70,24 @@ class ClientSession(
types.ServerNotification,
read_timeout_seconds=read_timeout_seconds,
)
self._sampling_callback = sampling_callback or _default_sampling_callback
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
async def initialize(self) -> types.InitializeResult:
sampling = (
types.SamplingCapability() if self._sampling_callback is not None else None
)
roots = (
types.RootsCapability(
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
listChanged=True,
)
if self._list_roots_callback is not None
else None
)
result = await self.send_request(
types.ClientRequest(
types.InitializeRequest(
@@ -39,14 +95,9 @@ class ClientSession(
params=types.InitializeRequestParams(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=types.ClientCapabilities(
sampling=None,
sampling=sampling,
experimental=None,
roots=types.RootsCapability(
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
listChanged=True
),
roots=roots,
),
clientInfo=types.Implementation(name="mcp", version="0.1.0"),
),
@@ -243,3 +294,32 @@ class ClientSession(
)
)
)
async def _received_request(
self, responder: RequestResponder[types.ServerRequest, types.ClientResult]
) -> None:
ctx = RequestContext[ClientSession, Any](
request_id=responder.request_id,
meta=responder.request_meta,
session=self,
lifespan_context=None,
)
match responder.request.root:
case types.CreateMessageRequest(params=params):
with responder:
response = await self._sampling_callback(ctx, params)
client_response = ClientResponse.validate_python(response)
await responder.respond(client_response)
case types.ListRootsRequest():
with responder:
response = await self._list_roots_callback(ctx)
client_response = ClientResponse.validate_python(response)
await responder.respond(client_response)
case types.PingRequest():
with responder:
return await responder.respond(
types.ClientResult(root=types.EmptyResult())
)

View File

@@ -9,7 +9,7 @@ from typing import AsyncGenerator
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp.client.session import ClientSession
from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT
from mcp.server import Server
from mcp.types import JSONRPCMessage
@@ -54,6 +54,8 @@ async def create_client_server_memory_streams() -> (
async def create_connected_server_and_client_session(
server: Server,
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
raise_exceptions: bool = False,
) -> AsyncGenerator[ClientSession, None]:
"""Creates a ClientSession that is connected to a running MCP server."""
@@ -80,6 +82,8 @@ async def create_connected_server_and_client_session(
read_stream=client_read,
write_stream=client_write,
read_timeout_seconds=read_timeout_seconds,
sampling_callback=sampling_callback,
list_roots_callback=list_roots_callback,
) as client_session:
await client_session.initialize()
yield client_session