mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-18 22:44:20 +01:00
Add client handling for sampling, list roots, ping (#218)
Adds sampling and list roots callbacks to the ClientSession, allowing the client to handle requests from the server. Co-authored-by: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Co-authored-by: David Soria Parra <davidsp@anthropic.com>
This commit is contained in:
14
README.md
14
README.md
@@ -476,9 +476,21 @@ server_params = StdioServerParameters(
|
||||
env=None # Optional environment variables
|
||||
)
|
||||
|
||||
# Optional: create a sampling callback
|
||||
async def handle_sampling_message(message: types.CreateMessageRequestParams) -> types.CreateMessageResult:
|
||||
return types.CreateMessageResult(
|
||||
role="assistant",
|
||||
content=types.TextContent(
|
||||
type="text",
|
||||
text="Hello, world! from model",
|
||||
),
|
||||
model="gpt-3.5-turbo",
|
||||
stopReason="endTurn",
|
||||
)
|
||||
|
||||
async def run():
|
||||
async with stdio_client(server_params) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
async with ClientSession(read, write, sampling_callback=handle_sampling_message) as session:
|
||||
# Initialize the connection
|
||||
await session.initialize()
|
||||
|
||||
|
||||
@@ -1,13 +1,51 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any, Protocol
|
||||
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl
|
||||
from pydantic import AnyUrl, TypeAdapter
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.shared.session import BaseSession
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.shared.session import BaseSession, RequestResponder
|
||||
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
||||
|
||||
|
||||
class SamplingFnT(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
) -> types.CreateMessageResult | types.ErrorData: ...
|
||||
|
||||
|
||||
class ListRootsFnT(Protocol):
|
||||
async def __call__(
|
||||
self, context: RequestContext["ClientSession", Any]
|
||||
) -> types.ListRootsResult | types.ErrorData: ...
|
||||
|
||||
|
||||
async def _default_sampling_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
) -> types.CreateMessageResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="Sampling not supported",
|
||||
)
|
||||
|
||||
|
||||
async def _default_list_roots_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
) -> types.ListRootsResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="List roots not supported",
|
||||
)
|
||||
|
||||
|
||||
ClientResponse = TypeAdapter(types.ClientResult | types.ErrorData)
|
||||
|
||||
|
||||
class ClientSession(
|
||||
BaseSession[
|
||||
types.ClientRequest,
|
||||
@@ -22,6 +60,8 @@ class ClientSession(
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
sampling_callback: SamplingFnT | None = None,
|
||||
list_roots_callback: ListRootsFnT | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
read_stream,
|
||||
@@ -30,8 +70,24 @@ class ClientSession(
|
||||
types.ServerNotification,
|
||||
read_timeout_seconds=read_timeout_seconds,
|
||||
)
|
||||
self._sampling_callback = sampling_callback or _default_sampling_callback
|
||||
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
|
||||
|
||||
async def initialize(self) -> types.InitializeResult:
|
||||
sampling = (
|
||||
types.SamplingCapability() if self._sampling_callback is not None else None
|
||||
)
|
||||
roots = (
|
||||
types.RootsCapability(
|
||||
# TODO: Should this be based on whether we
|
||||
# _will_ send notifications, or only whether
|
||||
# they're supported?
|
||||
listChanged=True,
|
||||
)
|
||||
if self._list_roots_callback is not None
|
||||
else None
|
||||
)
|
||||
|
||||
result = await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.InitializeRequest(
|
||||
@@ -39,14 +95,9 @@ class ClientSession(
|
||||
params=types.InitializeRequestParams(
|
||||
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
||||
capabilities=types.ClientCapabilities(
|
||||
sampling=None,
|
||||
sampling=sampling,
|
||||
experimental=None,
|
||||
roots=types.RootsCapability(
|
||||
# TODO: Should this be based on whether we
|
||||
# _will_ send notifications, or only whether
|
||||
# they're supported?
|
||||
listChanged=True
|
||||
),
|
||||
roots=roots,
|
||||
),
|
||||
clientInfo=types.Implementation(name="mcp", version="0.1.0"),
|
||||
),
|
||||
@@ -243,3 +294,32 @@ class ClientSession(
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def _received_request(
|
||||
self, responder: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
) -> None:
|
||||
ctx = RequestContext[ClientSession, Any](
|
||||
request_id=responder.request_id,
|
||||
meta=responder.request_meta,
|
||||
session=self,
|
||||
lifespan_context=None,
|
||||
)
|
||||
|
||||
match responder.request.root:
|
||||
case types.CreateMessageRequest(params=params):
|
||||
with responder:
|
||||
response = await self._sampling_callback(ctx, params)
|
||||
client_response = ClientResponse.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case types.ListRootsRequest():
|
||||
with responder:
|
||||
response = await self._list_roots_callback(ctx)
|
||||
client_response = ClientResponse.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case types.PingRequest():
|
||||
with responder:
|
||||
return await responder.respond(
|
||||
types.ClientResult(root=types.EmptyResult())
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import AsyncGenerator
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT
|
||||
from mcp.server import Server
|
||||
from mcp.types import JSONRPCMessage
|
||||
|
||||
@@ -54,6 +54,8 @@ async def create_client_server_memory_streams() -> (
|
||||
async def create_connected_server_and_client_session(
|
||||
server: Server,
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
sampling_callback: SamplingFnT | None = None,
|
||||
list_roots_callback: ListRootsFnT | None = None,
|
||||
raise_exceptions: bool = False,
|
||||
) -> AsyncGenerator[ClientSession, None]:
|
||||
"""Creates a ClientSession that is connected to a running MCP server."""
|
||||
@@ -80,6 +82,8 @@ async def create_connected_server_and_client_session(
|
||||
read_stream=client_read,
|
||||
write_stream=client_write,
|
||||
read_timeout_seconds=read_timeout_seconds,
|
||||
sampling_callback=sampling_callback,
|
||||
list_roots_callback=list_roots_callback,
|
||||
) as client_session:
|
||||
await client_session.initialize()
|
||||
yield client_session
|
||||
|
||||
70
tests/client/test_list_roots_callback.py
Normal file
70
tests/client/test_list_roots_callback.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import pytest
|
||||
from pydantic import FileUrl
|
||||
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.server.fastmcp.server import Context
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.shared.memory import (
|
||||
create_connected_server_and_client_session as create_session,
|
||||
)
|
||||
from mcp.types import (
|
||||
ListRootsResult,
|
||||
Root,
|
||||
TextContent,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_roots_callback():
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
server = FastMCP("test")
|
||||
|
||||
callback_return = ListRootsResult(
|
||||
roots=[
|
||||
Root(
|
||||
uri=FileUrl("file://users/fake/test"),
|
||||
name="Test Root 1",
|
||||
),
|
||||
Root(
|
||||
uri=FileUrl("file://users/fake/test/2"),
|
||||
name="Test Root 2",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
async def list_roots_callback(
|
||||
context: RequestContext[ClientSession, None],
|
||||
) -> ListRootsResult:
|
||||
return callback_return
|
||||
|
||||
@server.tool("test_list_roots")
|
||||
async def test_list_roots(context: Context, message: str):
|
||||
roots = await context.session.list_roots()
|
||||
assert roots == callback_return
|
||||
return True
|
||||
|
||||
# Test with list_roots callback
|
||||
async with create_session(
|
||||
server._mcp_server, list_roots_callback=list_roots_callback
|
||||
) as client_session:
|
||||
# Make a request to trigger sampling callback
|
||||
result = await client_session.call_tool(
|
||||
"test_list_roots", {"message": "test message"}
|
||||
)
|
||||
assert result.isError is False
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert result.content[0].text == "true"
|
||||
|
||||
# Test without list_roots callback
|
||||
async with create_session(server._mcp_server) as client_session:
|
||||
# Make a request to trigger sampling callback
|
||||
result = await client_session.call_tool(
|
||||
"test_list_roots", {"message": "test message"}
|
||||
)
|
||||
assert result.isError is True
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert (
|
||||
result.content[0].text
|
||||
== "Error executing tool test_list_roots: List roots not supported"
|
||||
)
|
||||
73
tests/client/test_sampling_callback.py
Normal file
73
tests/client/test_sampling_callback.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import pytest
|
||||
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.shared.memory import (
|
||||
create_connected_server_and_client_session as create_session,
|
||||
)
|
||||
from mcp.types import (
|
||||
CreateMessageRequestParams,
|
||||
CreateMessageResult,
|
||||
SamplingMessage,
|
||||
TextContent,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_sampling_callback():
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
server = FastMCP("test")
|
||||
|
||||
callback_return = CreateMessageResult(
|
||||
role="assistant",
|
||||
content=TextContent(
|
||||
type="text", text="This is a response from the sampling callback"
|
||||
),
|
||||
model="test-model",
|
||||
stopReason="endTurn",
|
||||
)
|
||||
|
||||
async def sampling_callback(
|
||||
context: RequestContext[ClientSession, None],
|
||||
params: CreateMessageRequestParams,
|
||||
) -> CreateMessageResult:
|
||||
return callback_return
|
||||
|
||||
@server.tool("test_sampling")
|
||||
async def test_sampling_tool(message: str):
|
||||
value = await server.get_context().session.create_message(
|
||||
messages=[
|
||||
SamplingMessage(
|
||||
role="user", content=TextContent(type="text", text=message)
|
||||
)
|
||||
],
|
||||
max_tokens=100,
|
||||
)
|
||||
assert value == callback_return
|
||||
return True
|
||||
|
||||
# Test with sampling callback
|
||||
async with create_session(
|
||||
server._mcp_server, sampling_callback=sampling_callback
|
||||
) as client_session:
|
||||
# Make a request to trigger sampling callback
|
||||
result = await client_session.call_tool(
|
||||
"test_sampling", {"message": "Test message for sampling"}
|
||||
)
|
||||
assert result.isError is False
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert result.content[0].text == "true"
|
||||
|
||||
# Test without sampling callback
|
||||
async with create_session(server._mcp_server) as client_session:
|
||||
# Make a request to trigger sampling callback
|
||||
result = await client_session.call_tool(
|
||||
"test_sampling", {"message": "Test message for sampling"}
|
||||
)
|
||||
assert result.isError is True
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert (
|
||||
result.content[0].text
|
||||
== "Error executing tool test_sampling: Sampling not supported"
|
||||
)
|
||||
@@ -1,12 +1,17 @@
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
from mcp.client.stdio import StdioServerParameters, stdio_client
|
||||
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
|
||||
|
||||
tee: str = shutil.which("tee") # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.skipif(tee is None, reason="could not find tee command")
|
||||
async def test_stdio_client():
|
||||
server_parameters = StdioServerParameters(command="/usr/bin/tee")
|
||||
server_parameters = StdioServerParameters(command=tee)
|
||||
|
||||
async with stdio_client(server_parameters) as (read_stream, write_stream):
|
||||
# Test sending and receiving messages
|
||||
|
||||
Reference in New Issue
Block a user