feat: implement MCP-Protocol-Version header requirement for HTTP transport (#898)

This commit is contained in:
Felix Weinberger
2025-06-12 18:01:53 +01:00
committed by GitHub
parent 0bcecffc4c
commit df15e9566d
6 changed files with 268 additions and 21 deletions

View File

@@ -17,7 +17,13 @@ from urllib.parse import urlencode, urljoin
import anyio
import httpx
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthMetadata, OAuthToken
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
from mcp.shared.auth import (
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthMetadata,
OAuthToken,
)
from mcp.types import LATEST_PROTOCOL_VERSION
logger = logging.getLogger(__name__)
@@ -121,7 +127,7 @@ class OAuthClientProvider(httpx.Auth):
# Extract base URL per MCP spec
auth_base_url = self._get_authorization_base_url(server_url)
url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server")
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION}
headers = {MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}
async with httpx.AsyncClient() as client:
try:

View File

@@ -22,6 +22,7 @@ from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
from mcp.shared.message import ClientMessageMetadata, SessionMessage
from mcp.types import (
ErrorData,
InitializeResult,
JSONRPCError,
JSONRPCMessage,
JSONRPCNotification,
@@ -39,6 +40,7 @@ StreamReader = MemoryObjectReceiveStream[SessionMessage]
GetSessionIdCallback = Callable[[], str | None]
MCP_SESSION_ID = "mcp-session-id"
MCP_PROTOCOL_VERSION = "mcp-protocol-version"
LAST_EVENT_ID = "last-event-id"
CONTENT_TYPE = "content-type"
ACCEPT = "Accept"
@@ -97,17 +99,20 @@ class StreamableHTTPTransport:
)
self.auth = auth
self.session_id = None
self.protocol_version = None
self.request_headers = {
ACCEPT: f"{JSON}, {SSE}",
CONTENT_TYPE: JSON,
**self.headers,
}
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
"""Update headers with session ID if available."""
def _prepare_request_headers(self, base_headers: dict[str, str]) -> dict[str, str]:
"""Update headers with session ID and protocol version if available."""
headers = base_headers.copy()
if self.session_id:
headers[MCP_SESSION_ID] = self.session_id
if self.protocol_version:
headers[MCP_PROTOCOL_VERSION] = self.protocol_version
return headers
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
@@ -128,12 +133,28 @@ class StreamableHTTPTransport:
self.session_id = new_session_id
logger.info(f"Received session ID: {self.session_id}")
def _maybe_extract_protocol_version_from_message(
self,
message: JSONRPCMessage,
) -> None:
"""Extract protocol version from initialization response message."""
if isinstance(message.root, JSONRPCResponse) and message.root.result:
try:
# Parse the result as InitializeResult for type safety
init_result = InitializeResult.model_validate(message.root.result)
self.protocol_version = str(init_result.protocolVersion)
logger.info(f"Negotiated protocol version: {self.protocol_version}")
except Exception as exc:
logger.warning(f"Failed to parse initialization response as InitializeResult: {exc}")
logger.warning(f"Raw result: {message.root.result}")
async def _handle_sse_event(
self,
sse: ServerSentEvent,
read_stream_writer: StreamWriter,
original_request_id: RequestId | None = None,
resumption_callback: Callable[[str], Awaitable[None]] | None = None,
is_initialization: bool = False,
) -> bool:
"""Handle an SSE event, returning True if the response is complete."""
if sse.event == "message":
@@ -141,6 +162,10 @@ class StreamableHTTPTransport:
message = JSONRPCMessage.model_validate_json(sse.data)
logger.debug(f"SSE message: {message}")
# Extract protocol version from initialization response
if is_initialization:
self._maybe_extract_protocol_version_from_message(message)
# If this is a response and we have original_request_id, replace it
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
message.root.id = original_request_id
@@ -174,7 +199,7 @@ class StreamableHTTPTransport:
if not self.session_id:
return
headers = self._update_headers_with_session(self.request_headers)
headers = self._prepare_request_headers(self.request_headers)
async with aconnect_sse(
client,
@@ -194,7 +219,7 @@ class StreamableHTTPTransport:
async def _handle_resumption_request(self, ctx: RequestContext) -> None:
"""Handle a resumption request using GET with SSE."""
headers = self._update_headers_with_session(ctx.headers)
headers = self._prepare_request_headers(ctx.headers)
if ctx.metadata and ctx.metadata.resumption_token:
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
else:
@@ -227,7 +252,7 @@ class StreamableHTTPTransport:
async def _handle_post_request(self, ctx: RequestContext) -> None:
"""Handle a POST request with response processing."""
headers = self._update_headers_with_session(ctx.headers)
headers = self._prepare_request_headers(ctx.headers)
message = ctx.session_message.message
is_initialization = self._is_initialization_request(message)
@@ -256,9 +281,9 @@ class StreamableHTTPTransport:
content_type = response.headers.get(CONTENT_TYPE, "").lower()
if content_type.startswith(JSON):
await self._handle_json_response(response, ctx.read_stream_writer)
await self._handle_json_response(response, ctx.read_stream_writer, is_initialization)
elif content_type.startswith(SSE):
await self._handle_sse_response(response, ctx)
await self._handle_sse_response(response, ctx, is_initialization)
else:
await self._handle_unexpected_content_type(
content_type,
@@ -269,18 +294,29 @@ class StreamableHTTPTransport:
self,
response: httpx.Response,
read_stream_writer: StreamWriter,
is_initialization: bool = False,
) -> None:
"""Handle JSON response from the server."""
try:
content = await response.aread()
message = JSONRPCMessage.model_validate_json(content)
# Extract protocol version from initialization response
if is_initialization:
self._maybe_extract_protocol_version_from_message(message)
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
except Exception as exc:
logger.error(f"Error parsing JSON response: {exc}")
await read_stream_writer.send(exc)
async def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None:
async def _handle_sse_response(
self,
response: httpx.Response,
ctx: RequestContext,
is_initialization: bool = False,
) -> None:
"""Handle SSE response from the server."""
try:
event_source = EventSource(response)
@@ -289,6 +325,7 @@ class StreamableHTTPTransport:
sse,
ctx.read_stream_writer,
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
is_initialization=is_initialization,
)
# If the SSE event indicates completion, like returning respose/error
# break the loop
@@ -385,7 +422,7 @@ class StreamableHTTPTransport:
return
try:
headers = self._update_headers_with_session(self.request_headers)
headers = self._prepare_request_headers(self.request_headers)
response = await client.delete(self.url, headers=headers)
if response.status_code == 405:

