mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 23:04:25 +01:00
Support for http request injection propagation to tools (#816)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import multiprocessing
|
||||
import socket
|
||||
import time
|
||||
@@ -318,3 +319,187 @@ async def test_sse_client_basic_connection_mounted_app(
|
||||
# Test ping
|
||||
ping_result = await session.send_ping()
|
||||
assert isinstance(ping_result, EmptyResult)
|
||||
|
||||
|
||||
# Test server with request context that returns headers in the response
|
||||
class RequestContextServer(Server[object, Request]):
|
||||
def __init__(self):
|
||||
super().__init__("request_context_server")
|
||||
|
||||
@self.call_tool()
|
||||
async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
|
||||
headers_info = {}
|
||||
context = self.request_context
|
||||
if context.request:
|
||||
headers_info = dict(context.request.headers)
|
||||
|
||||
if name == "echo_headers":
|
||||
return [TextContent(type="text", text=json.dumps(headers_info))]
|
||||
elif name == "echo_context":
|
||||
context_data = {
|
||||
"request_id": args.get("request_id"),
|
||||
"headers": headers_info,
|
||||
}
|
||||
return [TextContent(type="text", text=json.dumps(context_data))]
|
||||
|
||||
return [TextContent(type="text", text=f"Called {name}")]
|
||||
|
||||
@self.list_tools()
|
||||
async def handle_list_tools() -> list[Tool]:
|
||||
return [
|
||||
Tool(
|
||||
name="echo_headers",
|
||||
description="Echoes request headers",
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
),
|
||||
Tool(
|
||||
name="echo_context",
|
||||
description="Echoes request context",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {"request_id": {"type": "string"}},
|
||||
"required": ["request_id"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def run_context_server(server_port: int) -> None:
|
||||
"""Run a server that captures request context"""
|
||||
sse = SseServerTransport("/messages/")
|
||||
context_server = RequestContextServer()
|
||||
|
||||
async def handle_sse(request: Request) -> Response:
|
||||
async with sse.connect_sse(
|
||||
request.scope, request.receive, request._send
|
||||
) as streams:
|
||||
await context_server.run(
|
||||
streams[0], streams[1], context_server.create_initialization_options()
|
||||
)
|
||||
return Response()
|
||||
|
||||
app = Starlette(
|
||||
routes=[
|
||||
Route("/sse", endpoint=handle_sse),
|
||||
Mount("/messages/", app=sse.handle_post_message),
|
||||
]
|
||||
)
|
||||
|
||||
server = uvicorn.Server(
|
||||
config=uvicorn.Config(
|
||||
app=app, host="127.0.0.1", port=server_port, log_level="error"
|
||||
)
|
||||
)
|
||||
print(f"starting context server on {server_port}")
|
||||
server.run()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def context_server(server_port: int) -> Generator[None, None, None]:
|
||||
"""Fixture that provides a server with request context capture"""
|
||||
proc = multiprocessing.Process(
|
||||
target=run_context_server, kwargs={"server_port": server_port}, daemon=True
|
||||
)
|
||||
print("starting context server process")
|
||||
proc.start()
|
||||
|
||||
# Wait for server to be running
|
||||
max_attempts = 20
|
||||
attempt = 0
|
||||
print("waiting for context server to start")
|
||||
while attempt < max_attempts:
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.connect(("127.0.0.1", server_port))
|
||||
break
|
||||
except ConnectionRefusedError:
|
||||
time.sleep(0.1)
|
||||
attempt += 1
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Context server failed to start after {max_attempts} attempts"
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
print("killing context server")
|
||||
proc.kill()
|
||||
proc.join(timeout=2)
|
||||
if proc.is_alive():
|
||||
print("context server process failed to terminate")
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_request_context_propagation(
|
||||
context_server: None, server_url: str
|
||||
) -> None:
|
||||
"""Test that request context is properly propagated through SSE transport."""
|
||||
# Test with custom headers
|
||||
custom_headers = {
|
||||
"Authorization": "Bearer test-token",
|
||||
"X-Custom-Header": "test-value",
|
||||
"X-Trace-Id": "trace-123",
|
||||
}
|
||||
|
||||
async with sse_client(server_url + "/sse", headers=custom_headers) as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
# Initialize the session
|
||||
result = await session.initialize()
|
||||
assert isinstance(result, InitializeResult)
|
||||
|
||||
# Call the tool that echoes headers back
|
||||
tool_result = await session.call_tool("echo_headers", {})
|
||||
|
||||
# Parse the JSON response
|
||||
|
||||
assert len(tool_result.content) == 1
|
||||
headers_data = json.loads(
|
||||
tool_result.content[0].text
|
||||
if tool_result.content[0].type == "text"
|
||||
else "{}"
|
||||
)
|
||||
|
||||
# Verify headers were propagated
|
||||
assert headers_data.get("authorization") == "Bearer test-token"
|
||||
assert headers_data.get("x-custom-header") == "test-value"
|
||||
assert headers_data.get("x-trace-id") == "trace-123"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_request_context_isolation(context_server: None, server_url: str) -> None:
|
||||
"""Test that request contexts are isolated between different SSE clients."""
|
||||
contexts = []
|
||||
|
||||
# Create multiple clients with different headers
|
||||
for i in range(3):
|
||||
headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"}
|
||||
|
||||
async with sse_client(server_url + "/sse", headers=headers) as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
await session.initialize()
|
||||
|
||||
# Call the tool that echoes context
|
||||
tool_result = await session.call_tool(
|
||||
"echo_context", {"request_id": f"request-{i}"}
|
||||
)
|
||||
|
||||
assert len(tool_result.content) == 1
|
||||
context_data = json.loads(
|
||||
tool_result.content[0].text
|
||||
if tool_result.content[0].type == "text"
|
||||
else "{}"
|
||||
)
|
||||
contexts.append(context_data)
|
||||
|
||||
# Verify each request had its own context
|
||||
assert len(contexts) == 3
|
||||
for i, ctx in enumerate(contexts):
|
||||
assert ctx["request_id"] == f"request-{i}"
|
||||
assert ctx["headers"].get("x-request-id") == f"request-{i}"
|
||||
assert ctx["headers"].get("x-custom-value") == f"value-{i}"
|
||||
|
||||
Reference in New Issue
Block a user