mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 06:54:18 +01:00
feat: implement MCP-Protocol-Version header requirement for HTTP transport (#898)
This commit is contained in:
@@ -17,7 +17,13 @@ from urllib.parse import urlencode, urljoin
|
|||||||
import anyio
|
import anyio
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthMetadata, OAuthToken
|
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
|
||||||
|
from mcp.shared.auth import (
|
||||||
|
OAuthClientInformationFull,
|
||||||
|
OAuthClientMetadata,
|
||||||
|
OAuthMetadata,
|
||||||
|
OAuthToken,
|
||||||
|
)
|
||||||
from mcp.types import LATEST_PROTOCOL_VERSION
|
from mcp.types import LATEST_PROTOCOL_VERSION
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -121,7 +127,7 @@ class OAuthClientProvider(httpx.Auth):
|
|||||||
# Extract base URL per MCP spec
|
# Extract base URL per MCP spec
|
||||||
auth_base_url = self._get_authorization_base_url(server_url)
|
auth_base_url = self._get_authorization_base_url(server_url)
|
||||||
url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server")
|
url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server")
|
||||||
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION}
|
headers = {MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
|
|||||||
from mcp.shared.message import ClientMessageMetadata, SessionMessage
|
from mcp.shared.message import ClientMessageMetadata, SessionMessage
|
||||||
from mcp.types import (
|
from mcp.types import (
|
||||||
ErrorData,
|
ErrorData,
|
||||||
|
InitializeResult,
|
||||||
JSONRPCError,
|
JSONRPCError,
|
||||||
JSONRPCMessage,
|
JSONRPCMessage,
|
||||||
JSONRPCNotification,
|
JSONRPCNotification,
|
||||||
@@ -39,6 +40,7 @@ StreamReader = MemoryObjectReceiveStream[SessionMessage]
|
|||||||
GetSessionIdCallback = Callable[[], str | None]
|
GetSessionIdCallback = Callable[[], str | None]
|
||||||
|
|
||||||
MCP_SESSION_ID = "mcp-session-id"
|
MCP_SESSION_ID = "mcp-session-id"
|
||||||
|
MCP_PROTOCOL_VERSION = "mcp-protocol-version"
|
||||||
LAST_EVENT_ID = "last-event-id"
|
LAST_EVENT_ID = "last-event-id"
|
||||||
CONTENT_TYPE = "content-type"
|
CONTENT_TYPE = "content-type"
|
||||||
ACCEPT = "Accept"
|
ACCEPT = "Accept"
|
||||||
@@ -97,17 +99,20 @@ class StreamableHTTPTransport:
|
|||||||
)
|
)
|
||||||
self.auth = auth
|
self.auth = auth
|
||||||
self.session_id = None
|
self.session_id = None
|
||||||
|
self.protocol_version = None
|
||||||
self.request_headers = {
|
self.request_headers = {
|
||||||
ACCEPT: f"{JSON}, {SSE}",
|
ACCEPT: f"{JSON}, {SSE}",
|
||||||
CONTENT_TYPE: JSON,
|
CONTENT_TYPE: JSON,
|
||||||
**self.headers,
|
**self.headers,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
|
def _prepare_request_headers(self, base_headers: dict[str, str]) -> dict[str, str]:
|
||||||
"""Update headers with session ID if available."""
|
"""Update headers with session ID and protocol version if available."""
|
||||||
headers = base_headers.copy()
|
headers = base_headers.copy()
|
||||||
if self.session_id:
|
if self.session_id:
|
||||||
headers[MCP_SESSION_ID] = self.session_id
|
headers[MCP_SESSION_ID] = self.session_id
|
||||||
|
if self.protocol_version:
|
||||||
|
headers[MCP_PROTOCOL_VERSION] = self.protocol_version
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
|
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
|
||||||
@@ -128,12 +133,28 @@ class StreamableHTTPTransport:
|
|||||||
self.session_id = new_session_id
|
self.session_id = new_session_id
|
||||||
logger.info(f"Received session ID: {self.session_id}")
|
logger.info(f"Received session ID: {self.session_id}")
|
||||||
|
|
||||||
|
def _maybe_extract_protocol_version_from_message(
|
||||||
|
self,
|
||||||
|
message: JSONRPCMessage,
|
||||||
|
) -> None:
|
||||||
|
"""Extract protocol version from initialization response message."""
|
||||||
|
if isinstance(message.root, JSONRPCResponse) and message.root.result:
|
||||||
|
try:
|
||||||
|
# Parse the result as InitializeResult for type safety
|
||||||
|
init_result = InitializeResult.model_validate(message.root.result)
|
||||||
|
self.protocol_version = str(init_result.protocolVersion)
|
||||||
|
logger.info(f"Negotiated protocol version: {self.protocol_version}")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"Failed to parse initialization response as InitializeResult: {exc}")
|
||||||
|
logger.warning(f"Raw result: {message.root.result}")
|
||||||
|
|
||||||
async def _handle_sse_event(
|
async def _handle_sse_event(
|
||||||
self,
|
self,
|
||||||
sse: ServerSentEvent,
|
sse: ServerSentEvent,
|
||||||
read_stream_writer: StreamWriter,
|
read_stream_writer: StreamWriter,
|
||||||
original_request_id: RequestId | None = None,
|
original_request_id: RequestId | None = None,
|
||||||
resumption_callback: Callable[[str], Awaitable[None]] | None = None,
|
resumption_callback: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
is_initialization: bool = False,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Handle an SSE event, returning True if the response is complete."""
|
"""Handle an SSE event, returning True if the response is complete."""
|
||||||
if sse.event == "message":
|
if sse.event == "message":
|
||||||
@@ -141,6 +162,10 @@ class StreamableHTTPTransport:
|
|||||||
message = JSONRPCMessage.model_validate_json(sse.data)
|
message = JSONRPCMessage.model_validate_json(sse.data)
|
||||||
logger.debug(f"SSE message: {message}")
|
logger.debug(f"SSE message: {message}")
|
||||||
|
|
||||||
|
# Extract protocol version from initialization response
|
||||||
|
if is_initialization:
|
||||||
|
self._maybe_extract_protocol_version_from_message(message)
|
||||||
|
|
||||||
# If this is a response and we have original_request_id, replace it
|
# If this is a response and we have original_request_id, replace it
|
||||||
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
|
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
|
||||||
message.root.id = original_request_id
|
message.root.id = original_request_id
|
||||||
@@ -174,7 +199,7 @@ class StreamableHTTPTransport:
|
|||||||
if not self.session_id:
|
if not self.session_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
headers = self._update_headers_with_session(self.request_headers)
|
headers = self._prepare_request_headers(self.request_headers)
|
||||||
|
|
||||||
async with aconnect_sse(
|
async with aconnect_sse(
|
||||||
client,
|
client,
|
||||||
@@ -194,7 +219,7 @@ class StreamableHTTPTransport:
|
|||||||
|
|
||||||
async def _handle_resumption_request(self, ctx: RequestContext) -> None:
|
async def _handle_resumption_request(self, ctx: RequestContext) -> None:
|
||||||
"""Handle a resumption request using GET with SSE."""
|
"""Handle a resumption request using GET with SSE."""
|
||||||
headers = self._update_headers_with_session(ctx.headers)
|
headers = self._prepare_request_headers(ctx.headers)
|
||||||
if ctx.metadata and ctx.metadata.resumption_token:
|
if ctx.metadata and ctx.metadata.resumption_token:
|
||||||
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
|
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
|
||||||
else:
|
else:
|
||||||
@@ -227,7 +252,7 @@ class StreamableHTTPTransport:
|
|||||||
|
|
||||||
async def _handle_post_request(self, ctx: RequestContext) -> None:
|
async def _handle_post_request(self, ctx: RequestContext) -> None:
|
||||||
"""Handle a POST request with response processing."""
|
"""Handle a POST request with response processing."""
|
||||||
headers = self._update_headers_with_session(ctx.headers)
|
headers = self._prepare_request_headers(ctx.headers)
|
||||||
message = ctx.session_message.message
|
message = ctx.session_message.message
|
||||||
is_initialization = self._is_initialization_request(message)
|
is_initialization = self._is_initialization_request(message)
|
||||||
|
|
||||||
@@ -256,9 +281,9 @@ class StreamableHTTPTransport:
|
|||||||
content_type = response.headers.get(CONTENT_TYPE, "").lower()
|
content_type = response.headers.get(CONTENT_TYPE, "").lower()
|
||||||
|
|
||||||
if content_type.startswith(JSON):
|
if content_type.startswith(JSON):
|
||||||
await self._handle_json_response(response, ctx.read_stream_writer)
|
await self._handle_json_response(response, ctx.read_stream_writer, is_initialization)
|
||||||
elif content_type.startswith(SSE):
|
elif content_type.startswith(SSE):
|
||||||
await self._handle_sse_response(response, ctx)
|
await self._handle_sse_response(response, ctx, is_initialization)
|
||||||
else:
|
else:
|
||||||
await self._handle_unexpected_content_type(
|
await self._handle_unexpected_content_type(
|
||||||
content_type,
|
content_type,
|
||||||
@@ -269,18 +294,29 @@ class StreamableHTTPTransport:
|
|||||||
self,
|
self,
|
||||||
response: httpx.Response,
|
response: httpx.Response,
|
||||||
read_stream_writer: StreamWriter,
|
read_stream_writer: StreamWriter,
|
||||||
|
is_initialization: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle JSON response from the server."""
|
"""Handle JSON response from the server."""
|
||||||
try:
|
try:
|
||||||
content = await response.aread()
|
content = await response.aread()
|
||||||
message = JSONRPCMessage.model_validate_json(content)
|
message = JSONRPCMessage.model_validate_json(content)
|
||||||
|
|
||||||
|
# Extract protocol version from initialization response
|
||||||
|
if is_initialization:
|
||||||
|
self._maybe_extract_protocol_version_from_message(message)
|
||||||
|
|
||||||
session_message = SessionMessage(message)
|
session_message = SessionMessage(message)
|
||||||
await read_stream_writer.send(session_message)
|
await read_stream_writer.send(session_message)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error(f"Error parsing JSON response: {exc}")
|
logger.error(f"Error parsing JSON response: {exc}")
|
||||||
await read_stream_writer.send(exc)
|
await read_stream_writer.send(exc)
|
||||||
|
|
||||||
async def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None:
|
async def _handle_sse_response(
|
||||||
|
self,
|
||||||
|
response: httpx.Response,
|
||||||
|
ctx: RequestContext,
|
||||||
|
is_initialization: bool = False,
|
||||||
|
) -> None:
|
||||||
"""Handle SSE response from the server."""
|
"""Handle SSE response from the server."""
|
||||||
try:
|
try:
|
||||||
event_source = EventSource(response)
|
event_source = EventSource(response)
|
||||||
@@ -289,6 +325,7 @@ class StreamableHTTPTransport:
|
|||||||
sse,
|
sse,
|
||||||
ctx.read_stream_writer,
|
ctx.read_stream_writer,
|
||||||
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
|
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
|
||||||
|
is_initialization=is_initialization,
|
||||||
)
|
)
|
||||||
# If the SSE event indicates completion, like returning respose/error
|
# If the SSE event indicates completion, like returning respose/error
|
||||||
# break the loop
|
# break the loop
|
||||||
@@ -385,7 +422,7 @@ class StreamableHTTPTransport:
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
headers = self._update_headers_with_session(self.request_headers)
|
headers = self._prepare_request_headers(self.request_headers)
|
||||||
response = await client.delete(self.url, headers=headers)
|
response = await client.delete(self.url, headers=headers)
|
||||||
|
|
||||||
if response.status_code == 405:
|
if response.status_code == 405:
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from mcp.server.auth.handlers.token import TokenHandler
|
|||||||
from mcp.server.auth.middleware.client_auth import ClientAuthenticator
|
from mcp.server.auth.middleware.client_auth import ClientAuthenticator
|
||||||
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
|
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
|
||||||
from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions
|
from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions
|
||||||
|
from mcp.server.streamable_http import MCP_PROTOCOL_VERSION_HEADER
|
||||||
from mcp.shared.auth import OAuthMetadata
|
from mcp.shared.auth import OAuthMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -55,7 +56,7 @@ def cors_middleware(
|
|||||||
app=request_response(handler),
|
app=request_response(handler),
|
||||||
allow_origins="*",
|
allow_origins="*",
|
||||||
allow_methods=allow_methods,
|
allow_methods=allow_methods,
|
||||||
allow_headers=["mcp-protocol-version"],
|
allow_headers=[MCP_PROTOCOL_VERSION_HEADER],
|
||||||
)
|
)
|
||||||
return cors_app
|
return cors_app
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,9 @@ from starlette.responses import Response
|
|||||||
from starlette.types import Receive, Scope, Send
|
from starlette.types import Receive, Scope, Send
|
||||||
|
|
||||||
from mcp.shared.message import ServerMessageMetadata, SessionMessage
|
from mcp.shared.message import ServerMessageMetadata, SessionMessage
|
||||||
|
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
||||||
from mcp.types import (
|
from mcp.types import (
|
||||||
|
DEFAULT_NEGOTIATED_VERSION,
|
||||||
INTERNAL_ERROR,
|
INTERNAL_ERROR,
|
||||||
INVALID_PARAMS,
|
INVALID_PARAMS,
|
||||||
INVALID_REQUEST,
|
INVALID_REQUEST,
|
||||||
@@ -45,6 +47,7 @@ MAXIMUM_MESSAGE_SIZE = 4 * 1024 * 1024 # 4MB
|
|||||||
|
|
||||||
# Header names
|
# Header names
|
||||||
MCP_SESSION_ID_HEADER = "mcp-session-id"
|
MCP_SESSION_ID_HEADER = "mcp-session-id"
|
||||||
|
MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version"
|
||||||
LAST_EVENT_ID_HEADER = "last-event-id"
|
LAST_EVENT_ID_HEADER = "last-event-id"
|
||||||
|
|
||||||
# Content types
|
# Content types
|
||||||
@@ -293,7 +296,7 @@ class StreamableHTTPServerTransport:
|
|||||||
has_json, has_sse = self._check_accept_headers(request)
|
has_json, has_sse = self._check_accept_headers(request)
|
||||||
if not (has_json and has_sse):
|
if not (has_json and has_sse):
|
||||||
response = self._create_error_response(
|
response = self._create_error_response(
|
||||||
("Not Acceptable: Client must accept both application/json and " "text/event-stream"),
|
("Not Acceptable: Client must accept both application/json and text/event-stream"),
|
||||||
HTTPStatus.NOT_ACCEPTABLE,
|
HTTPStatus.NOT_ACCEPTABLE,
|
||||||
)
|
)
|
||||||
await response(scope, receive, send)
|
await response(scope, receive, send)
|
||||||
@@ -353,8 +356,7 @@ class StreamableHTTPServerTransport:
|
|||||||
)
|
)
|
||||||
await response(scope, receive, send)
|
await response(scope, receive, send)
|
||||||
return
|
return
|
||||||
# For non-initialization requests, validate the session
|
elif not await self._validate_request_headers(request, send):
|
||||||
elif not await self._validate_session(request, send):
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# For notifications and responses only, return 202 Accepted
|
# For notifications and responses only, return 202 Accepted
|
||||||
@@ -513,8 +515,9 @@ class StreamableHTTPServerTransport:
|
|||||||
await response(request.scope, request.receive, send)
|
await response(request.scope, request.receive, send)
|
||||||
return
|
return
|
||||||
|
|
||||||
if not await self._validate_session(request, send):
|
if not await self._validate_request_headers(request, send):
|
||||||
return
|
return
|
||||||
|
|
||||||
# Handle resumability: check for Last-Event-ID header
|
# Handle resumability: check for Last-Event-ID header
|
||||||
if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER):
|
if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER):
|
||||||
await self._replay_events(last_event_id, request, send)
|
await self._replay_events(last_event_id, request, send)
|
||||||
@@ -593,7 +596,7 @@ class StreamableHTTPServerTransport:
|
|||||||
await response(request.scope, request.receive, send)
|
await response(request.scope, request.receive, send)
|
||||||
return
|
return
|
||||||
|
|
||||||
if not await self._validate_session(request, send):
|
if not await self._validate_request_headers(request, send):
|
||||||
return
|
return
|
||||||
|
|
||||||
await self._terminate_session()
|
await self._terminate_session()
|
||||||
@@ -653,6 +656,13 @@ class StreamableHTTPServerTransport:
|
|||||||
)
|
)
|
||||||
await response(request.scope, request.receive, send)
|
await response(request.scope, request.receive, send)
|
||||||
|
|
||||||
|
async def _validate_request_headers(self, request: Request, send: Send) -> bool:
|
||||||
|
if not await self._validate_session(request, send):
|
||||||
|
return False
|
||||||
|
if not await self._validate_protocol_version(request, send):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
async def _validate_session(self, request: Request, send: Send) -> bool:
|
async def _validate_session(self, request: Request, send: Send) -> bool:
|
||||||
"""Validate the session ID in the request."""
|
"""Validate the session ID in the request."""
|
||||||
if not self.mcp_session_id:
|
if not self.mcp_session_id:
|
||||||
@@ -682,6 +692,28 @@ class StreamableHTTPServerTransport:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
async def _validate_protocol_version(self, request: Request, send: Send) -> bool:
|
||||||
|
"""Validate the protocol version header in the request."""
|
||||||
|
# Get the protocol version from the request headers
|
||||||
|
protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER)
|
||||||
|
|
||||||
|
# If no protocol version provided, assume default version
|
||||||
|
if protocol_version is None:
|
||||||
|
protocol_version = DEFAULT_NEGOTIATED_VERSION
|
||||||
|
|
||||||
|
# Check if the protocol version is supported
|
||||||
|
if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS:
|
||||||
|
supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS)
|
||||||
|
response = self._create_error_response(
|
||||||
|
f"Bad Request: Unsupported protocol version: {protocol_version}. "
|
||||||
|
+ f"Supported versions: {supported_versions}",
|
||||||
|
HTTPStatus.BAD_REQUEST,
|
||||||
|
)
|
||||||
|
await response(request.scope, request.receive, send)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None:
|
async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None:
|
||||||
"""
|
"""
|
||||||
Replays events that would have been sent after the specified event ID.
|
Replays events that would have been sent after the specified event ID.
|
||||||
|
|||||||
@@ -24,6 +24,14 @@ for reference.
|
|||||||
|
|
||||||
LATEST_PROTOCOL_VERSION = "2025-03-26"
|
LATEST_PROTOCOL_VERSION = "2025-03-26"
|
||||||
|
|
||||||
|
"""
|
||||||
|
The default negotiated version of the Model Context Protocol when no version is specified.
|
||||||
|
We need this to satisfy the MCP specification, which requires the server to assume a
|
||||||
|
specific version if none is provided by the client. See section "Protocol Version Header" at
|
||||||
|
https://modelcontextprotocol.io/specification
|
||||||
|
"""
|
||||||
|
DEFAULT_NEGOTIATED_VERSION = "2025-03-26"
|
||||||
|
|
||||||
ProgressToken = str | int
|
ProgressToken = str | int
|
||||||
Cursor = str
|
Cursor = str
|
||||||
Role = Literal["user", "assistant"]
|
Role = Literal["user", "assistant"]
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from mcp.client.session import ClientSession
|
|||||||
from mcp.client.streamable_http import streamablehttp_client
|
from mcp.client.streamable_http import streamablehttp_client
|
||||||
from mcp.server import Server
|
from mcp.server import Server
|
||||||
from mcp.server.streamable_http import (
|
from mcp.server.streamable_http import (
|
||||||
|
MCP_PROTOCOL_VERSION_HEADER,
|
||||||
MCP_SESSION_ID_HEADER,
|
MCP_SESSION_ID_HEADER,
|
||||||
SESSION_ID_PATTERN,
|
SESSION_ID_PATTERN,
|
||||||
EventCallback,
|
EventCallback,
|
||||||
@@ -64,6 +65,17 @@ INIT_REQUEST = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Helper functions
|
||||||
|
def extract_protocol_version_from_sse(response: requests.Response) -> str:
|
||||||
|
"""Extract the negotiated protocol version from an SSE initialization response."""
|
||||||
|
assert response.headers.get("Content-Type") == "text/event-stream"
|
||||||
|
for line in response.text.splitlines():
|
||||||
|
if line.startswith("data: "):
|
||||||
|
init_data = json.loads(line[6:])
|
||||||
|
return init_data["result"]["protocolVersion"]
|
||||||
|
raise ValueError("Could not extract protocol version from SSE response")
|
||||||
|
|
||||||
|
|
||||||
# Simple in-memory event store for testing
|
# Simple in-memory event store for testing
|
||||||
class SimpleEventStore(EventStore):
|
class SimpleEventStore(EventStore):
|
||||||
"""Simple in-memory event store for testing."""
|
"""Simple in-memory event store for testing."""
|
||||||
@@ -560,11 +572,17 @@ def test_session_termination(basic_server, basic_server_url):
|
|||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Extract negotiated protocol version from SSE response
|
||||||
|
negotiated_version = extract_protocol_version_from_sse(response)
|
||||||
|
|
||||||
# Now terminate the session
|
# Now terminate the session
|
||||||
session_id = response.headers.get(MCP_SESSION_ID_HEADER)
|
session_id = response.headers.get(MCP_SESSION_ID_HEADER)
|
||||||
response = requests.delete(
|
response = requests.delete(
|
||||||
f"{basic_server_url}/mcp",
|
f"{basic_server_url}/mcp",
|
||||||
headers={MCP_SESSION_ID_HEADER: session_id},
|
headers={
|
||||||
|
MCP_SESSION_ID_HEADER: session_id,
|
||||||
|
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
@@ -595,16 +613,20 @@ def test_response(basic_server, basic_server_url):
|
|||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
# Now terminate the session
|
# Extract negotiated protocol version from SSE response
|
||||||
|
negotiated_version = extract_protocol_version_from_sse(response)
|
||||||
|
|
||||||
|
# Now get the session ID
|
||||||
session_id = response.headers.get(MCP_SESSION_ID_HEADER)
|
session_id = response.headers.get(MCP_SESSION_ID_HEADER)
|
||||||
|
|
||||||
# Try to use the terminated session
|
# Try to use the session with proper headers
|
||||||
tools_response = requests.post(
|
tools_response = requests.post(
|
||||||
mcp_url,
|
mcp_url,
|
||||||
headers={
|
headers={
|
||||||
"Accept": "application/json, text/event-stream",
|
"Accept": "application/json, text/event-stream",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
MCP_SESSION_ID_HEADER: session_id, # Use the session ID we got earlier
|
MCP_SESSION_ID_HEADER: session_id, # Use the session ID we got earlier
|
||||||
|
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
|
||||||
},
|
},
|
||||||
json={"jsonrpc": "2.0", "method": "tools/list", "id": "tools-1"},
|
json={"jsonrpc": "2.0", "method": "tools/list", "id": "tools-1"},
|
||||||
stream=True,
|
stream=True,
|
||||||
@@ -646,12 +668,23 @@ def test_get_sse_stream(basic_server, basic_server_url):
|
|||||||
session_id = init_response.headers.get(MCP_SESSION_ID_HEADER)
|
session_id = init_response.headers.get(MCP_SESSION_ID_HEADER)
|
||||||
assert session_id is not None
|
assert session_id is not None
|
||||||
|
|
||||||
|
# Extract negotiated protocol version from SSE response
|
||||||
|
init_data = None
|
||||||
|
assert init_response.headers.get("Content-Type") == "text/event-stream"
|
||||||
|
for line in init_response.text.splitlines():
|
||||||
|
if line.startswith("data: "):
|
||||||
|
init_data = json.loads(line[6:])
|
||||||
|
break
|
||||||
|
assert init_data is not None
|
||||||
|
negotiated_version = init_data["result"]["protocolVersion"]
|
||||||
|
|
||||||
# Now attempt to establish an SSE stream via GET
|
# Now attempt to establish an SSE stream via GET
|
||||||
get_response = requests.get(
|
get_response = requests.get(
|
||||||
mcp_url,
|
mcp_url,
|
||||||
headers={
|
headers={
|
||||||
"Accept": "text/event-stream",
|
"Accept": "text/event-stream",
|
||||||
MCP_SESSION_ID_HEADER: session_id,
|
MCP_SESSION_ID_HEADER: session_id,
|
||||||
|
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
|
||||||
},
|
},
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
@@ -666,6 +699,7 @@ def test_get_sse_stream(basic_server, basic_server_url):
|
|||||||
headers={
|
headers={
|
||||||
"Accept": "text/event-stream",
|
"Accept": "text/event-stream",
|
||||||
MCP_SESSION_ID_HEADER: session_id,
|
MCP_SESSION_ID_HEADER: session_id,
|
||||||
|
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
|
||||||
},
|
},
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
@@ -694,11 +728,22 @@ def test_get_validation(basic_server, basic_server_url):
|
|||||||
session_id = init_response.headers.get(MCP_SESSION_ID_HEADER)
|
session_id = init_response.headers.get(MCP_SESSION_ID_HEADER)
|
||||||
assert session_id is not None
|
assert session_id is not None
|
||||||
|
|
||||||
|
# Extract negotiated protocol version from SSE response
|
||||||
|
init_data = None
|
||||||
|
assert init_response.headers.get("Content-Type") == "text/event-stream"
|
||||||
|
for line in init_response.text.splitlines():
|
||||||
|
if line.startswith("data: "):
|
||||||
|
init_data = json.loads(line[6:])
|
||||||
|
break
|
||||||
|
assert init_data is not None
|
||||||
|
negotiated_version = init_data["result"]["protocolVersion"]
|
||||||
|
|
||||||
# Test without Accept header
|
# Test without Accept header
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
mcp_url,
|
mcp_url,
|
||||||
headers={
|
headers={
|
||||||
MCP_SESSION_ID_HEADER: session_id,
|
MCP_SESSION_ID_HEADER: session_id,
|
||||||
|
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
|
||||||
},
|
},
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
@@ -711,6 +756,7 @@ def test_get_validation(basic_server, basic_server_url):
|
|||||||
headers={
|
headers={
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
MCP_SESSION_ID_HEADER: session_id,
|
MCP_SESSION_ID_HEADER: session_id,
|
||||||
|
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response.status_code == 406
|
assert response.status_code == 406
|
||||||
@@ -1004,6 +1050,7 @@ async def test_streamablehttp_client_resumption(event_server):
|
|||||||
captured_resumption_token = None
|
captured_resumption_token = None
|
||||||
captured_notifications = []
|
captured_notifications = []
|
||||||
tool_started = False
|
tool_started = False
|
||||||
|
captured_protocol_version = None
|
||||||
|
|
||||||
async def message_handler(
|
async def message_handler(
|
||||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||||
@@ -1032,6 +1079,8 @@ async def test_streamablehttp_client_resumption(event_server):
|
|||||||
assert isinstance(result, InitializeResult)
|
assert isinstance(result, InitializeResult)
|
||||||
captured_session_id = get_session_id()
|
captured_session_id = get_session_id()
|
||||||
assert captured_session_id is not None
|
assert captured_session_id is not None
|
||||||
|
# Capture the negotiated protocol version
|
||||||
|
captured_protocol_version = result.protocolVersion
|
||||||
|
|
||||||
# Start a long-running tool in a task
|
# Start a long-running tool in a task
|
||||||
async with anyio.create_task_group() as tg:
|
async with anyio.create_task_group() as tg:
|
||||||
@@ -1064,10 +1113,12 @@ async def test_streamablehttp_client_resumption(event_server):
|
|||||||
captured_notifications_pre = captured_notifications.copy()
|
captured_notifications_pre = captured_notifications.copy()
|
||||||
captured_notifications = []
|
captured_notifications = []
|
||||||
|
|
||||||
# Now resume the session with the same mcp-session-id
|
# Now resume the session with the same mcp-session-id and protocol version
|
||||||
headers = {}
|
headers = {}
|
||||||
if captured_session_id:
|
if captured_session_id:
|
||||||
headers[MCP_SESSION_ID_HEADER] = captured_session_id
|
headers[MCP_SESSION_ID_HEADER] = captured_session_id
|
||||||
|
if captured_protocol_version:
|
||||||
|
headers[MCP_PROTOCOL_VERSION_HEADER] = captured_protocol_version
|
||||||
|
|
||||||
async with streamablehttp_client(f"{server_url}/mcp", headers=headers) as (
|
async with streamablehttp_client(f"{server_url}/mcp", headers=headers) as (
|
||||||
read_stream,
|
read_stream,
|
||||||
@@ -1358,3 +1409,115 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No
|
|||||||
assert ctx["headers"].get("x-request-id") == f"request-{i}"
|
assert ctx["headers"].get("x-request-id") == f"request-{i}"
|
||||||
assert ctx["headers"].get("x-custom-value") == f"value-{i}"
|
assert ctx["headers"].get("x-custom-value") == f"value-{i}"
|
||||||
assert ctx["headers"].get("authorization") == f"Bearer token-{i}"
|
assert ctx["headers"].get("authorization") == f"Bearer token-{i}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_client_includes_protocol_version_header_after_init(context_aware_server, basic_server_url):
|
||||||
|
"""Test that client includes mcp-protocol-version header after initialization."""
|
||||||
|
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
|
||||||
|
read_stream,
|
||||||
|
write_stream,
|
||||||
|
_,
|
||||||
|
):
|
||||||
|
async with ClientSession(read_stream, write_stream) as session:
|
||||||
|
# Initialize and get the negotiated version
|
||||||
|
init_result = await session.initialize()
|
||||||
|
negotiated_version = init_result.protocolVersion
|
||||||
|
|
||||||
|
# Call a tool that echoes headers to verify the header is present
|
||||||
|
tool_result = await session.call_tool("echo_headers", {})
|
||||||
|
|
||||||
|
assert len(tool_result.content) == 1
|
||||||
|
assert isinstance(tool_result.content[0], TextContent)
|
||||||
|
headers_data = json.loads(tool_result.content[0].text)
|
||||||
|
|
||||||
|
# Verify protocol version header is present
|
||||||
|
assert "mcp-protocol-version" in headers_data
|
||||||
|
assert headers_data[MCP_PROTOCOL_VERSION_HEADER] == negotiated_version
|
||||||
|
|
||||||
|
|
||||||
|
def test_server_validates_protocol_version_header(basic_server, basic_server_url):
|
||||||
|
"""Test that server returns 400 Bad Request version if header unsupported or invalid."""
|
||||||
|
# First initialize a session to get a valid session ID
|
||||||
|
init_response = requests.post(
|
||||||
|
f"{basic_server_url}/mcp",
|
||||||
|
headers={
|
||||||
|
"Accept": "application/json, text/event-stream",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json=INIT_REQUEST,
|
||||||
|
)
|
||||||
|
assert init_response.status_code == 200
|
||||||
|
session_id = init_response.headers.get(MCP_SESSION_ID_HEADER)
|
||||||
|
|
||||||
|
# Test request with invalid protocol version (should fail)
|
||||||
|
response = requests.post(
|
||||||
|
f"{basic_server_url}/mcp",
|
||||||
|
headers={
|
||||||
|
"Accept": "application/json, text/event-stream",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
MCP_SESSION_ID_HEADER: session_id,
|
||||||
|
MCP_PROTOCOL_VERSION_HEADER: "invalid-version",
|
||||||
|
},
|
||||||
|
json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-2"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower()
|
||||||
|
|
||||||
|
# Test request with unsupported protocol version (should fail)
|
||||||
|
response = requests.post(
|
||||||
|
f"{basic_server_url}/mcp",
|
||||||
|
headers={
|
||||||
|
"Accept": "application/json, text/event-stream",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
MCP_SESSION_ID_HEADER: session_id,
|
||||||
|
MCP_PROTOCOL_VERSION_HEADER: "1999-01-01", # Very old unsupported version
|
||||||
|
},
|
||||||
|
json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-3"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower()
|
||||||
|
|
||||||
|
# Test request with valid protocol version (should succeed)
|
||||||
|
negotiated_version = extract_protocol_version_from_sse(init_response)
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"{basic_server_url}/mcp",
|
||||||
|
headers={
|
||||||
|
"Accept": "application/json, text/event-stream",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
MCP_SESSION_ID_HEADER: session_id,
|
||||||
|
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
|
||||||
|
},
|
||||||
|
json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-4"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_server_backwards_compatibility_no_protocol_version(basic_server, basic_server_url):
|
||||||
|
"""Test server accepts requests without protocol version header."""
|
||||||
|
# First initialize a session to get a valid session ID
|
||||||
|
init_response = requests.post(
|
||||||
|
f"{basic_server_url}/mcp",
|
||||||
|
headers={
|
||||||
|
"Accept": "application/json, text/event-stream",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json=INIT_REQUEST,
|
||||||
|
)
|
||||||
|
assert init_response.status_code == 200
|
||||||
|
session_id = init_response.headers.get(MCP_SESSION_ID_HEADER)
|
||||||
|
|
||||||
|
# Test request without mcp-protocol-version header (backwards compatibility)
|
||||||
|
response = requests.post(
|
||||||
|
f"{basic_server_url}/mcp",
|
||||||
|
headers={
|
||||||
|
"Accept": "application/json, text/event-stream",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
MCP_SESSION_ID_HEADER: session_id,
|
||||||
|
},
|
||||||
|
json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-backwards-compat"},
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200 # Should succeed for backwards compatibility
|
||||||
|
assert response.headers.get("Content-Type") == "text/event-stream"
|
||||||
|
|||||||
Reference in New Issue
Block a user