StreamableHttp - client refactoring and resumability support (#595)

This commit is contained in:
ihrpr
2025-05-02 14:49:50 +01:00
committed by GitHub
parent cf8b66b82f
commit 74f5fcfa0d
5 changed files with 733 additions and 218 deletions

View File

@@ -261,6 +261,7 @@ class ClientSession(
read_timeout_seconds: timedelta | None = None,
) -> types.CallToolResult:
"""Send a tools/call request."""
return await self.send_request(
types.ClientRequest(
types.CallToolRequest(

View File

@@ -7,32 +7,403 @@ and session management.
"""
import logging
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass
from datetime import timedelta
from typing import Any
import anyio
import httpx
from httpx_sse import EventSource, aconnect_sse
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
from mcp.shared.message import SessionMessage
from mcp.shared.message import ClientMessageMetadata, SessionMessage
from mcp.types import (
ErrorData,
JSONRPCError,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
RequestId,
)
logger = logging.getLogger(__name__)
# Header names
MCP_SESSION_ID_HEADER = "mcp-session-id"
LAST_EVENT_ID_HEADER = "last-event-id"
# Content types
CONTENT_TYPE_JSON = "application/json"
CONTENT_TYPE_SSE = "text/event-stream"
SessionMessageOrError = SessionMessage | Exception
StreamWriter = MemoryObjectSendStream[SessionMessageOrError]
StreamReader = MemoryObjectReceiveStream[SessionMessage]
GetSessionIdCallback = Callable[[], str | None]
MCP_SESSION_ID = "mcp-session-id"
LAST_EVENT_ID = "last-event-id"
CONTENT_TYPE = "content-type"
ACCEPT = "Accept"
JSON = "application/json"
SSE = "text/event-stream"
class StreamableHTTPError(Exception):
"""Base exception for StreamableHTTP transport errors."""
pass
class ResumptionError(StreamableHTTPError):
"""Raised when resumption request is invalid."""
pass
@dataclass
class RequestContext:
"""Context for a request operation."""
client: httpx.AsyncClient
headers: dict[str, str]
session_id: str | None
session_message: SessionMessage
metadata: ClientMessageMetadata | None
read_stream_writer: StreamWriter
sse_read_timeout: timedelta
class StreamableHTTPTransport:
"""StreamableHTTP client transport implementation."""
def __init__(
self,
url: str,
headers: dict[str, Any] | None = None,
timeout: timedelta = timedelta(seconds=30),
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
) -> None:
"""Initialize the StreamableHTTP transport.
Args:
url: The endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations.
sse_read_timeout: Timeout for SSE read operations.
"""
self.url = url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self.session_id: str | None = None
self.request_headers = {
ACCEPT: f"{JSON}, {SSE}",
CONTENT_TYPE: JSON,
**self.headers,
}
def _update_headers_with_session(
self, base_headers: dict[str, str]
) -> dict[str, str]:
"""Update headers with session ID if available."""
headers = base_headers.copy()
if self.session_id:
headers[MCP_SESSION_ID] = self.session_id
return headers
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
"""Check if the message is an initialization request."""
return (
isinstance(message.root, JSONRPCRequest)
and message.root.method == "initialize"
)
def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
"""Check if the message is an initialized notification."""
return (
isinstance(message.root, JSONRPCNotification)
and message.root.method == "notifications/initialized"
)
def _maybe_extract_session_id_from_response(
self,
response: httpx.Response,
) -> None:
"""Extract and store session ID from response headers."""
new_session_id = response.headers.get(MCP_SESSION_ID)
if new_session_id:
self.session_id = new_session_id
logger.info(f"Received session ID: {self.session_id}")
async def _handle_sse_event(
self,
sse: ServerSentEvent,
read_stream_writer: StreamWriter,
original_request_id: RequestId | None = None,
resumption_callback: Callable[[str], Awaitable[None]] | None = None,
) -> bool:
"""Handle an SSE event, returning True if the response is complete."""
if sse.event == "message":
try:
message = JSONRPCMessage.model_validate_json(sse.data)
logger.debug(f"SSE message: {message}")
# If this is a response and we have original_request_id, replace it
if original_request_id is not None and isinstance(
message.root, JSONRPCResponse | JSONRPCError
):
message.root.id = original_request_id
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
# Call resumption token callback if we have an ID
if sse.id and resumption_callback:
await resumption_callback(sse.id)
# If this is a response or error return True indicating completion
# Otherwise, return False to continue listening
return isinstance(message.root, JSONRPCResponse | JSONRPCError)
except Exception as exc:
logger.error(f"Error parsing SSE message: {exc}")
await read_stream_writer.send(exc)
return False
else:
logger.warning(f"Unknown SSE event: {sse.event}")
return False
async def handle_get_stream(
self,
client: httpx.AsyncClient,
read_stream_writer: StreamWriter,
) -> None:
"""Handle GET stream for server-initiated messages."""
try:
if not self.session_id:
return
headers = self._update_headers_with_session(self.request_headers)
async with aconnect_sse(
client,
"GET",
self.url,
headers=headers,
timeout=httpx.Timeout(
self.timeout.seconds, read=self.sse_read_timeout.seconds
),
) as event_source:
event_source.response.raise_for_status()
logger.debug("GET SSE connection established")
async for sse in event_source.aiter_sse():
await self._handle_sse_event(sse, read_stream_writer)
except Exception as exc:
logger.debug(f"GET stream error (non-fatal): {exc}")
async def _handle_resumption_request(self, ctx: RequestContext) -> None:
"""Handle a resumption request using GET with SSE."""
headers = self._update_headers_with_session(ctx.headers)
if ctx.metadata and ctx.metadata.resumption_token:
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
else:
raise ResumptionError("Resumption request requires a resumption token")
# Extract original request ID to map responses
original_request_id = None
if isinstance(ctx.session_message.message.root, JSONRPCRequest):
original_request_id = ctx.session_message.message.root.id
async with aconnect_sse(
ctx.client,
"GET",
self.url,
headers=headers,
timeout=httpx.Timeout(
self.timeout.seconds, read=ctx.sse_read_timeout.seconds
),
) as event_source:
event_source.response.raise_for_status()
logger.debug("Resumption GET SSE connection established")
async for sse in event_source.aiter_sse():
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
original_request_id,
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
)
if is_complete:
break
async def _handle_post_request(self, ctx: RequestContext) -> None:
"""Handle a POST request with response processing."""
headers = self._update_headers_with_session(ctx.headers)
message = ctx.session_message.message
is_initialization = self._is_initialization_request(message)
async with ctx.client.stream(
"POST",
self.url,
json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
headers=headers,
) as response:
if response.status_code == 202:
logger.debug("Received 202 Accepted")
return
if response.status_code == 404:
if isinstance(message.root, JSONRPCRequest):
await self._send_session_terminated_error(
ctx.read_stream_writer,
message.root.id,
)
return
response.raise_for_status()
if is_initialization:
self._maybe_extract_session_id_from_response(response)
content_type = response.headers.get(CONTENT_TYPE, "").lower()
if content_type.startswith(JSON):
await self._handle_json_response(response, ctx.read_stream_writer)
elif content_type.startswith(SSE):
await self._handle_sse_response(response, ctx)
else:
await self._handle_unexpected_content_type(
content_type,
ctx.read_stream_writer,
)
async def _handle_json_response(
self,
response: httpx.Response,
read_stream_writer: StreamWriter,
) -> None:
"""Handle JSON response from the server."""
try:
content = await response.aread()
message = JSONRPCMessage.model_validate_json(content)
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
except Exception as exc:
logger.error(f"Error parsing JSON response: {exc}")
await read_stream_writer.send(exc)
async def _handle_sse_response(
self, response: httpx.Response, ctx: RequestContext
) -> None:
"""Handle SSE response from the server."""
try:
event_source = EventSource(response)
async for sse in event_source.aiter_sse():
await self._handle_sse_event(
sse,
ctx.read_stream_writer,
resumption_callback=(
ctx.metadata.on_resumption_token_update
if ctx.metadata
else None
),
)
except Exception as e:
logger.exception("Error reading SSE stream:")
await ctx.read_stream_writer.send(e)
async def _handle_unexpected_content_type(
self,
content_type: str,
read_stream_writer: StreamWriter,
) -> None:
"""Handle unexpected content type in response."""
error_msg = f"Unexpected content type: {content_type}"
logger.error(error_msg)
await read_stream_writer.send(ValueError(error_msg))
async def _send_session_terminated_error(
self,
read_stream_writer: StreamWriter,
request_id: RequestId,
) -> None:
"""Send a session terminated error response."""
jsonrpc_error = JSONRPCError(
jsonrpc="2.0",
id=request_id,
error=ErrorData(code=32600, message="Session terminated"),
)
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
await read_stream_writer.send(session_message)
async def post_writer(
self,
client: httpx.AsyncClient,
write_stream_reader: StreamReader,
read_stream_writer: StreamWriter,
write_stream: MemoryObjectSendStream[SessionMessage],
start_get_stream: Callable[[], None],
) -> None:
"""Handle writing requests to the server."""
try:
async with write_stream_reader:
async for session_message in write_stream_reader:
message = session_message.message
metadata = (
session_message.metadata
if isinstance(session_message.metadata, ClientMessageMetadata)
else None
)
# Check if this is a resumption request
is_resumption = bool(metadata and metadata.resumption_token)
logger.debug(f"Sending client message: {message}")
# Handle initialized notification
if self._is_initialized_notification(message):
start_get_stream()
ctx = RequestContext(
client=client,
headers=self.request_headers,
session_id=self.session_id,
session_message=session_message,
metadata=metadata,
read_stream_writer=read_stream_writer,
sse_read_timeout=self.sse_read_timeout,
)
if is_resumption:
await self._handle_resumption_request(ctx)
else:
await self._handle_post_request(ctx)
except Exception as exc:
logger.error(f"Error in post_writer: {exc}")
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
async def terminate_session(self, client: httpx.AsyncClient) -> None:
"""Terminate the session by sending a DELETE request."""
if not self.session_id:
return
try:
headers = self._update_headers_with_session(self.request_headers)
response = await client.delete(self.url, headers=headers)
if response.status_code == 405:
logger.debug("Server does not allow session termination")
elif response.status_code != 200:
logger.warning(f"Session termination failed: {response.status_code}")
except Exception as exc:
logger.warning(f"Session termination failed: {exc}")
def get_session_id(self) -> str | None:
"""Get the current session ID."""
return self.session_id
@asynccontextmanager
@@ -41,7 +412,15 @@ async def streamablehttp_client(
headers: dict[str, Any] | None = None,
timeout: timedelta = timedelta(seconds=30),
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
):
terminate_on_close: bool = True,
) -> AsyncGenerator[
tuple[
MemoryObjectReceiveStream[SessionMessage | Exception],
MemoryObjectSendStream[SessionMessage],
GetSessionIdCallback,
],
None,
]:
"""
Client transport for StreamableHTTP.
@@ -49,8 +428,12 @@ async def streamablehttp_client(
event before disconnecting. All other HTTP operations are controlled by `timeout`.
Yields:
Tuple of (read_stream, write_stream, terminate_callback)
Tuple containing:
- read_stream: Stream for reading messages from the server
- write_stream: Stream for sending messages to the server
- get_session_id_callback: Function to retrieve the current session ID
"""
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout)
read_stream_writer, read_stream = anyio.create_memory_object_stream[
SessionMessage | Exception
@@ -59,208 +442,41 @@ async def streamablehttp_client(
SessionMessage
](0)
async def get_stream():
"""
Optional GET stream for server-initiated messages
"""
nonlocal session_id
try:
# Only attempt GET if we have a session ID
if not session_id:
return
get_headers = request_headers.copy()
get_headers[MCP_SESSION_ID_HEADER] = session_id
async with aconnect_sse(
client,
"GET",
url,
headers=get_headers,
timeout=httpx.Timeout(timeout.seconds, read=sse_read_timeout.seconds),
) as event_source:
event_source.response.raise_for_status()
logger.debug("GET SSE connection established")
async for sse in event_source.aiter_sse():
if sse.event == "message":
try:
message = JSONRPCMessage.model_validate_json(sse.data)
logger.debug(f"GET message: {message}")
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
except Exception as exc:
logger.error(f"Error parsing GET message: {exc}")
await read_stream_writer.send(exc)
else:
logger.warning(f"Unknown SSE event from GET: {sse.event}")
except Exception as exc:
# GET stream is optional, so don't propagate errors
logger.debug(f"GET stream error (non-fatal): {exc}")
async def post_writer(client: httpx.AsyncClient):
nonlocal session_id
try:
async with write_stream_reader:
async for session_message in write_stream_reader:
message = session_message.message
# Add session ID to headers if we have one
post_headers = request_headers.copy()
if session_id:
post_headers[MCP_SESSION_ID_HEADER] = session_id
logger.debug(f"Sending client message: {message}")
# Handle initial initialization request
is_initialization = (
isinstance(message.root, JSONRPCRequest)
and message.root.method == "initialize"
)
if (
isinstance(message.root, JSONRPCNotification)
and message.root.method == "notifications/initialized"
):
tg.start_soon(get_stream)
async with client.stream(
"POST",
url,
json=message.model_dump(
by_alias=True, mode="json", exclude_none=True
),
headers=post_headers,
) as response:
if response.status_code == 202:
logger.debug("Received 202 Accepted")
continue
# Check for 404 (session expired/invalid)
if response.status_code == 404:
if isinstance(message.root, JSONRPCRequest):
jsonrpc_error = JSONRPCError(
jsonrpc="2.0",
id=message.root.id,
error=ErrorData(
code=32600,
message="Session terminated",
),
)
session_message = SessionMessage(
JSONRPCMessage(jsonrpc_error)
)
await read_stream_writer.send(session_message)
continue
response.raise_for_status()
# Extract session ID from response headers
if is_initialization:
new_session_id = response.headers.get(MCP_SESSION_ID_HEADER)
if new_session_id:
session_id = new_session_id
logger.info(f"Received session ID: {session_id}")
# Handle different response types
content_type = response.headers.get("content-type", "").lower()
if content_type.startswith(CONTENT_TYPE_JSON):
try:
content = await response.aread()
json_message = JSONRPCMessage.model_validate_json(
content
)
session_message = SessionMessage(json_message)
await read_stream_writer.send(session_message)
except Exception as exc:
logger.error(f"Error parsing JSON response: {exc}")
await read_stream_writer.send(exc)
elif content_type.startswith(CONTENT_TYPE_SSE):
# Parse SSE events from the response
try:
event_source = EventSource(response)
async for sse in event_source.aiter_sse():
if sse.event == "message":
try:
message = (
JSONRPCMessage.model_validate_json(
sse.data
)
)
session_message = SessionMessage(message)
await read_stream_writer.send(
session_message
)
except Exception as exc:
logger.exception("Error parsing message")
await read_stream_writer.send(exc)
else:
logger.warning(f"Unknown event: {sse.event}")
except Exception as e:
logger.exception("Error reading SSE stream:")
await read_stream_writer.send(e)
else:
# For 202 Accepted with no body
if response.status_code == 202:
logger.debug("Received 202 Accepted")
continue
error_msg = f"Unexpected content type: {content_type}"
logger.error(error_msg)
await read_stream_writer.send(ValueError(error_msg))
except Exception as exc:
logger.error(f"Error in post_writer: {exc}")
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
async def terminate_session():
"""
Terminate the session by sending a DELETE request.
"""
nonlocal session_id
if not session_id:
return # No session to terminate
try:
delete_headers = request_headers.copy()
delete_headers[MCP_SESSION_ID_HEADER] = session_id
response = await client.delete(
url,
headers=delete_headers,
)
if response.status_code == 405:
# Server doesn't allow client-initiated termination
logger.debug("Server does not allow session termination")
elif response.status_code != 200:
logger.warning(f"Session termination failed: {response.status_code}")
except Exception as exc:
logger.warning(f"Session termination failed: {exc}")
async with anyio.create_task_group() as tg:
try:
logger.info(f"Connecting to StreamableHTTP endpoint: {url}")
# Set up headers with required Accept header
request_headers = {
"Accept": f"{CONTENT_TYPE_JSON}, {CONTENT_TYPE_SSE}",
"Content-Type": CONTENT_TYPE_JSON,
**(headers or {}),
}
# Track session ID if provided by server
session_id: str | None = None
async with httpx.AsyncClient(
headers=request_headers,
timeout=httpx.Timeout(timeout.seconds, read=sse_read_timeout.seconds),
headers=transport.request_headers,
timeout=httpx.Timeout(
transport.timeout.seconds, read=transport.sse_read_timeout.seconds
),
follow_redirects=True,
) as client:
tg.start_soon(post_writer, client)
# Define callbacks that need access to tg
def start_get_stream() -> None:
tg.start_soon(
transport.handle_get_stream, client, read_stream_writer
)
tg.start_soon(
transport.post_writer,
client,
write_stream_reader,
read_stream_writer,
write_stream,
start_get_stream,
)
try:
yield read_stream, write_stream, terminate_session
yield (
read_stream,
write_stream,
transport.get_session_id,
)
finally:
if transport.session_id and terminate_on_close:
await transport.terminate_session(client)
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()

View File

@@ -5,16 +5,24 @@ This module defines a wrapper type that combines JSONRPCMessage with metadata
to support transport-specific features like resumability.
"""
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from mcp.types import JSONRPCMessage, RequestId
ResumptionToken = str
ResumptionTokenUpdateCallback = Callable[[ResumptionToken], Awaitable[None]]
@dataclass
class ClientMessageMetadata:
"""Metadata specific to client messages."""
resumption_token: str | None = None
resumption_token: ResumptionToken | None = None
on_resumption_token_update: Callable[[ResumptionToken], Awaitable[None]] | None = (
None
)
@dataclass

View File

@@ -12,7 +12,7 @@ from pydantic import BaseModel
from typing_extensions import Self
from mcp.shared.exceptions import McpError
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
from mcp.types import (
CancelledNotification,
ClientNotification,
@@ -213,6 +213,7 @@ class BaseSession(
request: SendRequestT,
result_type: type[ReceiveResultT],
request_read_timeout_seconds: timedelta | None = None,
metadata: MessageMetadata = None,
) -> ReceiveResultT:
"""
Sends a request and wait for a response. Raises an McpError if the
@@ -241,7 +242,9 @@ class BaseSession(
# TODO: Support progress callbacks
await self._write_stream.send(
SessionMessage(message=JSONRPCMessage(jsonrpc_request))
SessionMessage(
message=JSONRPCMessage(jsonrpc_request), metadata=metadata
)
)
# request read timeout takes precedence over session read timeout

View File

@@ -23,16 +23,31 @@ from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Mount
import mcp.types as types
from mcp.client.session import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from mcp.server import Server
from mcp.server.streamable_http import (
MCP_SESSION_ID_HEADER,
SESSION_ID_PATTERN,
EventCallback,
EventId,
EventMessage,
EventStore,
StreamableHTTPServerTransport,
StreamId,
)
from mcp.shared.exceptions import McpError
from mcp.types import InitializeResult, TextContent, TextResourceContents, Tool
from mcp.shared.message import (
ClientMessageMetadata,
)
from mcp.shared.session import RequestResponder
from mcp.types import (
InitializeResult,
TextContent,
TextResourceContents,
Tool,
)
# Test constants
SERVER_NAME = "test_streamable_http_server"
@@ -49,6 +64,51 @@ INIT_REQUEST = {
}
# Simple in-memory event store for testing
class SimpleEventStore(EventStore):
"""Simple in-memory event store for testing."""
def __init__(self):
self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage]] = []
self._event_id_counter = 0
async def store_event(
self, stream_id: StreamId, message: types.JSONRPCMessage
) -> EventId:
"""Store an event and return its ID."""
self._event_id_counter += 1
event_id = str(self._event_id_counter)
self._events.append((stream_id, event_id, message))
return event_id
async def replay_events_after(
self,
last_event_id: EventId,
send_callback: EventCallback,
) -> StreamId | None:
"""Replay events after the specified ID."""
# Find the index of the last event ID
start_index = None
for i, (_, event_id, _) in enumerate(self._events):
if event_id == last_event_id:
start_index = i + 1
break
if start_index is None:
# If event ID not found, start from beginning
start_index = 0
stream_id = None
# Replay events
for _, event_id, message in self._events[start_index:]:
await send_callback(EventMessage(message, event_id))
# Capture the stream ID from the first replayed event
if stream_id is None and len(self._events) > start_index:
stream_id = self._events[start_index][0]
return stream_id
# Test server implementation that follows MCP protocol
class ServerTest(Server):
def __init__(self):
@@ -78,25 +138,57 @@ class ServerTest(Server):
description="A test tool that sends a notification",
inputSchema={"type": "object", "properties": {}},
),
Tool(
name="long_running_with_checkpoints",
description="A long-running tool that sends periodic notifications",
inputSchema={"type": "object", "properties": {}},
),
]
@self.call_tool()
async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
ctx = self.request_context
# When the tool is called, send a notification to test GET stream
if name == "test_tool_with_standalone_notification":
ctx = self.request_context
await ctx.session.send_resource_updated(
uri=AnyUrl("http://test_resource")
)
return [TextContent(type="text", text=f"Called {name}")]
elif name == "long_running_with_checkpoints":
# Send notifications that are part of the response stream
# This simulates a long-running tool that sends logs
await ctx.session.send_log_message(
level="info",
data="Tool started",
logger="tool",
related_request_id=ctx.request_id, # need for stream association
)
await anyio.sleep(0.1)
await ctx.session.send_log_message(
level="info",
data="Tool is almost done",
logger="tool",
related_request_id=ctx.request_id,
)
return [TextContent(type="text", text="Completed!")]
return [TextContent(type="text", text=f"Called {name}")]
def create_app(is_json_response_enabled=False) -> Starlette:
def create_app(
is_json_response_enabled=False, event_store: EventStore | None = None
) -> Starlette:
"""Create a Starlette application for testing that matches the example server.
Args:
is_json_response_enabled: If True, use JSON responses instead of SSE streams.
event_store: Optional event store for testing resumability.
"""
# Create server instance
server = ServerTest()
@@ -139,6 +231,7 @@ def create_app(is_json_response_enabled=False) -> Starlette:
http_transport = StreamableHTTPServerTransport(
mcp_session_id=new_session_id,
is_json_response_enabled=is_json_response_enabled,
event_store=event_store,
)
async with http_transport.connect() as streams:
@@ -183,15 +276,18 @@ def create_app(is_json_response_enabled=False) -> Starlette:
return app
def run_server(port: int, is_json_response_enabled=False) -> None:
def run_server(
port: int, is_json_response_enabled=False, event_store: EventStore | None = None
) -> None:
"""Run the test server.
Args:
port: Port to listen on.
is_json_response_enabled: If True, use JSON responses instead of SSE streams.
event_store: Optional event store for testing resumability.
"""
app = create_app(is_json_response_enabled)
app = create_app(is_json_response_enabled, event_store)
# Configure server
config = uvicorn.Config(
app=app,
@@ -261,6 +357,53 @@ def basic_server(basic_server_port: int) -> Generator[None, None, None]:
proc.join(timeout=2)
@pytest.fixture
def event_store() -> SimpleEventStore:
"""Create a test event store."""
return SimpleEventStore()
@pytest.fixture
def event_server_port() -> int:
"""Find an available port for the event store server."""
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
@pytest.fixture
def event_server(
event_server_port: int, event_store: SimpleEventStore
) -> Generator[tuple[SimpleEventStore, str], None, None]:
"""Start a server with event store enabled."""
proc = multiprocessing.Process(
target=run_server,
kwargs={"port": event_server_port, "event_store": event_store},
daemon=True,
)
proc.start()
# Wait for server to be running
max_attempts = 20
attempt = 0
while attempt < max_attempts:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("127.0.0.1", event_server_port))
break
except ConnectionRefusedError:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
yield event_store, f"http://127.0.0.1:{event_server_port}"
# Clean up
proc.kill()
proc.join(timeout=2)
@pytest.fixture
def json_response_server(json_server_port: int) -> Generator[None, None, None]:
"""Start a server with JSON response enabled."""
@@ -679,7 +822,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session)
"""Test client tool invocation."""
# First list tools
tools = await initialized_client_session.list_tools()
assert len(tools.tools) == 2
assert len(tools.tools) == 3
assert tools.tools[0].name == "test_tool"
# Call the tool
@@ -720,7 +863,7 @@ async def test_streamablehttp_client_session_persistence(
# Make multiple requests to verify session persistence
tools = await session.list_tools()
assert len(tools.tools) == 2
assert len(tools.tools) == 3
# Read a resource
resource = await session.read_resource(uri=AnyUrl("foobar://test-persist"))
@@ -751,7 +894,7 @@ async def test_streamablehttp_client_json_response(
# Check tool listing
tools = await session.list_tools()
assert len(tools.tools) == 2
assert len(tools.tools) == 3
# Call a tool and verify JSON response handling
result = await session.call_tool("test_tool", {})
@@ -813,25 +956,169 @@ async def test_streamablehttp_client_session_termination(
):
"""Test client session termination functionality."""
captured_session_id = None
# Create the streamablehttp_client with a custom httpx client to capture headers
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
read_stream,
write_stream,
terminate_session,
get_session_id,
):
async with ClientSession(read_stream, write_stream) as session:
# Initialize the session
result = await session.initialize()
assert isinstance(result, InitializeResult)
captured_session_id = get_session_id()
assert captured_session_id is not None
# Make a request to confirm session is working
tools = await session.list_tools()
assert len(tools.tools) == 2
assert len(tools.tools) == 3
# After exiting ClientSession context, explicitly terminate the session
await terminate_session()
headers = {}
if captured_session_id:
headers[MCP_SESSION_ID_HEADER] = captured_session_id
async with streamablehttp_client(f"{basic_server_url}/mcp", headers=headers) as (
read_stream,
write_stream,
_,
):
async with ClientSession(read_stream, write_stream) as session:
# Attempt to make a request after termination
with pytest.raises(
McpError,
match="Session terminated",
):
await session.list_tools()
@pytest.mark.anyio
async def test_streamablehttp_client_resumption(event_server):
"""Test client session to resume a long running tool."""
_, server_url = event_server
# Variables to track the state
captured_session_id = None
captured_resumption_token = None
captured_notifications = []
tool_started = False
async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
) -> None:
if isinstance(message, types.ServerNotification):
captured_notifications.append(message)
# Look for our special notification that indicates the tool is running
if isinstance(message.root, types.LoggingMessageNotification):
if message.root.params.data == "Tool started":
nonlocal tool_started
tool_started = True
async def on_resumption_token_update(token: str) -> None:
nonlocal captured_resumption_token
captured_resumption_token = token
# First, start the client session and begin the long-running tool
async with streamablehttp_client(f"{server_url}/mcp", terminate_on_close=False) as (
read_stream,
write_stream,
get_session_id,
):
async with ClientSession(
read_stream, write_stream, message_handler=message_handler
) as session:
# Initialize the session
result = await session.initialize()
assert isinstance(result, InitializeResult)
captured_session_id = get_session_id()
assert captured_session_id is not None
# Start a long-running tool in a task
async with anyio.create_task_group() as tg:
async def run_tool():
metadata = ClientMessageMetadata(
on_resumption_token_update=on_resumption_token_update,
)
await session.send_request(
types.ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(
name="long_running_with_checkpoints", arguments={}
),
)
),
types.CallToolResult,
metadata=metadata,
)
tg.start_soon(run_tool)
# Wait for the tool to start and at least one notification
# and then kill the task group
while not tool_started or not captured_resumption_token:
await anyio.sleep(0.1)
tg.cancel_scope.cancel()
# Store pre notifications and clear the captured notifications
# for the post-resumption check
captured_notifications_pre = captured_notifications.copy()
captured_notifications = []
# Now resume the session with the same mcp-session-id
headers = {}
if captured_session_id:
headers[MCP_SESSION_ID_HEADER] = captured_session_id
async with streamablehttp_client(f"{server_url}/mcp", headers=headers) as (
read_stream,
write_stream,
_,
):
async with ClientSession(
read_stream, write_stream, message_handler=message_handler
) as session:
# Don't initialize - just use the existing session
# Resume the tool with the resumption token
assert captured_resumption_token is not None
metadata = ClientMessageMetadata(
resumption_token=captured_resumption_token,
)
result = await session.send_request(
types.ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(
name="long_running_with_checkpoints", arguments={}
),
)
),
types.CallToolResult,
metadata=metadata,
)
# We should get a complete result
assert len(result.content) == 1
assert result.content[0].type == "text"
assert "Completed" in result.content[0].text
# We should have received the remaining notifications
assert len(captured_notifications) > 0
# Should not have the first notification
# Check that "Tool started" notification isn't repeated when resuming
assert not any(
isinstance(n.root, types.LoggingMessageNotification)
and n.root.params.data == "Tool started"
for n in captured_notifications
)
# there is no intersection between pre and post notifications
assert not any(
n in captured_notifications_pre for n in captured_notifications
)