mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-23 08:44:22 +01:00
feat: implement MCP-Protocol-Version header requirement for HTTP transport (#898)
This commit is contained in:
@@ -16,6 +16,7 @@ from mcp.server.auth.handlers.token import TokenHandler
|
||||
from mcp.server.auth.middleware.client_auth import ClientAuthenticator
|
||||
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
|
||||
from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions
|
||||
from mcp.server.streamable_http import MCP_PROTOCOL_VERSION_HEADER
|
||||
from mcp.shared.auth import OAuthMetadata
|
||||
|
||||
|
||||
@@ -55,7 +56,7 @@ def cors_middleware(
|
||||
app=request_response(handler),
|
||||
allow_origins="*",
|
||||
allow_methods=allow_methods,
|
||||
allow_headers=["mcp-protocol-version"],
|
||||
allow_headers=[MCP_PROTOCOL_VERSION_HEADER],
|
||||
)
|
||||
return cors_app
|
||||
|
||||
|
||||
@@ -25,7 +25,9 @@ from starlette.responses import Response
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
from mcp.shared.message import ServerMessageMetadata, SessionMessage
|
||||
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
||||
from mcp.types import (
|
||||
DEFAULT_NEGOTIATED_VERSION,
|
||||
INTERNAL_ERROR,
|
||||
INVALID_PARAMS,
|
||||
INVALID_REQUEST,
|
||||
@@ -45,6 +47,7 @@ MAXIMUM_MESSAGE_SIZE = 4 * 1024 * 1024 # 4MB
|
||||
|
||||
# Header names
|
||||
MCP_SESSION_ID_HEADER = "mcp-session-id"
|
||||
MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version"
|
||||
LAST_EVENT_ID_HEADER = "last-event-id"
|
||||
|
||||
# Content types
|
||||
@@ -293,7 +296,7 @@ class StreamableHTTPServerTransport:
|
||||
has_json, has_sse = self._check_accept_headers(request)
|
||||
if not (has_json and has_sse):
|
||||
response = self._create_error_response(
|
||||
("Not Acceptable: Client must accept both application/json and " "text/event-stream"),
|
||||
("Not Acceptable: Client must accept both application/json and text/event-stream"),
|
||||
HTTPStatus.NOT_ACCEPTABLE,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
@@ -353,8 +356,7 @@ class StreamableHTTPServerTransport:
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
# For non-initialization requests, validate the session
|
||||
elif not await self._validate_session(request, send):
|
||||
elif not await self._validate_request_headers(request, send):
|
||||
return
|
||||
|
||||
# For notifications and responses only, return 202 Accepted
|
||||
@@ -513,8 +515,9 @@ class StreamableHTTPServerTransport:
|
||||
await response(request.scope, request.receive, send)
|
||||
return
|
||||
|
||||
if not await self._validate_session(request, send):
|
||||
if not await self._validate_request_headers(request, send):
|
||||
return
|
||||
|
||||
# Handle resumability: check for Last-Event-ID header
|
||||
if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER):
|
||||
await self._replay_events(last_event_id, request, send)
|
||||
@@ -593,7 +596,7 @@ class StreamableHTTPServerTransport:
|
||||
await response(request.scope, request.receive, send)
|
||||
return
|
||||
|
||||
if not await self._validate_session(request, send):
|
||||
if not await self._validate_request_headers(request, send):
|
||||
return
|
||||
|
||||
await self._terminate_session()
|
||||
@@ -653,6 +656,13 @@ class StreamableHTTPServerTransport:
|
||||
)
|
||||
await response(request.scope, request.receive, send)
|
||||
|
||||
async def _validate_request_headers(self, request: Request, send: Send) -> bool:
|
||||
if not await self._validate_session(request, send):
|
||||
return False
|
||||
if not await self._validate_protocol_version(request, send):
|
||||
return False
|
||||
return True
|
||||
|
||||
async def _validate_session(self, request: Request, send: Send) -> bool:
|
||||
"""Validate the session ID in the request."""
|
||||
if not self.mcp_session_id:
|
||||
@@ -682,6 +692,28 @@ class StreamableHTTPServerTransport:
|
||||
|
||||
return True
|
||||
|
||||
async def _validate_protocol_version(self, request: Request, send: Send) -> bool:
|
||||
"""Validate the protocol version header in the request."""
|
||||
# Get the protocol version from the request headers
|
||||
protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER)
|
||||
|
||||
# If no protocol version provided, assume default version
|
||||
if protocol_version is None:
|
||||
protocol_version = DEFAULT_NEGOTIATED_VERSION
|
||||
|
||||
# Check if the protocol version is supported
|
||||
if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS:
|
||||
supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS)
|
||||
response = self._create_error_response(
|
||||
f"Bad Request: Unsupported protocol version: {protocol_version}. "
|
||||
+ f"Supported versions: {supported_versions}",
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
await response(request.scope, request.receive, send)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None:
|
||||
"""
|
||||
Replays events that would have been sent after the specified event ID.
|
||||
|
||||
Reference in New Issue
Block a user