feat: implement MCP-Protocol-Version header requirement for HTTP transport (#898)

This commit is contained in:
Felix Weinberger
2025-06-12 18:01:53 +01:00
committed by GitHub
parent 0bcecffc4c
commit df15e9566d
6 changed files with 268 additions and 21 deletions

View File

@@ -26,6 +26,7 @@ from mcp.client.session import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from mcp.server import Server
from mcp.server.streamable_http import (
MCP_PROTOCOL_VERSION_HEADER,
MCP_SESSION_ID_HEADER,
SESSION_ID_PATTERN,
EventCallback,
@@ -64,6 +65,17 @@ INIT_REQUEST = {
}
# Helper functions
def extract_protocol_version_from_sse(response: requests.Response) -> str:
"""Extract the negotiated protocol version from an SSE initialization response."""
assert response.headers.get("Content-Type") == "text/event-stream"
for line in response.text.splitlines():
if line.startswith("data: "):
init_data = json.loads(line[6:])
return init_data["result"]["protocolVersion"]
raise ValueError("Could not extract protocol version from SSE response")
# Simple in-memory event store for testing
class SimpleEventStore(EventStore):
"""Simple in-memory event store for testing."""
@@ -560,11 +572,17 @@ def test_session_termination(basic_server, basic_server_url):
)
assert response.status_code == 200
# Extract negotiated protocol version from SSE response
negotiated_version = extract_protocol_version_from_sse(response)
# Now terminate the session
session_id = response.headers.get(MCP_SESSION_ID_HEADER)
response = requests.delete(
f"{basic_server_url}/mcp",
headers={MCP_SESSION_ID_HEADER: session_id},
headers={
MCP_SESSION_ID_HEADER: session_id,
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
},
)
assert response.status_code == 200
@@ -595,16 +613,20 @@ def test_response(basic_server, basic_server_url):
)
assert response.status_code == 200
# Now terminate the session
# Extract negotiated protocol version from SSE response
negotiated_version = extract_protocol_version_from_sse(response)
# Now get the session ID
session_id = response.headers.get(MCP_SESSION_ID_HEADER)
# Try to use the terminated session
# Try to use the session with proper headers
tools_response = requests.post(
mcp_url,
headers={
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
MCP_SESSION_ID_HEADER: session_id, # Use the session ID we got earlier
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
},
json={"jsonrpc": "2.0", "method": "tools/list", "id": "tools-1"},
stream=True,
@@ -646,12 +668,23 @@ def test_get_sse_stream(basic_server, basic_server_url):
session_id = init_response.headers.get(MCP_SESSION_ID_HEADER)
assert session_id is not None
# Extract negotiated protocol version from SSE response
init_data = None
assert init_response.headers.get("Content-Type") == "text/event-stream"
for line in init_response.text.splitlines():
if line.startswith("data: "):
init_data = json.loads(line[6:])
break
assert init_data is not None
negotiated_version = init_data["result"]["protocolVersion"]
# Now attempt to establish an SSE stream via GET
get_response = requests.get(
mcp_url,
headers={
"Accept": "text/event-stream",
MCP_SESSION_ID_HEADER: session_id,
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
},
stream=True,
)
@@ -666,6 +699,7 @@ def test_get_sse_stream(basic_server, basic_server_url):
headers={
"Accept": "text/event-stream",
MCP_SESSION_ID_HEADER: session_id,
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
},
stream=True,
)
@@ -694,11 +728,22 @@ def test_get_validation(basic_server, basic_server_url):
session_id = init_response.headers.get(MCP_SESSION_ID_HEADER)
assert session_id is not None
# Extract negotiated protocol version from SSE response
init_data = None
assert init_response.headers.get("Content-Type") == "text/event-stream"
for line in init_response.text.splitlines():
if line.startswith("data: "):
init_data = json.loads(line[6:])
break
assert init_data is not None
negotiated_version = init_data["result"]["protocolVersion"]
# Test without Accept header
response = requests.get(
mcp_url,
headers={
MCP_SESSION_ID_HEADER: session_id,
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
},
stream=True,
)
@@ -711,6 +756,7 @@ def test_get_validation(basic_server, basic_server_url):
headers={
"Accept": "application/json",
MCP_SESSION_ID_HEADER: session_id,
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
},
)
assert response.status_code == 406
@@ -1004,6 +1050,7 @@ async def test_streamablehttp_client_resumption(event_server):
captured_resumption_token = None
captured_notifications = []
tool_started = False
captured_protocol_version = None
async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
@@ -1032,6 +1079,8 @@ async def test_streamablehttp_client_resumption(event_server):
assert isinstance(result, InitializeResult)
captured_session_id = get_session_id()
assert captured_session_id is not None
# Capture the negotiated protocol version
captured_protocol_version = result.protocolVersion
# Start a long-running tool in a task
async with anyio.create_task_group() as tg:
@@ -1064,10 +1113,12 @@ async def test_streamablehttp_client_resumption(event_server):
captured_notifications_pre = captured_notifications.copy()
captured_notifications = []
# Now resume the session with the same mcp-session-id
# Now resume the session with the same mcp-session-id and protocol version
headers = {}
if captured_session_id:
headers[MCP_SESSION_ID_HEADER] = captured_session_id
if captured_protocol_version:
headers[MCP_PROTOCOL_VERSION_HEADER] = captured_protocol_version
async with streamablehttp_client(f"{server_url}/mcp", headers=headers) as (
read_stream,
@@ -1358,3 +1409,115 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No
assert ctx["headers"].get("x-request-id") == f"request-{i}"
assert ctx["headers"].get("x-custom-value") == f"value-{i}"
assert ctx["headers"].get("authorization") == f"Bearer token-{i}"
@pytest.mark.anyio
async def test_client_includes_protocol_version_header_after_init(context_aware_server, basic_server_url):
"""Test that client includes mcp-protocol-version header after initialization."""
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
read_stream,
write_stream,
_,
):
async with ClientSession(read_stream, write_stream) as session:
# Initialize and get the negotiated version
init_result = await session.initialize()
negotiated_version = init_result.protocolVersion
# Call a tool that echoes headers to verify the header is present
tool_result = await session.call_tool("echo_headers", {})
assert len(tool_result.content) == 1
assert isinstance(tool_result.content[0], TextContent)
headers_data = json.loads(tool_result.content[0].text)
# Verify protocol version header is present
assert "mcp-protocol-version" in headers_data
assert headers_data[MCP_PROTOCOL_VERSION_HEADER] == negotiated_version
def test_server_validates_protocol_version_header(basic_server, basic_server_url):
"""Test that server returns 400 Bad Request version if header unsupported or invalid."""
# First initialize a session to get a valid session ID
init_response = requests.post(
f"{basic_server_url}/mcp",
headers={
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
},
json=INIT_REQUEST,
)
assert init_response.status_code == 200
session_id = init_response.headers.get(MCP_SESSION_ID_HEADER)
# Test request with invalid protocol version (should fail)
response = requests.post(
f"{basic_server_url}/mcp",
headers={
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
MCP_SESSION_ID_HEADER: session_id,
MCP_PROTOCOL_VERSION_HEADER: "invalid-version",
},
json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-2"},
)
assert response.status_code == 400
assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower()
# Test request with unsupported protocol version (should fail)
response = requests.post(
f"{basic_server_url}/mcp",
headers={
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
MCP_SESSION_ID_HEADER: session_id,
MCP_PROTOCOL_VERSION_HEADER: "1999-01-01", # Very old unsupported version
},
json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-3"},
)
assert response.status_code == 400
assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower()
# Test request with valid protocol version (should succeed)
negotiated_version = extract_protocol_version_from_sse(init_response)
response = requests.post(
f"{basic_server_url}/mcp",
headers={
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
MCP_SESSION_ID_HEADER: session_id,
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
},
json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-4"},
)
assert response.status_code == 200
def test_server_backwards_compatibility_no_protocol_version(basic_server, basic_server_url):
"""Test server accepts requests without protocol version header."""
# First initialize a session to get a valid session ID
init_response = requests.post(
f"{basic_server_url}/mcp",
headers={
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
},
json=INIT_REQUEST,
)
assert init_response.status_code == 200
session_id = init_response.headers.get(MCP_SESSION_ID_HEADER)
# Test request without mcp-protocol-version header (backwards compatibility)
response = requests.post(
f"{basic_server_url}/mcp",
headers={
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
MCP_SESSION_ID_HEADER: session_id,
},
json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-backwards-compat"},
stream=True,
)
assert response.status_code == 200 # Should succeed for backwards compatibility
assert response.headers.get("Content-Type") == "text/event-stream"