mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Client sampling and roots capabilities set to None if not implemented (#802)
Co-authored-by: ihrpr <inna@anthropic.com>
This commit is contained in:
@@ -116,12 +116,18 @@ class ClientSession(
|
|||||||
self._message_handler = message_handler or _default_message_handler
|
self._message_handler = message_handler or _default_message_handler
|
||||||
|
|
||||||
async def initialize(self) -> types.InitializeResult:
|
async def initialize(self) -> types.InitializeResult:
|
||||||
sampling = types.SamplingCapability()
|
sampling = (
|
||||||
roots = types.RootsCapability(
|
types.SamplingCapability()
|
||||||
|
if self._sampling_callback is not _default_sampling_callback
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
roots = (
|
||||||
# TODO: Should this be based on whether we
|
# TODO: Should this be based on whether we
|
||||||
# _will_ send notifications, or only whether
|
# _will_ send notifications, or only whether
|
||||||
# they're supported?
|
# they're supported?
|
||||||
listChanged=True,
|
types.RootsCapability(listChanged=True)
|
||||||
|
if self._list_roots_callback is not _default_list_roots_callback
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await self.send_request(
|
result = await self.send_request(
|
||||||
|
|||||||
@@ -218,7 +218,7 @@ class RootsCapability(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class SamplingCapability(BaseModel):
|
class SamplingCapability(BaseModel):
|
||||||
"""Capability for logging operations."""
|
"""Capability for sampling operations."""
|
||||||
|
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import mcp.types as types
|
import mcp.types as types
|
||||||
from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession
|
from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession
|
||||||
|
from mcp.shared.context import RequestContext
|
||||||
from mcp.shared.message import SessionMessage
|
from mcp.shared.message import SessionMessage
|
||||||
from mcp.shared.session import RequestResponder
|
from mcp.shared.session import RequestResponder
|
||||||
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
||||||
@@ -380,3 +383,167 @@ async def test_client_session_version_negotiation_failure():
|
|||||||
# Should raise RuntimeError for unsupported version
|
# Should raise RuntimeError for unsupported version
|
||||||
with pytest.raises(RuntimeError, match="Unsupported protocol version"):
|
with pytest.raises(RuntimeError, match="Unsupported protocol version"):
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_client_capabilities_default():
|
||||||
|
"""Test that client capabilities are properly set with default callbacks"""
|
||||||
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||||
|
SessionMessage
|
||||||
|
](1)
|
||||||
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||||
|
SessionMessage
|
||||||
|
](1)
|
||||||
|
|
||||||
|
received_capabilities = None
|
||||||
|
|
||||||
|
async def mock_server():
|
||||||
|
nonlocal received_capabilities
|
||||||
|
|
||||||
|
session_message = await client_to_server_receive.receive()
|
||||||
|
jsonrpc_request = session_message.message
|
||||||
|
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
|
||||||
|
request = ClientRequest.model_validate(
|
||||||
|
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||||
|
)
|
||||||
|
assert isinstance(request.root, InitializeRequest)
|
||||||
|
received_capabilities = request.root.params.capabilities
|
||||||
|
|
||||||
|
result = ServerResult(
|
||||||
|
InitializeResult(
|
||||||
|
protocolVersion=LATEST_PROTOCOL_VERSION,
|
||||||
|
capabilities=ServerCapabilities(),
|
||||||
|
serverInfo=Implementation(name="mock-server", version="0.1.0"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async with server_to_client_send:
|
||||||
|
await server_to_client_send.send(
|
||||||
|
SessionMessage(
|
||||||
|
JSONRPCMessage(
|
||||||
|
JSONRPCResponse(
|
||||||
|
jsonrpc="2.0",
|
||||||
|
id=jsonrpc_request.root.id,
|
||||||
|
result=result.model_dump(
|
||||||
|
by_alias=True, mode="json", exclude_none=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Receive initialized notification
|
||||||
|
await client_to_server_receive.receive()
|
||||||
|
|
||||||
|
async with (
|
||||||
|
ClientSession(
|
||||||
|
server_to_client_receive,
|
||||||
|
client_to_server_send,
|
||||||
|
) as session,
|
||||||
|
anyio.create_task_group() as tg,
|
||||||
|
client_to_server_send,
|
||||||
|
client_to_server_receive,
|
||||||
|
server_to_client_send,
|
||||||
|
server_to_client_receive,
|
||||||
|
):
|
||||||
|
tg.start_soon(mock_server)
|
||||||
|
await session.initialize()
|
||||||
|
|
||||||
|
# Assert that capabilities are properly set with defaults
|
||||||
|
assert received_capabilities is not None
|
||||||
|
assert received_capabilities.sampling is None # No custom sampling callback
|
||||||
|
assert received_capabilities.roots is None # No custom list_roots callback
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_client_capabilities_with_custom_callbacks():
|
||||||
|
"""Test that client capabilities are properly set with custom callbacks"""
|
||||||
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||||
|
SessionMessage
|
||||||
|
](1)
|
||||||
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||||
|
SessionMessage
|
||||||
|
](1)
|
||||||
|
|
||||||
|
received_capabilities = None
|
||||||
|
|
||||||
|
async def custom_sampling_callback(
|
||||||
|
context: RequestContext["ClientSession", Any],
|
||||||
|
params: types.CreateMessageRequestParams,
|
||||||
|
) -> types.CreateMessageResult | types.ErrorData:
|
||||||
|
return types.CreateMessageResult(
|
||||||
|
role="assistant",
|
||||||
|
content=types.TextContent(type="text", text="test"),
|
||||||
|
model="test-model",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def custom_list_roots_callback(
|
||||||
|
context: RequestContext["ClientSession", Any],
|
||||||
|
) -> types.ListRootsResult | types.ErrorData:
|
||||||
|
return types.ListRootsResult(roots=[])
|
||||||
|
|
||||||
|
async def mock_server():
|
||||||
|
nonlocal received_capabilities
|
||||||
|
|
||||||
|
session_message = await client_to_server_receive.receive()
|
||||||
|
jsonrpc_request = session_message.message
|
||||||
|
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
|
||||||
|
request = ClientRequest.model_validate(
|
||||||
|
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||||
|
)
|
||||||
|
assert isinstance(request.root, InitializeRequest)
|
||||||
|
received_capabilities = request.root.params.capabilities
|
||||||
|
|
||||||
|
result = ServerResult(
|
||||||
|
InitializeResult(
|
||||||
|
protocolVersion=LATEST_PROTOCOL_VERSION,
|
||||||
|
capabilities=ServerCapabilities(),
|
||||||
|
serverInfo=Implementation(name="mock-server", version="0.1.0"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async with server_to_client_send:
|
||||||
|
await server_to_client_send.send(
|
||||||
|
SessionMessage(
|
||||||
|
JSONRPCMessage(
|
||||||
|
JSONRPCResponse(
|
||||||
|
jsonrpc="2.0",
|
||||||
|
id=jsonrpc_request.root.id,
|
||||||
|
result=result.model_dump(
|
||||||
|
by_alias=True, mode="json", exclude_none=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Receive initialized notification
|
||||||
|
await client_to_server_receive.receive()
|
||||||
|
|
||||||
|
async with (
|
||||||
|
ClientSession(
|
||||||
|
server_to_client_receive,
|
||||||
|
client_to_server_send,
|
||||||
|
sampling_callback=custom_sampling_callback,
|
||||||
|
list_roots_callback=custom_list_roots_callback,
|
||||||
|
) as session,
|
||||||
|
anyio.create_task_group() as tg,
|
||||||
|
client_to_server_send,
|
||||||
|
client_to_server_receive,
|
||||||
|
server_to_client_send,
|
||||||
|
server_to_client_receive,
|
||||||
|
):
|
||||||
|
tg.start_soon(mock_server)
|
||||||
|
await session.initialize()
|
||||||
|
|
||||||
|
# Assert that capabilities are properly set with custom callbacks
|
||||||
|
assert received_capabilities is not None
|
||||||
|
assert (
|
||||||
|
received_capabilities.sampling is not None
|
||||||
|
) # Custom sampling callback provided
|
||||||
|
assert isinstance(received_capabilities.sampling, types.SamplingCapability)
|
||||||
|
assert (
|
||||||
|
received_capabilities.roots is not None
|
||||||
|
) # Custom list_roots callback provided
|
||||||
|
assert isinstance(received_capabilities.roots, types.RootsCapability)
|
||||||
|
assert (
|
||||||
|
received_capabilities.roots.listChanged is True
|
||||||
|
) # Should be True for custom callback
|
||||||
|
|||||||
Reference in New Issue
Block a user