View File

@@ -16,6 +16,7 @@ from mcp.server.auth.handlers.token import TokenHandler
from mcp.server.auth.middleware.client_auth import ClientAuthenticator
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions
from mcp.server.streamable_http import MCP_PROTOCOL_VERSION_HEADER
from mcp.shared.auth import OAuthMetadata
@@ -55,7 +56,7 @@ def cors_middleware(
app=request_response(handler),
allow_origins="*",
allow_methods=allow_methods,
allow_headers=["mcp-protocol-version"],
allow_headers=[MCP_PROTOCOL_VERSION_HEADER],
)
return cors_app

View File

@@ -25,7 +25,9 @@ from starlette.responses import Response
from starlette.types import Receive, Scope, Send
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
from mcp.types import (
DEFAULT_NEGOTIATED_VERSION,
INTERNAL_ERROR,
INVALID_PARAMS,
INVALID_REQUEST,
@@ -45,6 +47,7 @@ MAXIMUM_MESSAGE_SIZE = 4 * 1024 * 1024 # 4MB
# Header names
MCP_SESSION_ID_HEADER = "mcp-session-id"
MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version"
LAST_EVENT_ID_HEADER = "last-event-id"
# Content types
@@ -293,7 +296,7 @@ class StreamableHTTPServerTransport:
has_json, has_sse = self._check_accept_headers(request)
if not (has_json and has_sse):
response = self._create_error_response(
("Not Acceptable: Client must accept both application/json and " "text/event-stream"),
("Not Acceptable: Client must accept both application/json and text/event-stream"),
HTTPStatus.NOT_ACCEPTABLE,
)
await response(scope, receive, send)
@@ -353,8 +356,7 @@ class StreamableHTTPServerTransport:
)
await response(scope, receive, send)
return
# For non-initialization requests, validate the session
elif not await self._validate_session(request, send):
elif not await self._validate_request_headers(request, send):
return
# For notifications and responses only, return 202 Accepted
@@ -513,8 +515,9 @@ class StreamableHTTPServerTransport:
await response(request.scope, request.receive, send)
return
if not await self._validate_session(request, send):
if not await self._validate_request_headers(request, send):
return
# Handle resumability: check for Last-Event-ID header
if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER):
await self._replay_events(last_event_id, request, send)
@@ -593,7 +596,7 @@ class StreamableHTTPServerTransport:
await response(request.scope, request.receive, send)
return
if not await self._validate_session(request, send):
if not await self._validate_request_headers(request, send):
return
await self._terminate_session()
@@ -653,6 +656,13 @@ class StreamableHTTPServerTransport:
)
await response(request.scope, request.receive, send)
async def _validate_request_headers(self, request: Request, send: Send) -> bool:
if not await self._validate_session(request, send):
return False
if not await self._validate_protocol_version(request, send):
return False
return True
async def _validate_session(self, request: Request, send: Send) -> bool:
"""Validate the session ID in the request."""
if not self.mcp_session_id:
@@ -682,6 +692,28 @@ class StreamableHTTPServerTransport:
return True
async def _validate_protocol_version(self, request: Request, send: Send) -> bool:
"""Validate the protocol version header in the request."""
# Get the protocol version from the request headers
protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER)
# If no protocol version provided, assume default version
if protocol_version is None:
protocol_version = DEFAULT_NEGOTIATED_VERSION
# Check if the protocol version is supported
if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS:
supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS)
response = self._create_error_response(
f"Bad Request: Unsupported protocol version: {protocol_version}. "
+ f"Supported versions: {supported_versions}",
HTTPStatus.BAD_REQUEST,
)
await response(request.scope, request.receive, send)
return False
return True
async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None:
"""
Replays events that would have been sent after the specified event ID.

View File

@@ -24,6 +24,14 @@ for reference.
LATEST_PROTOCOL_VERSION = "2025-03-26"
"""
The default negotiated version of the Model Context Protocol when no version is specified.
We need this to satisfy the MCP specification, which requires the server to assume a
specific version if none is provided by the client. See section "Protocol Version Header" at
https://modelcontextprotocol.io/specification
"""
DEFAULT_NEGOTIATED_VERSION = "2025-03-26"
ProgressToken = str | int
Cursor = str
Role = Literal["user", "assistant"]