Add client handling for sampling, list roots, ping (#218)

Adds sampling and list roots callbacks to the ClientSession, allowing the client to handle requests from the server.

Co-authored-by: TerminalMan <84923604+SecretiveShell@users.noreply.github.com>
Co-authored-by: David Soria Parra <davidsp@anthropic.com>
This commit is contained in:
Jerome
2025-02-20 10:49:43 +00:00
committed by GitHub
parent 106619967b
commit ff22f48365
6 changed files with 256 additions and 12 deletions

View File

@@ -0,0 +1,70 @@
import pytest
from pydantic import FileUrl
from mcp.client.session import ClientSession
from mcp.server.fastmcp.server import Context
from mcp.shared.context import RequestContext
from mcp.shared.memory import (
create_connected_server_and_client_session as create_session,
)
from mcp.types import (
ListRootsResult,
Root,
TextContent,
)
@pytest.mark.anyio
async def test_list_roots_callback():
from mcp.server.fastmcp import FastMCP
server = FastMCP("test")
callback_return = ListRootsResult(
roots=[
Root(
uri=FileUrl("file://users/fake/test"),
name="Test Root 1",
),
Root(
uri=FileUrl("file://users/fake/test/2"),
name="Test Root 2",
),
]
)
async def list_roots_callback(
context: RequestContext[ClientSession, None],
) -> ListRootsResult:
return callback_return
@server.tool("test_list_roots")
async def test_list_roots(context: Context, message: str):
roots = await context.session.list_roots()
assert roots == callback_return
return True
# Test with list_roots callback
async with create_session(
server._mcp_server, list_roots_callback=list_roots_callback
) as client_session:
# Make a request to trigger sampling callback
result = await client_session.call_tool(
"test_list_roots", {"message": "test message"}
)
assert result.isError is False
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "true"
# Test without list_roots callback
async with create_session(server._mcp_server) as client_session:
# Make a request to trigger sampling callback
result = await client_session.call_tool(
"test_list_roots", {"message": "test message"}
)
assert result.isError is True
assert isinstance(result.content[0], TextContent)
assert (
result.content[0].text
== "Error executing tool test_list_roots: List roots not supported"
)

View File

@@ -0,0 +1,73 @@
import pytest
from mcp.client.session import ClientSession
from mcp.shared.context import RequestContext
from mcp.shared.memory import (
create_connected_server_and_client_session as create_session,
)
from mcp.types import (
CreateMessageRequestParams,
CreateMessageResult,
SamplingMessage,
TextContent,
)
@pytest.mark.anyio
async def test_sampling_callback():
from mcp.server.fastmcp import FastMCP
server = FastMCP("test")
callback_return = CreateMessageResult(
role="assistant",
content=TextContent(
type="text", text="This is a response from the sampling callback"
),
model="test-model",
stopReason="endTurn",
)
async def sampling_callback(
context: RequestContext[ClientSession, None],
params: CreateMessageRequestParams,
) -> CreateMessageResult:
return callback_return
@server.tool("test_sampling")
async def test_sampling_tool(message: str):
value = await server.get_context().session.create_message(
messages=[
SamplingMessage(
role="user", content=TextContent(type="text", text=message)
)
],
max_tokens=100,
)
assert value == callback_return
return True
# Test with sampling callback
async with create_session(
server._mcp_server, sampling_callback=sampling_callback
) as client_session:
# Make a request to trigger sampling callback
result = await client_session.call_tool(
"test_sampling", {"message": "Test message for sampling"}
)
assert result.isError is False
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "true"
# Test without sampling callback
async with create_session(server._mcp_server) as client_session:
# Make a request to trigger sampling callback
result = await client_session.call_tool(
"test_sampling", {"message": "Test message for sampling"}
)
assert result.isError is True
assert isinstance(result.content[0], TextContent)
assert (
result.content[0].text
== "Error executing tool test_sampling: Sampling not supported"
)

View File

@@ -1,12 +1,17 @@
import shutil
import pytest
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
tee: str = shutil.which("tee") # type: ignore
@pytest.mark.anyio
@pytest.mark.skipif(tee is None, reason="could not find tee command")
async def test_stdio_client():
server_parameters = StdioServerParameters(command="/usr/bin/tee")
server_parameters = StdioServerParameters(command=tee)
async with stdio_client(server_parameters) as (read_stream, write_stream):
# Test sending and receiving messages