mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 06:54:18 +01:00
Client sampling and roots capabilities set to None if not implemented (#802)
Co-authored-by: ihrpr <inna@anthropic.com>
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.shared.message import SessionMessage
|
||||
from mcp.shared.session import RequestResponder
|
||||
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
||||
@@ -380,3 +383,167 @@ async def test_client_session_version_negotiation_failure():
|
||||
# Should raise RuntimeError for unsupported version
|
||||
with pytest.raises(RuntimeError, match="Unsupported protocol version"):
|
||||
await session.initialize()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_client_capabilities_default():
|
||||
"""Test that client capabilities are properly set with default callbacks"""
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](1)
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](1)
|
||||
|
||||
received_capabilities = None
|
||||
|
||||
async def mock_server():
|
||||
nonlocal received_capabilities
|
||||
|
||||
session_message = await client_to_server_receive.receive()
|
||||
jsonrpc_request = session_message.message
|
||||
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
|
||||
request = ClientRequest.model_validate(
|
||||
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
assert isinstance(request.root, InitializeRequest)
|
||||
received_capabilities = request.root.params.capabilities
|
||||
|
||||
result = ServerResult(
|
||||
InitializeResult(
|
||||
protocolVersion=LATEST_PROTOCOL_VERSION,
|
||||
capabilities=ServerCapabilities(),
|
||||
serverInfo=Implementation(name="mock-server", version="0.1.0"),
|
||||
)
|
||||
)
|
||||
|
||||
async with server_to_client_send:
|
||||
await server_to_client_send.send(
|
||||
SessionMessage(
|
||||
JSONRPCMessage(
|
||||
JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=jsonrpc_request.root.id,
|
||||
result=result.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
# Receive initialized notification
|
||||
await client_to_server_receive.receive()
|
||||
|
||||
async with (
|
||||
ClientSession(
|
||||
server_to_client_receive,
|
||||
client_to_server_send,
|
||||
) as session,
|
||||
anyio.create_task_group() as tg,
|
||||
client_to_server_send,
|
||||
client_to_server_receive,
|
||||
server_to_client_send,
|
||||
server_to_client_receive,
|
||||
):
|
||||
tg.start_soon(mock_server)
|
||||
await session.initialize()
|
||||
|
||||
# Assert that capabilities are properly set with defaults
|
||||
assert received_capabilities is not None
|
||||
assert received_capabilities.sampling is None # No custom sampling callback
|
||||
assert received_capabilities.roots is None # No custom list_roots callback
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_client_capabilities_with_custom_callbacks():
|
||||
"""Test that client capabilities are properly set with custom callbacks"""
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](1)
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](1)
|
||||
|
||||
received_capabilities = None
|
||||
|
||||
async def custom_sampling_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
) -> types.CreateMessageResult | types.ErrorData:
|
||||
return types.CreateMessageResult(
|
||||
role="assistant",
|
||||
content=types.TextContent(type="text", text="test"),
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
async def custom_list_roots_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
) -> types.ListRootsResult | types.ErrorData:
|
||||
return types.ListRootsResult(roots=[])
|
||||
|
||||
async def mock_server():
|
||||
nonlocal received_capabilities
|
||||
|
||||
session_message = await client_to_server_receive.receive()
|
||||
jsonrpc_request = session_message.message
|
||||
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
|
||||
request = ClientRequest.model_validate(
|
||||
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
assert isinstance(request.root, InitializeRequest)
|
||||
received_capabilities = request.root.params.capabilities
|
||||
|
||||
result = ServerResult(
|
||||
InitializeResult(
|
||||
protocolVersion=LATEST_PROTOCOL_VERSION,
|
||||
capabilities=ServerCapabilities(),
|
||||
serverInfo=Implementation(name="mock-server", version="0.1.0"),
|
||||
)
|
||||
)
|
||||
|
||||
async with server_to_client_send:
|
||||
await server_to_client_send.send(
|
||||
SessionMessage(
|
||||
JSONRPCMessage(
|
||||
JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=jsonrpc_request.root.id,
|
||||
result=result.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
# Receive initialized notification
|
||||
await client_to_server_receive.receive()
|
||||
|
||||
async with (
|
||||
ClientSession(
|
||||
server_to_client_receive,
|
||||
client_to_server_send,
|
||||
sampling_callback=custom_sampling_callback,
|
||||
list_roots_callback=custom_list_roots_callback,
|
||||
) as session,
|
||||
anyio.create_task_group() as tg,
|
||||
client_to_server_send,
|
||||
client_to_server_receive,
|
||||
server_to_client_send,
|
||||
server_to_client_receive,
|
||||
):
|
||||
tg.start_soon(mock_server)
|
||||
await session.initialize()
|
||||
|
||||
# Assert that capabilities are properly set with custom callbacks
|
||||
assert received_capabilities is not None
|
||||
assert (
|
||||
received_capabilities.sampling is not None
|
||||
) # Custom sampling callback provided
|
||||
assert isinstance(received_capabilities.sampling, types.SamplingCapability)
|
||||
assert (
|
||||
received_capabilities.roots is not None
|
||||
) # Custom list_roots callback provided
|
||||
assert isinstance(received_capabilities.roots, types.RootsCapability)
|
||||
assert (
|
||||
received_capabilities.roots.listChanged is True
|
||||
) # Should be True for custom callback
|
||||
|
||||
Reference in New Issue
Block a user