Client sampling and roots capabilities set to None if not implemented (#802)

Co-authored-by: ihrpr <inna@anthropic.com>
This commit is contained in:
Lorenzo
2025-05-29 05:56:34 -03:00
committed by GitHub
parent d55cb2bf4e
commit 7f94bef85e
3 changed files with 177 additions and 4 deletions

View File

@@ -1,8 +1,11 @@
from typing import Any
import anyio
import pytest
import mcp.types as types
from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession
from mcp.shared.context import RequestContext
from mcp.shared.message import SessionMessage
from mcp.shared.session import RequestResponder
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
@@ -380,3 +383,167 @@ async def test_client_session_version_negotiation_failure():
# Should raise RuntimeError for unsupported version
with pytest.raises(RuntimeError, match="Unsupported protocol version"):
await session.initialize()
@pytest.mark.anyio
async def test_client_capabilities_default():
"""Test that client capabilities are properly set with default callbacks"""
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
received_capabilities = None
async def mock_server():
nonlocal received_capabilities
session_message = await client_to_server_receive.receive()
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
request = ClientRequest.model_validate(
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
)
assert isinstance(request.root, InitializeRequest)
received_capabilities = request.root.params.capabilities
result = ServerResult(
InitializeResult(
protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=ServerCapabilities(),
serverInfo=Implementation(name="mock-server", version="0.1.0"),
)
)
async with server_to_client_send:
await server_to_client_send.send(
SessionMessage(
JSONRPCMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.id,
result=result.model_dump(
by_alias=True, mode="json", exclude_none=True
),
)
)
)
)
# Receive initialized notification
await client_to_server_receive.receive()
async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
tg.start_soon(mock_server)
await session.initialize()
# Assert that capabilities are properly set with defaults
assert received_capabilities is not None
assert received_capabilities.sampling is None # No custom sampling callback
assert received_capabilities.roots is None # No custom list_roots callback
@pytest.mark.anyio
async def test_client_capabilities_with_custom_callbacks():
"""Test that client capabilities are properly set with custom callbacks"""
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
received_capabilities = None
async def custom_sampling_callback(
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.ErrorData:
return types.CreateMessageResult(
role="assistant",
content=types.TextContent(type="text", text="test"),
model="test-model",
)
async def custom_list_roots_callback(
context: RequestContext["ClientSession", Any],
) -> types.ListRootsResult | types.ErrorData:
return types.ListRootsResult(roots=[])
async def mock_server():
nonlocal received_capabilities
session_message = await client_to_server_receive.receive()
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
request = ClientRequest.model_validate(
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
)
assert isinstance(request.root, InitializeRequest)
received_capabilities = request.root.params.capabilities
result = ServerResult(
InitializeResult(
protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=ServerCapabilities(),
serverInfo=Implementation(name="mock-server", version="0.1.0"),
)
)
async with server_to_client_send:
await server_to_client_send.send(
SessionMessage(
JSONRPCMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.id,
result=result.model_dump(
by_alias=True, mode="json", exclude_none=True
),
)
)
)
)
# Receive initialized notification
await client_to_server_receive.receive()
async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
sampling_callback=custom_sampling_callback,
list_roots_callback=custom_list_roots_callback,
) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
tg.start_soon(mock_server)
await session.initialize()
# Assert that capabilities are properly set with custom callbacks
assert received_capabilities is not None
assert (
received_capabilities.sampling is not None
) # Custom sampling callback provided
assert isinstance(received_capabilities.sampling, types.SamplingCapability)
assert (
received_capabilities.roots is not None
) # Custom list_roots callback provided
assert isinstance(received_capabilities.roots, types.RootsCapability)
assert (
received_capabilities.roots.listChanged is True
) # Should be True for custom callback