mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 06:54:18 +01:00
StreamableHttp - client refactoring and resumability support (#595)
This commit is contained in:
@@ -23,16 +23,31 @@ from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import Mount
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from mcp.server import Server
|
||||
from mcp.server.streamable_http import (
|
||||
MCP_SESSION_ID_HEADER,
|
||||
SESSION_ID_PATTERN,
|
||||
EventCallback,
|
||||
EventId,
|
||||
EventMessage,
|
||||
EventStore,
|
||||
StreamableHTTPServerTransport,
|
||||
StreamId,
|
||||
)
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.types import InitializeResult, TextContent, TextResourceContents, Tool
|
||||
from mcp.shared.message import (
|
||||
ClientMessageMetadata,
|
||||
)
|
||||
from mcp.shared.session import RequestResponder
|
||||
from mcp.types import (
|
||||
InitializeResult,
|
||||
TextContent,
|
||||
TextResourceContents,
|
||||
Tool,
|
||||
)
|
||||
|
||||
# Test constants
|
||||
SERVER_NAME = "test_streamable_http_server"
|
||||
@@ -49,6 +64,51 @@ INIT_REQUEST = {
|
||||
}
|
||||
|
||||
|
||||
# Simple in-memory event store for testing
|
||||
class SimpleEventStore(EventStore):
|
||||
"""Simple in-memory event store for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage]] = []
|
||||
self._event_id_counter = 0
|
||||
|
||||
async def store_event(
|
||||
self, stream_id: StreamId, message: types.JSONRPCMessage
|
||||
) -> EventId:
|
||||
"""Store an event and return its ID."""
|
||||
self._event_id_counter += 1
|
||||
event_id = str(self._event_id_counter)
|
||||
self._events.append((stream_id, event_id, message))
|
||||
return event_id
|
||||
|
||||
async def replay_events_after(
|
||||
self,
|
||||
last_event_id: EventId,
|
||||
send_callback: EventCallback,
|
||||
) -> StreamId | None:
|
||||
"""Replay events after the specified ID."""
|
||||
# Find the index of the last event ID
|
||||
start_index = None
|
||||
for i, (_, event_id, _) in enumerate(self._events):
|
||||
if event_id == last_event_id:
|
||||
start_index = i + 1
|
||||
break
|
||||
|
||||
if start_index is None:
|
||||
# If event ID not found, start from beginning
|
||||
start_index = 0
|
||||
|
||||
stream_id = None
|
||||
# Replay events
|
||||
for _, event_id, message in self._events[start_index:]:
|
||||
await send_callback(EventMessage(message, event_id))
|
||||
# Capture the stream ID from the first replayed event
|
||||
if stream_id is None and len(self._events) > start_index:
|
||||
stream_id = self._events[start_index][0]
|
||||
|
||||
return stream_id
|
||||
|
||||
|
||||
# Test server implementation that follows MCP protocol
|
||||
class ServerTest(Server):
|
||||
def __init__(self):
|
||||
@@ -78,25 +138,57 @@ class ServerTest(Server):
|
||||
description="A test tool that sends a notification",
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
),
|
||||
Tool(
|
||||
name="long_running_with_checkpoints",
|
||||
description="A long-running tool that sends periodic notifications",
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
),
|
||||
]
|
||||
|
||||
@self.call_tool()
|
||||
async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
|
||||
ctx = self.request_context
|
||||
|
||||
# When the tool is called, send a notification to test GET stream
|
||||
if name == "test_tool_with_standalone_notification":
|
||||
ctx = self.request_context
|
||||
await ctx.session.send_resource_updated(
|
||||
uri=AnyUrl("http://test_resource")
|
||||
)
|
||||
return [TextContent(type="text", text=f"Called {name}")]
|
||||
|
||||
elif name == "long_running_with_checkpoints":
|
||||
# Send notifications that are part of the response stream
|
||||
# This simulates a long-running tool that sends logs
|
||||
|
||||
await ctx.session.send_log_message(
|
||||
level="info",
|
||||
data="Tool started",
|
||||
logger="tool",
|
||||
related_request_id=ctx.request_id, # need for stream association
|
||||
)
|
||||
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
await ctx.session.send_log_message(
|
||||
level="info",
|
||||
data="Tool is almost done",
|
||||
logger="tool",
|
||||
related_request_id=ctx.request_id,
|
||||
)
|
||||
|
||||
return [TextContent(type="text", text="Completed!")]
|
||||
|
||||
return [TextContent(type="text", text=f"Called {name}")]
|
||||
|
||||
|
||||
def create_app(is_json_response_enabled=False) -> Starlette:
|
||||
def create_app(
|
||||
is_json_response_enabled=False, event_store: EventStore | None = None
|
||||
) -> Starlette:
|
||||
"""Create a Starlette application for testing that matches the example server.
|
||||
|
||||
Args:
|
||||
is_json_response_enabled: If True, use JSON responses instead of SSE streams.
|
||||
event_store: Optional event store for testing resumability.
|
||||
"""
|
||||
# Create server instance
|
||||
server = ServerTest()
|
||||
@@ -139,6 +231,7 @@ def create_app(is_json_response_enabled=False) -> Starlette:
|
||||
http_transport = StreamableHTTPServerTransport(
|
||||
mcp_session_id=new_session_id,
|
||||
is_json_response_enabled=is_json_response_enabled,
|
||||
event_store=event_store,
|
||||
)
|
||||
|
||||
async with http_transport.connect() as streams:
|
||||
@@ -183,15 +276,18 @@ def create_app(is_json_response_enabled=False) -> Starlette:
|
||||
return app
|
||||
|
||||
|
||||
def run_server(port: int, is_json_response_enabled=False) -> None:
|
||||
def run_server(
|
||||
port: int, is_json_response_enabled=False, event_store: EventStore | None = None
|
||||
) -> None:
|
||||
"""Run the test server.
|
||||
|
||||
Args:
|
||||
port: Port to listen on.
|
||||
is_json_response_enabled: If True, use JSON responses instead of SSE streams.
|
||||
event_store: Optional event store for testing resumability.
|
||||
"""
|
||||
|
||||
app = create_app(is_json_response_enabled)
|
||||
app = create_app(is_json_response_enabled, event_store)
|
||||
# Configure server
|
||||
config = uvicorn.Config(
|
||||
app=app,
|
||||
@@ -261,6 +357,53 @@ def basic_server(basic_server_port: int) -> Generator[None, None, None]:
|
||||
proc.join(timeout=2)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_store() -> SimpleEventStore:
|
||||
"""Create a test event store."""
|
||||
return SimpleEventStore()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_server_port() -> int:
|
||||
"""Find an available port for the event store server."""
|
||||
with socket.socket() as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_server(
|
||||
event_server_port: int, event_store: SimpleEventStore
|
||||
) -> Generator[tuple[SimpleEventStore, str], None, None]:
|
||||
"""Start a server with event store enabled."""
|
||||
proc = multiprocessing.Process(
|
||||
target=run_server,
|
||||
kwargs={"port": event_server_port, "event_store": event_store},
|
||||
daemon=True,
|
||||
)
|
||||
proc.start()
|
||||
|
||||
# Wait for server to be running
|
||||
max_attempts = 20
|
||||
attempt = 0
|
||||
while attempt < max_attempts:
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.connect(("127.0.0.1", event_server_port))
|
||||
break
|
||||
except ConnectionRefusedError:
|
||||
time.sleep(0.1)
|
||||
attempt += 1
|
||||
else:
|
||||
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
|
||||
|
||||
yield event_store, f"http://127.0.0.1:{event_server_port}"
|
||||
|
||||
# Clean up
|
||||
proc.kill()
|
||||
proc.join(timeout=2)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def json_response_server(json_server_port: int) -> Generator[None, None, None]:
|
||||
"""Start a server with JSON response enabled."""
|
||||
@@ -679,7 +822,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session)
|
||||
"""Test client tool invocation."""
|
||||
# First list tools
|
||||
tools = await initialized_client_session.list_tools()
|
||||
assert len(tools.tools) == 2
|
||||
assert len(tools.tools) == 3
|
||||
assert tools.tools[0].name == "test_tool"
|
||||
|
||||
# Call the tool
|
||||
@@ -720,7 +863,7 @@ async def test_streamablehttp_client_session_persistence(
|
||||
|
||||
# Make multiple requests to verify session persistence
|
||||
tools = await session.list_tools()
|
||||
assert len(tools.tools) == 2
|
||||
assert len(tools.tools) == 3
|
||||
|
||||
# Read a resource
|
||||
resource = await session.read_resource(uri=AnyUrl("foobar://test-persist"))
|
||||
@@ -751,7 +894,7 @@ async def test_streamablehttp_client_json_response(
|
||||
|
||||
# Check tool listing
|
||||
tools = await session.list_tools()
|
||||
assert len(tools.tools) == 2
|
||||
assert len(tools.tools) == 3
|
||||
|
||||
# Call a tool and verify JSON response handling
|
||||
result = await session.call_tool("test_tool", {})
|
||||
@@ -813,25 +956,169 @@ async def test_streamablehttp_client_session_termination(
|
||||
):
|
||||
"""Test client session termination functionality."""
|
||||
|
||||
captured_session_id = None
|
||||
|
||||
# Create the streamablehttp_client with a custom httpx client to capture headers
|
||||
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
terminate_session,
|
||||
get_session_id,
|
||||
):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
# Initialize the session
|
||||
result = await session.initialize()
|
||||
assert isinstance(result, InitializeResult)
|
||||
captured_session_id = get_session_id()
|
||||
assert captured_session_id is not None
|
||||
|
||||
# Make a request to confirm session is working
|
||||
tools = await session.list_tools()
|
||||
assert len(tools.tools) == 2
|
||||
assert len(tools.tools) == 3
|
||||
|
||||
# After exiting ClientSession context, explicitly terminate the session
|
||||
await terminate_session()
|
||||
headers = {}
|
||||
if captured_session_id:
|
||||
headers[MCP_SESSION_ID_HEADER] = captured_session_id
|
||||
|
||||
async with streamablehttp_client(f"{basic_server_url}/mcp", headers=headers) as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
_,
|
||||
):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
# Attempt to make a request after termination
|
||||
with pytest.raises(
|
||||
McpError,
|
||||
match="Session terminated",
|
||||
):
|
||||
await session.list_tools()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_streamablehttp_client_resumption(event_server):
|
||||
"""Test client session to resume a long running tool."""
|
||||
_, server_url = event_server
|
||||
|
||||
# Variables to track the state
|
||||
captured_session_id = None
|
||||
captured_resumption_token = None
|
||||
captured_notifications = []
|
||||
tool_started = False
|
||||
|
||||
async def message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
| types.ServerNotification
|
||||
| Exception,
|
||||
) -> None:
|
||||
if isinstance(message, types.ServerNotification):
|
||||
captured_notifications.append(message)
|
||||
# Look for our special notification that indicates the tool is running
|
||||
if isinstance(message.root, types.LoggingMessageNotification):
|
||||
if message.root.params.data == "Tool started":
|
||||
nonlocal tool_started
|
||||
tool_started = True
|
||||
|
||||
async def on_resumption_token_update(token: str) -> None:
|
||||
nonlocal captured_resumption_token
|
||||
captured_resumption_token = token
|
||||
|
||||
# First, start the client session and begin the long-running tool
|
||||
async with streamablehttp_client(f"{server_url}/mcp", terminate_on_close=False) as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
get_session_id,
|
||||
):
|
||||
async with ClientSession(
|
||||
read_stream, write_stream, message_handler=message_handler
|
||||
) as session:
|
||||
# Initialize the session
|
||||
result = await session.initialize()
|
||||
assert isinstance(result, InitializeResult)
|
||||
captured_session_id = get_session_id()
|
||||
assert captured_session_id is not None
|
||||
|
||||
# Start a long-running tool in a task
|
||||
async with anyio.create_task_group() as tg:
|
||||
|
||||
async def run_tool():
|
||||
metadata = ClientMessageMetadata(
|
||||
on_resumption_token_update=on_resumption_token_update,
|
||||
)
|
||||
await session.send_request(
|
||||
types.ClientRequest(
|
||||
types.CallToolRequest(
|
||||
method="tools/call",
|
||||
params=types.CallToolRequestParams(
|
||||
name="long_running_with_checkpoints", arguments={}
|
||||
),
|
||||
)
|
||||
),
|
||||
types.CallToolResult,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
tg.start_soon(run_tool)
|
||||
|
||||
# Wait for the tool to start and at least one notification
|
||||
# and then kill the task group
|
||||
while not tool_started or not captured_resumption_token:
|
||||
await anyio.sleep(0.1)
|
||||
tg.cancel_scope.cancel()
|
||||
|
||||
# Store pre notifications and clear the captured notifications
|
||||
# for the post-resumption check
|
||||
captured_notifications_pre = captured_notifications.copy()
|
||||
captured_notifications = []
|
||||
|
||||
# Now resume the session with the same mcp-session-id
|
||||
headers = {}
|
||||
if captured_session_id:
|
||||
headers[MCP_SESSION_ID_HEADER] = captured_session_id
|
||||
|
||||
async with streamablehttp_client(f"{server_url}/mcp", headers=headers) as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
_,
|
||||
):
|
||||
async with ClientSession(
|
||||
read_stream, write_stream, message_handler=message_handler
|
||||
) as session:
|
||||
# Don't initialize - just use the existing session
|
||||
|
||||
# Resume the tool with the resumption token
|
||||
assert captured_resumption_token is not None
|
||||
|
||||
metadata = ClientMessageMetadata(
|
||||
resumption_token=captured_resumption_token,
|
||||
)
|
||||
result = await session.send_request(
|
||||
types.ClientRequest(
|
||||
types.CallToolRequest(
|
||||
method="tools/call",
|
||||
params=types.CallToolRequestParams(
|
||||
name="long_running_with_checkpoints", arguments={}
|
||||
),
|
||||
)
|
||||
),
|
||||
types.CallToolResult,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
# We should get a complete result
|
||||
assert len(result.content) == 1
|
||||
assert result.content[0].type == "text"
|
||||
assert "Completed" in result.content[0].text
|
||||
|
||||
# We should have received the remaining notifications
|
||||
assert len(captured_notifications) > 0
|
||||
|
||||
# Should not have the first notification
|
||||
# Check that "Tool started" notification isn't repeated when resuming
|
||||
assert not any(
|
||||
isinstance(n.root, types.LoggingMessageNotification)
|
||||
and n.root.params.data == "Tool started"
|
||||
for n in captured_notifications
|
||||
)
|
||||
# there is no intersection between pre and post notifications
|
||||
assert not any(
|
||||
n in captured_notifications_pre for n in captured_notifications
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user