mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 06:54:18 +01:00
Fix streamable http sampling (#693)
This commit is contained in:
@@ -8,6 +8,7 @@ import multiprocessing
|
||||
import socket
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
@@ -33,6 +34,7 @@ from mcp.server.streamable_http import (
|
||||
StreamId,
|
||||
)
|
||||
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.message import (
|
||||
ClientMessageMetadata,
|
||||
@@ -139,6 +141,11 @@ class ServerTest(Server):
|
||||
description="A long-running tool that sends periodic notifications",
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
),
|
||||
Tool(
|
||||
name="test_sampling_tool",
|
||||
description="A tool that triggers server-side sampling",
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
),
|
||||
]
|
||||
|
||||
@self.call_tool()
|
||||
@@ -174,6 +181,34 @@ class ServerTest(Server):
|
||||
|
||||
return [TextContent(type="text", text="Completed!")]
|
||||
|
||||
elif name == "test_sampling_tool":
|
||||
# Test sampling by requesting the client to sample a message
|
||||
sampling_result = await ctx.session.create_message(
|
||||
messages=[
|
||||
types.SamplingMessage(
|
||||
role="user",
|
||||
content=types.TextContent(
|
||||
type="text", text="Server needs client sampling"
|
||||
),
|
||||
)
|
||||
],
|
||||
max_tokens=100,
|
||||
related_request_id=ctx.request_id,
|
||||
)
|
||||
|
||||
# Return the sampling result in the tool response
|
||||
response = (
|
||||
sampling_result.content.text
|
||||
if sampling_result.content.type == "text"
|
||||
else None
|
||||
)
|
||||
return [
|
||||
TextContent(
|
||||
type="text",
|
||||
text=f"Response from sampling: {response}",
|
||||
)
|
||||
]
|
||||
|
||||
return [TextContent(type="text", text=f"Called {name}")]
|
||||
|
||||
|
||||
@@ -754,7 +789,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session)
|
||||
"""Test client tool invocation."""
|
||||
# First list tools
|
||||
tools = await initialized_client_session.list_tools()
|
||||
assert len(tools.tools) == 3
|
||||
assert len(tools.tools) == 4
|
||||
assert tools.tools[0].name == "test_tool"
|
||||
|
||||
# Call the tool
|
||||
@@ -795,7 +830,7 @@ async def test_streamablehttp_client_session_persistence(
|
||||
|
||||
# Make multiple requests to verify session persistence
|
||||
tools = await session.list_tools()
|
||||
assert len(tools.tools) == 3
|
||||
assert len(tools.tools) == 4
|
||||
|
||||
# Read a resource
|
||||
resource = await session.read_resource(uri=AnyUrl("foobar://test-persist"))
|
||||
@@ -826,7 +861,7 @@ async def test_streamablehttp_client_json_response(
|
||||
|
||||
# Check tool listing
|
||||
tools = await session.list_tools()
|
||||
assert len(tools.tools) == 3
|
||||
assert len(tools.tools) == 4
|
||||
|
||||
# Call a tool and verify JSON response handling
|
||||
result = await session.call_tool("test_tool", {})
|
||||
@@ -905,7 +940,7 @@ async def test_streamablehttp_client_session_termination(
|
||||
|
||||
# Make a request to confirm session is working
|
||||
tools = await session.list_tools()
|
||||
assert len(tools.tools) == 3
|
||||
assert len(tools.tools) == 4
|
||||
|
||||
headers = {}
|
||||
if captured_session_id:
|
||||
@@ -1054,3 +1089,71 @@ async def test_streamablehttp_client_resumption(event_server):
|
||||
assert not any(
|
||||
n in captured_notifications_pre for n in captured_notifications
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_streamablehttp_server_sampling(basic_server, basic_server_url):
|
||||
"""Test server-initiated sampling request through streamable HTTP transport."""
|
||||
print("Testing server sampling...")
|
||||
# Variable to track if sampling callback was invoked
|
||||
sampling_callback_invoked = False
|
||||
captured_message_params = None
|
||||
|
||||
# Define sampling callback that returns a mock response
|
||||
async def sampling_callback(
|
||||
context: RequestContext[ClientSession, Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
) -> types.CreateMessageResult:
|
||||
nonlocal sampling_callback_invoked, captured_message_params
|
||||
sampling_callback_invoked = True
|
||||
captured_message_params = params
|
||||
message_received = (
|
||||
params.messages[0].content.text
|
||||
if params.messages[0].content.type == "text"
|
||||
else None
|
||||
)
|
||||
|
||||
return types.CreateMessageResult(
|
||||
role="assistant",
|
||||
content=types.TextContent(
|
||||
type="text",
|
||||
text=f"Received message from server: {message_received}",
|
||||
),
|
||||
model="test-model",
|
||||
stopReason="endTurn",
|
||||
)
|
||||
|
||||
# Create client with sampling callback
|
||||
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
_,
|
||||
):
|
||||
async with ClientSession(
|
||||
read_stream,
|
||||
write_stream,
|
||||
sampling_callback=sampling_callback,
|
||||
) as session:
|
||||
# Initialize the session
|
||||
result = await session.initialize()
|
||||
assert isinstance(result, InitializeResult)
|
||||
|
||||
# Call the tool that triggers server-side sampling
|
||||
tool_result = await session.call_tool("test_sampling_tool", {})
|
||||
|
||||
# Verify the tool result contains the expected content
|
||||
assert len(tool_result.content) == 1
|
||||
assert tool_result.content[0].type == "text"
|
||||
assert (
|
||||
"Response from sampling: Received message from server"
|
||||
in tool_result.content[0].text
|
||||
)
|
||||
|
||||
# Verify sampling callback was invoked
|
||||
assert sampling_callback_invoked
|
||||
assert captured_message_params is not None
|
||||
assert len(captured_message_params.messages) == 1
|
||||
assert (
|
||||
captured_message_params.messages[0].content.text
|
||||
== "Server needs client sampling"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user