mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 06:54:18 +01:00
Support for http request injection propagation in StreamableHttp (#833)
This commit is contained in:
@@ -4,6 +4,7 @@ Tests for the StreamableHTTP server and client transport.
|
||||
Contains tests for both server and client sides of the StreamableHTTP transport.
|
||||
"""
|
||||
|
||||
import json
|
||||
import multiprocessing
|
||||
import socket
|
||||
import time
|
||||
@@ -17,6 +18,7 @@ import requests
|
||||
import uvicorn
|
||||
from pydantic import AnyUrl
|
||||
from starlette.applications import Starlette
|
||||
from starlette.requests import Request
|
||||
from starlette.routing import Mount
|
||||
|
||||
import mcp.types as types
|
||||
@@ -1223,3 +1225,203 @@ async def test_streamablehttp_server_sampling(basic_server, basic_server_url):
|
||||
captured_message_params.messages[0].content.text
|
||||
== "Server needs client sampling"
|
||||
)
|
||||
|
||||
|
||||
# Context-aware server implementation for testing request context propagation
|
||||
class ContextAwareServerTest(Server):
|
||||
def __init__(self):
|
||||
super().__init__("ContextAwareServer")
|
||||
|
||||
@self.list_tools()
|
||||
async def handle_list_tools() -> list[Tool]:
|
||||
return [
|
||||
Tool(
|
||||
name="echo_headers",
|
||||
description="Echo request headers from context",
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
),
|
||||
Tool(
|
||||
name="echo_context",
|
||||
description="Echo request context with custom data",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"request_id": {"type": "string"},
|
||||
},
|
||||
"required": ["request_id"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
@self.call_tool()
|
||||
async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
|
||||
ctx = self.request_context
|
||||
|
||||
if name == "echo_headers":
|
||||
# Access the request object from context
|
||||
headers_info = {}
|
||||
if ctx.request and isinstance(ctx.request, Request):
|
||||
headers_info = dict(ctx.request.headers)
|
||||
return [
|
||||
TextContent(
|
||||
type="text",
|
||||
text=json.dumps(headers_info),
|
||||
)
|
||||
]
|
||||
|
||||
elif name == "echo_context":
|
||||
# Return full context information
|
||||
context_data = {
|
||||
"request_id": args.get("request_id"),
|
||||
"headers": {},
|
||||
"method": None,
|
||||
"path": None,
|
||||
}
|
||||
if ctx.request and isinstance(ctx.request, Request):
|
||||
request = ctx.request
|
||||
context_data["headers"] = dict(request.headers)
|
||||
context_data["method"] = request.method
|
||||
context_data["path"] = request.url.path
|
||||
return [
|
||||
TextContent(
|
||||
type="text",
|
||||
text=json.dumps(context_data),
|
||||
)
|
||||
]
|
||||
|
||||
return [TextContent(type="text", text=f"Unknown tool: {name}")]
|
||||
|
||||
|
||||
# Server runner for context-aware testing
|
||||
def run_context_aware_server(port: int):
|
||||
"""Run the context-aware test server."""
|
||||
server = ContextAwareServerTest()
|
||||
|
||||
session_manager = StreamableHTTPSessionManager(
|
||||
app=server,
|
||||
event_store=None,
|
||||
json_response=False,
|
||||
)
|
||||
|
||||
app = Starlette(
|
||||
debug=True,
|
||||
routes=[
|
||||
Mount("/mcp", app=session_manager.handle_request),
|
||||
],
|
||||
lifespan=lambda app: session_manager.run(),
|
||||
)
|
||||
|
||||
server_instance = uvicorn.Server(
|
||||
config=uvicorn.Config(
|
||||
app=app,
|
||||
host="127.0.0.1",
|
||||
port=port,
|
||||
log_level="error",
|
||||
)
|
||||
)
|
||||
server_instance.run()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def context_aware_server(basic_server_port: int) -> Generator[None, None, None]:
|
||||
"""Start the context-aware server in a separate process."""
|
||||
proc = multiprocessing.Process(
|
||||
target=run_context_aware_server, args=(basic_server_port,), daemon=True
|
||||
)
|
||||
proc.start()
|
||||
|
||||
# Wait for server to be running
|
||||
max_attempts = 20
|
||||
attempt = 0
|
||||
while attempt < max_attempts:
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.connect(("127.0.0.1", basic_server_port))
|
||||
break
|
||||
except ConnectionRefusedError:
|
||||
time.sleep(0.1)
|
||||
attempt += 1
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Context-aware server failed to start after {max_attempts} attempts"
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
proc.kill()
|
||||
proc.join(timeout=2)
|
||||
if proc.is_alive():
|
||||
print("Context-aware server process failed to terminate")
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_streamablehttp_request_context_propagation(
|
||||
context_aware_server: None, basic_server_url: str
|
||||
) -> None:
|
||||
"""Test that request context is properly propagated through StreamableHTTP."""
|
||||
custom_headers = {
|
||||
"Authorization": "Bearer test-token",
|
||||
"X-Custom-Header": "test-value",
|
||||
"X-Trace-Id": "trace-123",
|
||||
}
|
||||
|
||||
async with streamablehttp_client(
|
||||
f"{basic_server_url}/mcp", headers=custom_headers
|
||||
) as (read_stream, write_stream, _):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
result = await session.initialize()
|
||||
assert isinstance(result, InitializeResult)
|
||||
assert result.serverInfo.name == "ContextAwareServer"
|
||||
|
||||
# Call the tool that echoes headers back
|
||||
tool_result = await session.call_tool("echo_headers", {})
|
||||
|
||||
# Parse the JSON response
|
||||
assert len(tool_result.content) == 1
|
||||
assert isinstance(tool_result.content[0], TextContent)
|
||||
headers_data = json.loads(tool_result.content[0].text)
|
||||
|
||||
# Verify headers were propagated
|
||||
assert headers_data.get("authorization") == "Bearer test-token"
|
||||
assert headers_data.get("x-custom-header") == "test-value"
|
||||
assert headers_data.get("x-trace-id") == "trace-123"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_streamablehttp_request_context_isolation(
|
||||
context_aware_server: None, basic_server_url: str
|
||||
) -> None:
|
||||
"""Test that request contexts are isolated between StreamableHTTP clients."""
|
||||
contexts = []
|
||||
|
||||
# Create multiple clients with different headers
|
||||
for i in range(3):
|
||||
headers = {
|
||||
"X-Request-Id": f"request-{i}",
|
||||
"X-Custom-Value": f"value-{i}",
|
||||
"Authorization": f"Bearer token-{i}",
|
||||
}
|
||||
|
||||
async with streamablehttp_client(
|
||||
f"{basic_server_url}/mcp", headers=headers
|
||||
) as (read_stream, write_stream, _):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
await session.initialize()
|
||||
|
||||
# Call the tool that echoes context
|
||||
tool_result = await session.call_tool(
|
||||
"echo_context", {"request_id": f"request-{i}"}
|
||||
)
|
||||
|
||||
assert len(tool_result.content) == 1
|
||||
assert isinstance(tool_result.content[0], TextContent)
|
||||
context_data = json.loads(tool_result.content[0].text)
|
||||
contexts.append(context_data)
|
||||
|
||||
# Verify each request had its own context
|
||||
assert len(contexts) == 3
|
||||
for i, ctx in enumerate(contexts):
|
||||
assert ctx["request_id"] == f"request-{i}"
|
||||
assert ctx["headers"].get("x-request-id") == f"request-{i}"
|
||||
assert ctx["headers"].get("x-custom-value") == f"value-{i}"
|
||||
assert ctx["headers"].get("authorization") == f"Bearer token-{i}"
|
||||
|
||||
Reference in New Issue
Block a user