diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index 7782022..4e777d6 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -17,7 +17,13 @@ from urllib.parse import urlencode, urljoin import anyio import httpx -from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthMetadata, OAuthToken +from mcp.client.streamable_http import MCP_PROTOCOL_VERSION +from mcp.shared.auth import ( + OAuthClientInformationFull, + OAuthClientMetadata, + OAuthMetadata, + OAuthToken, +) from mcp.types import LATEST_PROTOCOL_VERSION logger = logging.getLogger(__name__) @@ -121,7 +127,7 @@ class OAuthClientProvider(httpx.Auth): # Extract base URL per MCP spec auth_base_url = self._get_authorization_base_url(server_url) url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server") - headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION} + headers = {MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION} async with httpx.AsyncClient() as client: try: diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 4718705..39ac34d 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -22,6 +22,7 @@ from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.types import ( ErrorData, + InitializeResult, JSONRPCError, JSONRPCMessage, JSONRPCNotification, @@ -39,6 +40,7 @@ StreamReader = MemoryObjectReceiveStream[SessionMessage] GetSessionIdCallback = Callable[[], str | None] MCP_SESSION_ID = "mcp-session-id" +MCP_PROTOCOL_VERSION = "mcp-protocol-version" LAST_EVENT_ID = "last-event-id" CONTENT_TYPE = "content-type" ACCEPT = "Accept" @@ -97,17 +99,20 @@ class StreamableHTTPTransport: ) self.auth = auth self.session_id = None + self.protocol_version = None self.request_headers = { ACCEPT: f"{JSON}, {SSE}", CONTENT_TYPE: JSON, **self.headers, } - def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]: - """Update headers with session ID if available.""" + def _prepare_request_headers(self, base_headers: dict[str, str]) -> dict[str, str]: + """Update headers with session ID and protocol version if available.""" headers = base_headers.copy() if self.session_id: headers[MCP_SESSION_ID] = self.session_id + if self.protocol_version: + headers[MCP_PROTOCOL_VERSION] = self.protocol_version return headers def _is_initialization_request(self, message: JSONRPCMessage) -> bool: @@ -128,12 +133,28 @@ class StreamableHTTPTransport: self.session_id = new_session_id logger.info(f"Received session ID: {self.session_id}") + def _maybe_extract_protocol_version_from_message( + self, + message: JSONRPCMessage, + ) -> None: + """Extract protocol version from initialization response message.""" + if isinstance(message.root, JSONRPCResponse) and message.root.result: + try: + # Parse the result as InitializeResult for type safety + init_result = InitializeResult.model_validate(message.root.result) + self.protocol_version = str(init_result.protocolVersion) + logger.info(f"Negotiated protocol version: {self.protocol_version}") + except Exception as exc: + logger.warning(f"Failed to parse initialization response as InitializeResult: {exc}") + logger.warning(f"Raw result: {message.root.result}") + async def _handle_sse_event( self, sse: ServerSentEvent, read_stream_writer: StreamWriter, original_request_id: RequestId | None = None, resumption_callback: Callable[[str], Awaitable[None]] | None = None, + is_initialization: bool = False, ) -> bool: """Handle an SSE event, returning True if the response is complete.""" if sse.event == "message": @@ -141,6 +162,10 @@ class StreamableHTTPTransport: message = JSONRPCMessage.model_validate_json(sse.data) logger.debug(f"SSE message: {message}") + # Extract protocol version from initialization response + if is_initialization: + self._maybe_extract_protocol_version_from_message(message) + # If this is a response and we have original_request_id, replace it if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError): message.root.id = original_request_id @@ -174,7 +199,7 @@ class StreamableHTTPTransport: if not self.session_id: return - headers = self._update_headers_with_session(self.request_headers) + headers = self._prepare_request_headers(self.request_headers) async with aconnect_sse( client, @@ -194,7 +219,7 @@ class StreamableHTTPTransport: async def _handle_resumption_request(self, ctx: RequestContext) -> None: """Handle a resumption request using GET with SSE.""" - headers = self._update_headers_with_session(ctx.headers) + headers = self._prepare_request_headers(ctx.headers) if ctx.metadata and ctx.metadata.resumption_token: headers[LAST_EVENT_ID] = ctx.metadata.resumption_token else: @@ -227,7 +252,7 @@ class StreamableHTTPTransport: async def _handle_post_request(self, ctx: RequestContext) -> None: """Handle a POST request with response processing.""" - headers = self._update_headers_with_session(ctx.headers) + headers = self._prepare_request_headers(ctx.headers) message = ctx.session_message.message is_initialization = self._is_initialization_request(message) @@ -256,9 +281,9 @@ class StreamableHTTPTransport: content_type = response.headers.get(CONTENT_TYPE, "").lower() if content_type.startswith(JSON): - await self._handle_json_response(response, ctx.read_stream_writer) + await self._handle_json_response(response, ctx.read_stream_writer, is_initialization) elif content_type.startswith(SSE): - await self._handle_sse_response(response, ctx) + await self._handle_sse_response(response, ctx, is_initialization) else: await self._handle_unexpected_content_type( content_type, @@ -269,18 +294,29 @@ class StreamableHTTPTransport: self, response: httpx.Response, read_stream_writer: StreamWriter, + is_initialization: bool = False, ) -> None: """Handle JSON response from the server.""" try: content = await response.aread() message = JSONRPCMessage.model_validate_json(content) + + # Extract protocol version from initialization response + if is_initialization: + self._maybe_extract_protocol_version_from_message(message) + session_message = SessionMessage(message) await read_stream_writer.send(session_message) except Exception as exc: logger.error(f"Error parsing JSON response: {exc}") await read_stream_writer.send(exc) - async def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None: + async def _handle_sse_response( + self, + response: httpx.Response, + ctx: RequestContext, + is_initialization: bool = False, + ) -> None: """Handle SSE response from the server.""" try: event_source = EventSource(response) @@ -289,6 +325,7 @@ class StreamableHTTPTransport: sse, ctx.read_stream_writer, resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), + is_initialization=is_initialization, ) # If the SSE event indicates completion, like returning respose/error # break the loop @@ -385,7 +422,7 @@ class StreamableHTTPTransport: return try: - headers = self._update_headers_with_session(self.request_headers) + headers = self._prepare_request_headers(self.request_headers) response = await client.delete(self.url, headers=headers) if response.status_code == 405: diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py index dff468e..8647334 100644 --- a/src/mcp/server/auth/routes.py +++ b/src/mcp/server/auth/routes.py @@ -16,6 +16,7 @@ from mcp.server.auth.handlers.token import TokenHandler from mcp.server.auth.middleware.client_auth import ClientAuthenticator from mcp.server.auth.provider import OAuthAuthorizationServerProvider from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions +from mcp.server.streamable_http import MCP_PROTOCOL_VERSION_HEADER from mcp.shared.auth import OAuthMetadata @@ -55,7 +56,7 @@ def cors_middleware( app=request_response(handler), allow_origins="*", allow_methods=allow_methods, - allow_headers=["mcp-protocol-version"], + allow_headers=[MCP_PROTOCOL_VERSION_HEADER], ) return cors_app diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 9356a99..13ee27b 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -25,7 +25,9 @@ from starlette.responses import Response from starlette.types import Receive, Scope, Send from mcp.shared.message import ServerMessageMetadata, SessionMessage +from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.types import ( + DEFAULT_NEGOTIATED_VERSION, INTERNAL_ERROR, INVALID_PARAMS, INVALID_REQUEST, @@ -45,6 +47,7 @@ MAXIMUM_MESSAGE_SIZE = 4 * 1024 * 1024 # 4MB # Header names MCP_SESSION_ID_HEADER = "mcp-session-id" +MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version" LAST_EVENT_ID_HEADER = "last-event-id" # Content types @@ -293,7 +296,7 @@ class StreamableHTTPServerTransport: has_json, has_sse = self._check_accept_headers(request) if not (has_json and has_sse): response = self._create_error_response( - ("Not Acceptable: Client must accept both application/json and " "text/event-stream"), + ("Not Acceptable: Client must accept both application/json and text/event-stream"), HTTPStatus.NOT_ACCEPTABLE, ) await response(scope, receive, send) @@ -353,8 +356,7 @@ class StreamableHTTPServerTransport: ) await response(scope, receive, send) return - # For non-initialization requests, validate the session - elif not await self._validate_session(request, send): + elif not await self._validate_request_headers(request, send): return # For notifications and responses only, return 202 Accepted @@ -513,8 +515,9 @@ class StreamableHTTPServerTransport: await response(request.scope, request.receive, send) return - if not await self._validate_session(request, send): + if not await self._validate_request_headers(request, send): return + # Handle resumability: check for Last-Event-ID header if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): await self._replay_events(last_event_id, request, send) @@ -593,7 +596,7 @@ class StreamableHTTPServerTransport: await response(request.scope, request.receive, send) return - if not await self._validate_session(request, send): + if not await self._validate_request_headers(request, send): return await self._terminate_session() @@ -653,6 +656,13 @@ class StreamableHTTPServerTransport: ) await response(request.scope, request.receive, send) + async def _validate_request_headers(self, request: Request, send: Send) -> bool: + if not await self._validate_session(request, send): + return False + if not await self._validate_protocol_version(request, send): + return False + return True + async def _validate_session(self, request: Request, send: Send) -> bool: """Validate the session ID in the request.""" if not self.mcp_session_id: @@ -682,6 +692,28 @@ class StreamableHTTPServerTransport: return True + async def _validate_protocol_version(self, request: Request, send: Send) -> bool: + """Validate the protocol version header in the request.""" + # Get the protocol version from the request headers + protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER) + + # If no protocol version provided, assume default version + if protocol_version is None: + protocol_version = DEFAULT_NEGOTIATED_VERSION + + # Check if the protocol version is supported + if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: + supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS) + response = self._create_error_response( + f"Bad Request: Unsupported protocol version: {protocol_version}. " + + f"Supported versions: {supported_versions}", + HTTPStatus.BAD_REQUEST, + ) + await response(request.scope, request.receive, send) + return False + + return True + async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: """ Replays events that would have been sent after the specified event ID. diff --git a/src/mcp/types.py b/src/mcp/types.py index 2949ed8..824cee7 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -24,6 +24,14 @@ for reference. LATEST_PROTOCOL_VERSION = "2025-03-26" +""" +The default negotiated version of the Model Context Protocol when no version is specified. +We need this to satisfy the MCP specification, which requires the server to assume a +specific version if none is provided by the client. See section "Protocol Version Header" at +https://modelcontextprotocol.io/specification +""" +DEFAULT_NEGOTIATED_VERSION = "2025-03-26" + ProgressToken = str | int Cursor = str Role = Literal["user", "assistant"] diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 615e68e..c4604e3 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -26,6 +26,7 @@ from mcp.client.session import ClientSession from mcp.client.streamable_http import streamablehttp_client from mcp.server import Server from mcp.server.streamable_http import ( + MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, SESSION_ID_PATTERN, EventCallback, @@ -64,6 +65,17 @@ INIT_REQUEST = { } +# Helper functions +def extract_protocol_version_from_sse(response: requests.Response) -> str: + """Extract the negotiated protocol version from an SSE initialization response.""" + assert response.headers.get("Content-Type") == "text/event-stream" + for line in response.text.splitlines(): + if line.startswith("data: "): + init_data = json.loads(line[6:]) + return init_data["result"]["protocolVersion"] + raise ValueError("Could not extract protocol version from SSE response") + + # Simple in-memory event store for testing class SimpleEventStore(EventStore): """Simple in-memory event store for testing.""" @@ -560,11 +572,17 @@ def test_session_termination(basic_server, basic_server_url): ) assert response.status_code == 200 + # Extract negotiated protocol version from SSE response + negotiated_version = extract_protocol_version_from_sse(response) + # Now terminate the session session_id = response.headers.get(MCP_SESSION_ID_HEADER) response = requests.delete( f"{basic_server_url}/mcp", - headers={MCP_SESSION_ID_HEADER: session_id}, + headers={ + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, ) assert response.status_code == 200 @@ -595,16 +613,20 @@ def test_response(basic_server, basic_server_url): ) assert response.status_code == 200 - # Now terminate the session + # Extract negotiated protocol version from SSE response + negotiated_version = extract_protocol_version_from_sse(response) + + # Now get the session ID session_id = response.headers.get(MCP_SESSION_ID_HEADER) - # Try to use the terminated session + # Try to use the session with proper headers tools_response = requests.post( mcp_url, headers={ "Accept": "application/json, text/event-stream", "Content-Type": "application/json", MCP_SESSION_ID_HEADER: session_id, # Use the session ID we got earlier + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, }, json={"jsonrpc": "2.0", "method": "tools/list", "id": "tools-1"}, stream=True, @@ -646,12 +668,23 @@ def test_get_sse_stream(basic_server, basic_server_url): session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) assert session_id is not None + # Extract negotiated protocol version from SSE response + init_data = None + assert init_response.headers.get("Content-Type") == "text/event-stream" + for line in init_response.text.splitlines(): + if line.startswith("data: "): + init_data = json.loads(line[6:]) + break + assert init_data is not None + negotiated_version = init_data["result"]["protocolVersion"] + # Now attempt to establish an SSE stream via GET get_response = requests.get( mcp_url, headers={ "Accept": "text/event-stream", MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, }, stream=True, ) @@ -666,6 +699,7 @@ def test_get_sse_stream(basic_server, basic_server_url): headers={ "Accept": "text/event-stream", MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, }, stream=True, ) @@ -694,11 +728,22 @@ def test_get_validation(basic_server, basic_server_url): session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) assert session_id is not None + # Extract negotiated protocol version from SSE response + init_data = None + assert init_response.headers.get("Content-Type") == "text/event-stream" + for line in init_response.text.splitlines(): + if line.startswith("data: "): + init_data = json.loads(line[6:]) + break + assert init_data is not None + negotiated_version = init_data["result"]["protocolVersion"] + # Test without Accept header response = requests.get( mcp_url, headers={ MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, }, stream=True, ) @@ -711,6 +756,7 @@ def test_get_validation(basic_server, basic_server_url): headers={ "Accept": "application/json", MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, }, ) assert response.status_code == 406 @@ -1004,6 +1050,7 @@ async def test_streamablehttp_client_resumption(event_server): captured_resumption_token = None captured_notifications = [] tool_started = False + captured_protocol_version = None async def message_handler( message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, @@ -1032,6 +1079,8 @@ async def test_streamablehttp_client_resumption(event_server): assert isinstance(result, InitializeResult) captured_session_id = get_session_id() assert captured_session_id is not None + # Capture the negotiated protocol version + captured_protocol_version = result.protocolVersion # Start a long-running tool in a task async with anyio.create_task_group() as tg: @@ -1064,10 +1113,12 @@ async def test_streamablehttp_client_resumption(event_server): captured_notifications_pre = captured_notifications.copy() captured_notifications = [] - # Now resume the session with the same mcp-session-id + # Now resume the session with the same mcp-session-id and protocol version headers = {} if captured_session_id: headers[MCP_SESSION_ID_HEADER] = captured_session_id + if captured_protocol_version: + headers[MCP_PROTOCOL_VERSION_HEADER] = captured_protocol_version async with streamablehttp_client(f"{server_url}/mcp", headers=headers) as ( read_stream, @@ -1358,3 +1409,115 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No assert ctx["headers"].get("x-request-id") == f"request-{i}" assert ctx["headers"].get("x-custom-value") == f"value-{i}" assert ctx["headers"].get("authorization") == f"Bearer token-{i}" + + +@pytest.mark.anyio +async def test_client_includes_protocol_version_header_after_init(context_aware_server, basic_server_url): + """Test that client includes mcp-protocol-version header after initialization.""" + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + # Initialize and get the negotiated version + init_result = await session.initialize() + negotiated_version = init_result.protocolVersion + + # Call a tool that echoes headers to verify the header is present + tool_result = await session.call_tool("echo_headers", {}) + + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + headers_data = json.loads(tool_result.content[0].text) + + # Verify protocol version header is present + assert "mcp-protocol-version" in headers_data + assert headers_data[MCP_PROTOCOL_VERSION_HEADER] == negotiated_version + + +def test_server_validates_protocol_version_header(basic_server, basic_server_url): + """Test that server returns 400 Bad Request version if header unsupported or invalid.""" + # First initialize a session to get a valid session ID + init_response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + + # Test request with invalid protocol version (should fail) + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: "invalid-version", + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-2"}, + ) + assert response.status_code == 400 + assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower() + + # Test request with unsupported protocol version (should fail) + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: "1999-01-01", # Very old unsupported version + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-3"}, + ) + assert response.status_code == 400 + assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower() + + # Test request with valid protocol version (should succeed) + negotiated_version = extract_protocol_version_from_sse(init_response) + + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-4"}, + ) + assert response.status_code == 200 + + +def test_server_backwards_compatibility_no_protocol_version(basic_server, basic_server_url): + """Test server accepts requests without protocol version header.""" + # First initialize a session to get a valid session ID + init_response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + + # Test request without mcp-protocol-version header (backwards compatibility) + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-backwards-compat"}, + stream=True, + ) + assert response.status_code == 200 # Should succeed for backwards compatibility + assert response.headers.get("Content-Type") == "text/event-stream"