StreamableHttp - client refactoring and resumability support (#595)

This commit is contained in:
ihrpr
2025-05-02 14:49:50 +01:00
committed by GitHub
parent cf8b66b82f
commit 74f5fcfa0d
5 changed files with 733 additions and 218 deletions

View File

@@ -23,16 +23,31 @@ from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Mount
import mcp.types as types
from mcp.client.session import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from mcp.server import Server
from mcp.server.streamable_http import (
MCP_SESSION_ID_HEADER,
SESSION_ID_PATTERN,
EventCallback,
EventId,
EventMessage,
EventStore,
StreamableHTTPServerTransport,
StreamId,
)
from mcp.shared.exceptions import McpError
from mcp.types import InitializeResult, TextContent, TextResourceContents, Tool
from mcp.shared.message import (
ClientMessageMetadata,
)
from mcp.shared.session import RequestResponder
from mcp.types import (
InitializeResult,
TextContent,
TextResourceContents,
Tool,
)
# Test constants
SERVER_NAME = "test_streamable_http_server"
@@ -49,6 +64,51 @@ INIT_REQUEST = {
}
# Simple in-memory event store for testing
class SimpleEventStore(EventStore):
"""Simple in-memory event store for testing."""
def __init__(self):
self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage]] = []
self._event_id_counter = 0
async def store_event(
self, stream_id: StreamId, message: types.JSONRPCMessage
) -> EventId:
"""Store an event and return its ID."""
self._event_id_counter += 1
event_id = str(self._event_id_counter)
self._events.append((stream_id, event_id, message))
return event_id
async def replay_events_after(
self,
last_event_id: EventId,
send_callback: EventCallback,
) -> StreamId | None:
"""Replay events after the specified ID."""
# Find the index of the last event ID
start_index = None
for i, (_, event_id, _) in enumerate(self._events):
if event_id == last_event_id:
start_index = i + 1
break
if start_index is None:
# If event ID not found, start from beginning
start_index = 0
stream_id = None
# Replay events
for _, event_id, message in self._events[start_index:]:
await send_callback(EventMessage(message, event_id))
# Capture the stream ID from the first replayed event
if stream_id is None and len(self._events) > start_index:
stream_id = self._events[start_index][0]
return stream_id
# Test server implementation that follows MCP protocol
class ServerTest(Server):
def __init__(self):
@@ -78,25 +138,57 @@ class ServerTest(Server):
description="A test tool that sends a notification",
inputSchema={"type": "object", "properties": {}},
),
Tool(
name="long_running_with_checkpoints",
description="A long-running tool that sends periodic notifications",
inputSchema={"type": "object", "properties": {}},
),
]
@self.call_tool()
async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
ctx = self.request_context
# When the tool is called, send a notification to test GET stream
if name == "test_tool_with_standalone_notification":
ctx = self.request_context
await ctx.session.send_resource_updated(
uri=AnyUrl("http://test_resource")
)
return [TextContent(type="text", text=f"Called {name}")]
elif name == "long_running_with_checkpoints":
# Send notifications that are part of the response stream
# This simulates a long-running tool that sends logs
await ctx.session.send_log_message(
level="info",
data="Tool started",
logger="tool",
related_request_id=ctx.request_id, # need for stream association
)
await anyio.sleep(0.1)
await ctx.session.send_log_message(
level="info",
data="Tool is almost done",
logger="tool",
related_request_id=ctx.request_id,
)
return [TextContent(type="text", text="Completed!")]
return [TextContent(type="text", text=f"Called {name}")]
def create_app(is_json_response_enabled=False) -> Starlette:
def create_app(
is_json_response_enabled=False, event_store: EventStore | None = None
) -> Starlette:
"""Create a Starlette application for testing that matches the example server.
Args:
is_json_response_enabled: If True, use JSON responses instead of SSE streams.
event_store: Optional event store for testing resumability.
"""
# Create server instance
server = ServerTest()
@@ -139,6 +231,7 @@ def create_app(is_json_response_enabled=False) -> Starlette:
http_transport = StreamableHTTPServerTransport(
mcp_session_id=new_session_id,
is_json_response_enabled=is_json_response_enabled,
event_store=event_store,
)
async with http_transport.connect() as streams:
@@ -183,15 +276,18 @@ def create_app(is_json_response_enabled=False) -> Starlette:
return app
def run_server(port: int, is_json_response_enabled=False) -> None:
def run_server(
port: int, is_json_response_enabled=False, event_store: EventStore | None = None
) -> None:
"""Run the test server.
Args:
port: Port to listen on.
is_json_response_enabled: If True, use JSON responses instead of SSE streams.
event_store: Optional event store for testing resumability.
"""
app = create_app(is_json_response_enabled)
app = create_app(is_json_response_enabled, event_store)
# Configure server
config = uvicorn.Config(
app=app,
@@ -261,6 +357,53 @@ def basic_server(basic_server_port: int) -> Generator[None, None, None]:
proc.join(timeout=2)
@pytest.fixture
def event_store() -> SimpleEventStore:
"""Create a test event store."""
return SimpleEventStore()
@pytest.fixture
def event_server_port() -> int:
"""Find an available port for the event store server."""
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
@pytest.fixture
def event_server(
event_server_port: int, event_store: SimpleEventStore
) -> Generator[tuple[SimpleEventStore, str], None, None]:
"""Start a server with event store enabled."""
proc = multiprocessing.Process(
target=run_server,
kwargs={"port": event_server_port, "event_store": event_store},
daemon=True,
)
proc.start()
# Wait for server to be running
max_attempts = 20
attempt = 0
while attempt < max_attempts:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("127.0.0.1", event_server_port))
break
except ConnectionRefusedError:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
yield event_store, f"http://127.0.0.1:{event_server_port}"
# Clean up
proc.kill()
proc.join(timeout=2)
@pytest.fixture
def json_response_server(json_server_port: int) -> Generator[None, None, None]:
"""Start a server with JSON response enabled."""
@@ -679,7 +822,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session)
"""Test client tool invocation."""
# First list tools
tools = await initialized_client_session.list_tools()
assert len(tools.tools) == 2
assert len(tools.tools) == 3
assert tools.tools[0].name == "test_tool"
# Call the tool
@@ -720,7 +863,7 @@ async def test_streamablehttp_client_session_persistence(
# Make multiple requests to verify session persistence
tools = await session.list_tools()
assert len(tools.tools) == 2
assert len(tools.tools) == 3
# Read a resource
resource = await session.read_resource(uri=AnyUrl("foobar://test-persist"))
@@ -751,7 +894,7 @@ async def test_streamablehttp_client_json_response(
# Check tool listing
tools = await session.list_tools()
assert len(tools.tools) == 2
assert len(tools.tools) == 3
# Call a tool and verify JSON response handling
result = await session.call_tool("test_tool", {})
@@ -813,25 +956,169 @@ async def test_streamablehttp_client_session_termination(
):
"""Test client session termination functionality."""
captured_session_id = None
# Create the streamablehttp_client with a custom httpx client to capture headers
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
read_stream,
write_stream,
terminate_session,
get_session_id,
):
async with ClientSession(read_stream, write_stream) as session:
# Initialize the session
result = await session.initialize()
assert isinstance(result, InitializeResult)
captured_session_id = get_session_id()
assert captured_session_id is not None
# Make a request to confirm session is working
tools = await session.list_tools()
assert len(tools.tools) == 2
assert len(tools.tools) == 3
# After exiting ClientSession context, explicitly terminate the session
await terminate_session()
headers = {}
if captured_session_id:
headers[MCP_SESSION_ID_HEADER] = captured_session_id
async with streamablehttp_client(f"{basic_server_url}/mcp", headers=headers) as (
read_stream,
write_stream,
_,
):
async with ClientSession(read_stream, write_stream) as session:
# Attempt to make a request after termination
with pytest.raises(
McpError,
match="Session terminated",
):
await session.list_tools()
@pytest.mark.anyio
async def test_streamablehttp_client_resumption(event_server):
"""Test client session to resume a long running tool."""
_, server_url = event_server
# Variables to track the state
captured_session_id = None
captured_resumption_token = None
captured_notifications = []
tool_started = False
async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
) -> None:
if isinstance(message, types.ServerNotification):
captured_notifications.append(message)
# Look for our special notification that indicates the tool is running
if isinstance(message.root, types.LoggingMessageNotification):
if message.root.params.data == "Tool started":
nonlocal tool_started
tool_started = True
async def on_resumption_token_update(token: str) -> None:
nonlocal captured_resumption_token
captured_resumption_token = token
# First, start the client session and begin the long-running tool
async with streamablehttp_client(f"{server_url}/mcp", terminate_on_close=False) as (
read_stream,
write_stream,
get_session_id,
):
async with ClientSession(
read_stream, write_stream, message_handler=message_handler
) as session:
# Initialize the session
result = await session.initialize()
assert isinstance(result, InitializeResult)
captured_session_id = get_session_id()
assert captured_session_id is not None
# Start a long-running tool in a task
async with anyio.create_task_group() as tg:
async def run_tool():
metadata = ClientMessageMetadata(
on_resumption_token_update=on_resumption_token_update,
)
await session.send_request(
types.ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(
name="long_running_with_checkpoints", arguments={}
),
)
),
types.CallToolResult,
metadata=metadata,
)
tg.start_soon(run_tool)
# Wait for the tool to start and at least one notification
# and then kill the task group
while not tool_started or not captured_resumption_token:
await anyio.sleep(0.1)
tg.cancel_scope.cancel()
# Store pre notifications and clear the captured notifications
# for the post-resumption check
captured_notifications_pre = captured_notifications.copy()
captured_notifications = []
# Now resume the session with the same mcp-session-id
headers = {}
if captured_session_id:
headers[MCP_SESSION_ID_HEADER] = captured_session_id
async with streamablehttp_client(f"{server_url}/mcp", headers=headers) as (
read_stream,
write_stream,
_,
):
async with ClientSession(
read_stream, write_stream, message_handler=message_handler
) as session:
# Don't initialize - just use the existing session
# Resume the tool with the resumption token
assert captured_resumption_token is not None
metadata = ClientMessageMetadata(
resumption_token=captured_resumption_token,
)
result = await session.send_request(
types.ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(
name="long_running_with_checkpoints", arguments={}
),
)
),
types.CallToolResult,
metadata=metadata,
)
# We should get a complete result
assert len(result.content) == 1
assert result.content[0].type == "text"
assert "Completed" in result.content[0].text
# We should have received the remaining notifications
assert len(captured_notifications) > 0
# Should not have the first notification
# Check that "Tool started" notification isn't repeated when resuming
assert not any(
isinstance(n.root, types.LoggingMessageNotification)
and n.root.params.data == "Tool started"
for n in captured_notifications
)
# there is no intersection between pre and post notifications
assert not any(
n in captured_notifications_pre for n in captured_notifications
)