mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 06:54:18 +01:00
feat: implement MCP-Protocol-Version header requirement for HTTP transport (#898)
This commit is contained in:
@@ -26,6 +26,7 @@ from mcp.client.session import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from mcp.server import Server
|
||||
from mcp.server.streamable_http import (
|
||||
MCP_PROTOCOL_VERSION_HEADER,
|
||||
MCP_SESSION_ID_HEADER,
|
||||
SESSION_ID_PATTERN,
|
||||
EventCallback,
|
||||
@@ -64,6 +65,17 @@ INIT_REQUEST = {
|
||||
}
|
||||
|
||||
|
||||
# Helper functions
|
||||
def extract_protocol_version_from_sse(response: requests.Response) -> str:
|
||||
"""Extract the negotiated protocol version from an SSE initialization response."""
|
||||
assert response.headers.get("Content-Type") == "text/event-stream"
|
||||
for line in response.text.splitlines():
|
||||
if line.startswith("data: "):
|
||||
init_data = json.loads(line[6:])
|
||||
return init_data["result"]["protocolVersion"]
|
||||
raise ValueError("Could not extract protocol version from SSE response")
|
||||
|
||||
|
||||
# Simple in-memory event store for testing
|
||||
class SimpleEventStore(EventStore):
|
||||
"""Simple in-memory event store for testing."""
|
||||
@@ -560,11 +572,17 @@ def test_session_termination(basic_server, basic_server_url):
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Extract negotiated protocol version from SSE response
|
||||
negotiated_version = extract_protocol_version_from_sse(response)
|
||||
|
||||
# Now terminate the session
|
||||
session_id = response.headers.get(MCP_SESSION_ID_HEADER)
|
||||
response = requests.delete(
|
||||
f"{basic_server_url}/mcp",
|
||||
headers={MCP_SESSION_ID_HEADER: session_id},
|
||||
headers={
|
||||
MCP_SESSION_ID_HEADER: session_id,
|
||||
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
@@ -595,16 +613,20 @@ def test_response(basic_server, basic_server_url):
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Now terminate the session
|
||||
# Extract negotiated protocol version from SSE response
|
||||
negotiated_version = extract_protocol_version_from_sse(response)
|
||||
|
||||
# Now get the session ID
|
||||
session_id = response.headers.get(MCP_SESSION_ID_HEADER)
|
||||
|
||||
# Try to use the terminated session
|
||||
# Try to use the session with proper headers
|
||||
tools_response = requests.post(
|
||||
mcp_url,
|
||||
headers={
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"Content-Type": "application/json",
|
||||
MCP_SESSION_ID_HEADER: session_id, # Use the session ID we got earlier
|
||||
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
|
||||
},
|
||||
json={"jsonrpc": "2.0", "method": "tools/list", "id": "tools-1"},
|
||||
stream=True,
|
||||
@@ -646,12 +668,23 @@ def test_get_sse_stream(basic_server, basic_server_url):
|
||||
session_id = init_response.headers.get(MCP_SESSION_ID_HEADER)
|
||||
assert session_id is not None
|
||||
|
||||
# Extract negotiated protocol version from SSE response
|
||||
init_data = None
|
||||
assert init_response.headers.get("Content-Type") == "text/event-stream"
|
||||
for line in init_response.text.splitlines():
|
||||
if line.startswith("data: "):
|
||||
init_data = json.loads(line[6:])
|
||||
break
|
||||
assert init_data is not None
|
||||
negotiated_version = init_data["result"]["protocolVersion"]
|
||||
|
||||
# Now attempt to establish an SSE stream via GET
|
||||
get_response = requests.get(
|
||||
mcp_url,
|
||||
headers={
|
||||
"Accept": "text/event-stream",
|
||||
MCP_SESSION_ID_HEADER: session_id,
|
||||
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
@@ -666,6 +699,7 @@ def test_get_sse_stream(basic_server, basic_server_url):
|
||||
headers={
|
||||
"Accept": "text/event-stream",
|
||||
MCP_SESSION_ID_HEADER: session_id,
|
||||
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
@@ -694,11 +728,22 @@ def test_get_validation(basic_server, basic_server_url):
|
||||
session_id = init_response.headers.get(MCP_SESSION_ID_HEADER)
|
||||
assert session_id is not None
|
||||
|
||||
# Extract negotiated protocol version from SSE response
|
||||
init_data = None
|
||||
assert init_response.headers.get("Content-Type") == "text/event-stream"
|
||||
for line in init_response.text.splitlines():
|
||||
if line.startswith("data: "):
|
||||
init_data = json.loads(line[6:])
|
||||
break
|
||||
assert init_data is not None
|
||||
negotiated_version = init_data["result"]["protocolVersion"]
|
||||
|
||||
# Test without Accept header
|
||||
response = requests.get(
|
||||
mcp_url,
|
||||
headers={
|
||||
MCP_SESSION_ID_HEADER: session_id,
|
||||
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
@@ -711,6 +756,7 @@ def test_get_validation(basic_server, basic_server_url):
|
||||
headers={
|
||||
"Accept": "application/json",
|
||||
MCP_SESSION_ID_HEADER: session_id,
|
||||
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 406
|
||||
@@ -1004,6 +1050,7 @@ async def test_streamablehttp_client_resumption(event_server):
|
||||
captured_resumption_token = None
|
||||
captured_notifications = []
|
||||
tool_started = False
|
||||
captured_protocol_version = None
|
||||
|
||||
async def message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
@@ -1032,6 +1079,8 @@ async def test_streamablehttp_client_resumption(event_server):
|
||||
assert isinstance(result, InitializeResult)
|
||||
captured_session_id = get_session_id()
|
||||
assert captured_session_id is not None
|
||||
# Capture the negotiated protocol version
|
||||
captured_protocol_version = result.protocolVersion
|
||||
|
||||
# Start a long-running tool in a task
|
||||
async with anyio.create_task_group() as tg:
|
||||
@@ -1064,10 +1113,12 @@ async def test_streamablehttp_client_resumption(event_server):
|
||||
captured_notifications_pre = captured_notifications.copy()
|
||||
captured_notifications = []
|
||||
|
||||
# Now resume the session with the same mcp-session-id
|
||||
# Now resume the session with the same mcp-session-id and protocol version
|
||||
headers = {}
|
||||
if captured_session_id:
|
||||
headers[MCP_SESSION_ID_HEADER] = captured_session_id
|
||||
if captured_protocol_version:
|
||||
headers[MCP_PROTOCOL_VERSION_HEADER] = captured_protocol_version
|
||||
|
||||
async with streamablehttp_client(f"{server_url}/mcp", headers=headers) as (
|
||||
read_stream,
|
||||
@@ -1358,3 +1409,115 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No
|
||||
assert ctx["headers"].get("x-request-id") == f"request-{i}"
|
||||
assert ctx["headers"].get("x-custom-value") == f"value-{i}"
|
||||
assert ctx["headers"].get("authorization") == f"Bearer token-{i}"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_client_includes_protocol_version_header_after_init(context_aware_server, basic_server_url):
|
||||
"""Test that client includes mcp-protocol-version header after initialization."""
|
||||
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
_,
|
||||
):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
# Initialize and get the negotiated version
|
||||
init_result = await session.initialize()
|
||||
negotiated_version = init_result.protocolVersion
|
||||
|
||||
# Call a tool that echoes headers to verify the header is present
|
||||
tool_result = await session.call_tool("echo_headers", {})
|
||||
|
||||
assert len(tool_result.content) == 1
|
||||
assert isinstance(tool_result.content[0], TextContent)
|
||||
headers_data = json.loads(tool_result.content[0].text)
|
||||
|
||||
# Verify protocol version header is present
|
||||
assert "mcp-protocol-version" in headers_data
|
||||
assert headers_data[MCP_PROTOCOL_VERSION_HEADER] == negotiated_version
|
||||
|
||||
|
||||
def test_server_validates_protocol_version_header(basic_server, basic_server_url):
|
||||
"""Test that server returns 400 Bad Request version if header unsupported or invalid."""
|
||||
# First initialize a session to get a valid session ID
|
||||
init_response = requests.post(
|
||||
f"{basic_server_url}/mcp",
|
||||
headers={
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=INIT_REQUEST,
|
||||
)
|
||||
assert init_response.status_code == 200
|
||||
session_id = init_response.headers.get(MCP_SESSION_ID_HEADER)
|
||||
|
||||
# Test request with invalid protocol version (should fail)
|
||||
response = requests.post(
|
||||
f"{basic_server_url}/mcp",
|
||||
headers={
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"Content-Type": "application/json",
|
||||
MCP_SESSION_ID_HEADER: session_id,
|
||||
MCP_PROTOCOL_VERSION_HEADER: "invalid-version",
|
||||
},
|
||||
json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-2"},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower()
|
||||
|
||||
# Test request with unsupported protocol version (should fail)
|
||||
response = requests.post(
|
||||
f"{basic_server_url}/mcp",
|
||||
headers={
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"Content-Type": "application/json",
|
||||
MCP_SESSION_ID_HEADER: session_id,
|
||||
MCP_PROTOCOL_VERSION_HEADER: "1999-01-01", # Very old unsupported version
|
||||
},
|
||||
json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-3"},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower()
|
||||
|
||||
# Test request with valid protocol version (should succeed)
|
||||
negotiated_version = extract_protocol_version_from_sse(init_response)
|
||||
|
||||
response = requests.post(
|
||||
f"{basic_server_url}/mcp",
|
||||
headers={
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"Content-Type": "application/json",
|
||||
MCP_SESSION_ID_HEADER: session_id,
|
||||
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
|
||||
},
|
||||
json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-4"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_server_backwards_compatibility_no_protocol_version(basic_server, basic_server_url):
|
||||
"""Test server accepts requests without protocol version header."""
|
||||
# First initialize a session to get a valid session ID
|
||||
init_response = requests.post(
|
||||
f"{basic_server_url}/mcp",
|
||||
headers={
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=INIT_REQUEST,
|
||||
)
|
||||
assert init_response.status_code == 200
|
||||
session_id = init_response.headers.get(MCP_SESSION_ID_HEADER)
|
||||
|
||||
# Test request without mcp-protocol-version header (backwards compatibility)
|
||||
response = requests.post(
|
||||
f"{basic_server_url}/mcp",
|
||||
headers={
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"Content-Type": "application/json",
|
||||
MCP_SESSION_ID_HEADER: session_id,
|
||||
},
|
||||
json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-backwards-compat"},
|
||||
stream=True,
|
||||
)
|
||||
assert response.status_code == 200 # Should succeed for backwards compatibility
|
||||
assert response.headers.get("Content-Type") == "text/event-stream"
|
||||
|
||||
Reference in New Issue
Block a user