Fix streamable http sampling (#693)

This commit is contained in:
ihrpr
2025-05-12 18:31:35 +01:00
committed by GitHub
parent ed25167fa5
commit c6fb822c86
7 changed files with 152 additions and 23 deletions

View File

@@ -54,7 +54,7 @@ def test_absolute_uv_path(mock_config_path: Path):
"""Test that the absolute path to uv is used when available."""
# Mock the shutil.which function to return a fake path
mock_uv_path = "/usr/local/bin/uv"
with patch("mcp.cli.claude.get_uv_path", return_value=mock_uv_path):
# Setup
server_name = "test_server"
@@ -71,5 +71,5 @@ def test_absolute_uv_path(mock_config_path: Path):
# Verify the command is the absolute path
server_config = config["mcpServers"][server_name]
command = server_config["command"]
assert command == mock_uv_path
assert command == mock_uv_path

View File

@@ -8,6 +8,7 @@ import multiprocessing
import socket
import time
from collections.abc import Generator
from typing import Any
import anyio
import httpx
@@ -33,6 +34,7 @@ from mcp.server.streamable_http import (
StreamId,
)
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from mcp.shared.context import RequestContext
from mcp.shared.exceptions import McpError
from mcp.shared.message import (
ClientMessageMetadata,
@@ -139,6 +141,11 @@ class ServerTest(Server):
description="A long-running tool that sends periodic notifications",
inputSchema={"type": "object", "properties": {}},
),
Tool(
name="test_sampling_tool",
description="A tool that triggers server-side sampling",
inputSchema={"type": "object", "properties": {}},
),
]
@self.call_tool()
@@ -174,6 +181,34 @@ class ServerTest(Server):
return [TextContent(type="text", text="Completed!")]
elif name == "test_sampling_tool":
# Test sampling by requesting the client to sample a message
sampling_result = await ctx.session.create_message(
messages=[
types.SamplingMessage(
role="user",
content=types.TextContent(
type="text", text="Server needs client sampling"
),
)
],
max_tokens=100,
related_request_id=ctx.request_id,
)
# Return the sampling result in the tool response
response = (
sampling_result.content.text
if sampling_result.content.type == "text"
else None
)
return [
TextContent(
type="text",
text=f"Response from sampling: {response}",
)
]
return [TextContent(type="text", text=f"Called {name}")]
@@ -754,7 +789,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session)
"""Test client tool invocation."""
# First list tools
tools = await initialized_client_session.list_tools()
assert len(tools.tools) == 3
assert len(tools.tools) == 4
assert tools.tools[0].name == "test_tool"
# Call the tool
@@ -795,7 +830,7 @@ async def test_streamablehttp_client_session_persistence(
# Make multiple requests to verify session persistence
tools = await session.list_tools()
assert len(tools.tools) == 3
assert len(tools.tools) == 4
# Read a resource
resource = await session.read_resource(uri=AnyUrl("foobar://test-persist"))
@@ -826,7 +861,7 @@ async def test_streamablehttp_client_json_response(
# Check tool listing
tools = await session.list_tools()
assert len(tools.tools) == 3
assert len(tools.tools) == 4
# Call a tool and verify JSON response handling
result = await session.call_tool("test_tool", {})
@@ -905,7 +940,7 @@ async def test_streamablehttp_client_session_termination(
# Make a request to confirm session is working
tools = await session.list_tools()
assert len(tools.tools) == 3
assert len(tools.tools) == 4
headers = {}
if captured_session_id:
@@ -1054,3 +1089,71 @@ async def test_streamablehttp_client_resumption(event_server):
assert not any(
n in captured_notifications_pre for n in captured_notifications
)
@pytest.mark.anyio
async def test_streamablehttp_server_sampling(basic_server, basic_server_url):
"""Test server-initiated sampling request through streamable HTTP transport."""
print("Testing server sampling...")
# Variable to track if sampling callback was invoked
sampling_callback_invoked = False
captured_message_params = None
# Define sampling callback that returns a mock response
async def sampling_callback(
context: RequestContext[ClientSession, Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult:
nonlocal sampling_callback_invoked, captured_message_params
sampling_callback_invoked = True
captured_message_params = params
message_received = (
params.messages[0].content.text
if params.messages[0].content.type == "text"
else None
)
return types.CreateMessageResult(
role="assistant",
content=types.TextContent(
type="text",
text=f"Received message from server: {message_received}",
),
model="test-model",
stopReason="endTurn",
)
# Create client with sampling callback
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
read_stream,
write_stream,
_,
):
async with ClientSession(
read_stream,
write_stream,
sampling_callback=sampling_callback,
) as session:
# Initialize the session
result = await session.initialize()
assert isinstance(result, InitializeResult)
# Call the tool that triggers server-side sampling
tool_result = await session.call_tool("test_sampling_tool", {})
# Verify the tool result contains the expected content
assert len(tool_result.content) == 1
assert tool_result.content[0].type == "text"
assert (
"Response from sampling: Received message from server"
in tool_result.content[0].text
)
# Verify sampling callback was invoked
assert sampling_callback_invoked
assert captured_message_params is not None
assert len(captured_message_params.messages) == 1
assert (
captured_message_params.messages[0].content.text
== "Server needs client sampling"
)