Support for http request injection propagation to tools (#816)

This commit is contained in:
ihrpr
2025-05-28 15:59:14 +01:00
committed by GitHub
parent 532b1176f9
commit 70014a2bbb
12 changed files with 413 additions and 35 deletions

View File

@@ -1,3 +1,4 @@
import json
import multiprocessing
import socket
import time
@@ -318,3 +319,187 @@ async def test_sse_client_basic_connection_mounted_app(
# Test ping
ping_result = await session.send_ping()
assert isinstance(ping_result, EmptyResult)
# Test server with request context that returns headers in the response
class RequestContextServer(Server[object, Request]):
def __init__(self):
super().__init__("request_context_server")
@self.call_tool()
async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
headers_info = {}
context = self.request_context
if context.request:
headers_info = dict(context.request.headers)
if name == "echo_headers":
return [TextContent(type="text", text=json.dumps(headers_info))]
elif name == "echo_context":
context_data = {
"request_id": args.get("request_id"),
"headers": headers_info,
}
return [TextContent(type="text", text=json.dumps(context_data))]
return [TextContent(type="text", text=f"Called {name}")]
@self.list_tools()
async def handle_list_tools() -> list[Tool]:
return [
Tool(
name="echo_headers",
description="Echoes request headers",
inputSchema={"type": "object", "properties": {}},
),
Tool(
name="echo_context",
description="Echoes request context",
inputSchema={
"type": "object",
"properties": {"request_id": {"type": "string"}},
"required": ["request_id"],
},
),
]
def run_context_server(server_port: int) -> None:
"""Run a server that captures request context"""
sse = SseServerTransport("/messages/")
context_server = RequestContextServer()
async def handle_sse(request: Request) -> Response:
async with sse.connect_sse(
request.scope, request.receive, request._send
) as streams:
await context_server.run(
streams[0], streams[1], context_server.create_initialization_options()
)
return Response()
app = Starlette(
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
]
)
server = uvicorn.Server(
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
print(f"starting context server on {server_port}")
server.run()
@pytest.fixture()
def context_server(server_port: int) -> Generator[None, None, None]:
"""Fixture that provides a server with request context capture"""
proc = multiprocessing.Process(
target=run_context_server, kwargs={"server_port": server_port}, daemon=True
)
print("starting context server process")
proc.start()
# Wait for server to be running
max_attempts = 20
attempt = 0
print("waiting for context server to start")
while attempt < max_attempts:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("127.0.0.1", server_port))
break
except ConnectionRefusedError:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(
f"Context server failed to start after {max_attempts} attempts"
)
yield
print("killing context server")
proc.kill()
proc.join(timeout=2)
if proc.is_alive():
print("context server process failed to terminate")
@pytest.mark.anyio
async def test_request_context_propagation(
context_server: None, server_url: str
) -> None:
"""Test that request context is properly propagated through SSE transport."""
# Test with custom headers
custom_headers = {
"Authorization": "Bearer test-token",
"X-Custom-Header": "test-value",
"X-Trace-Id": "trace-123",
}
async with sse_client(server_url + "/sse", headers=custom_headers) as (
read_stream,
write_stream,
):
async with ClientSession(read_stream, write_stream) as session:
# Initialize the session
result = await session.initialize()
assert isinstance(result, InitializeResult)
# Call the tool that echoes headers back
tool_result = await session.call_tool("echo_headers", {})
# Parse the JSON response
assert len(tool_result.content) == 1
headers_data = json.loads(
tool_result.content[0].text
if tool_result.content[0].type == "text"
else "{}"
)
# Verify headers were propagated
assert headers_data.get("authorization") == "Bearer test-token"
assert headers_data.get("x-custom-header") == "test-value"
assert headers_data.get("x-trace-id") == "trace-123"
@pytest.mark.anyio
async def test_request_context_isolation(context_server: None, server_url: str) -> None:
"""Test that request contexts are isolated between different SSE clients."""
contexts = []
# Create multiple clients with different headers
for i in range(3):
headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"}
async with sse_client(server_url + "/sse", headers=headers) as (
read_stream,
write_stream,
):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
# Call the tool that echoes context
tool_result = await session.call_tool(
"echo_context", {"request_id": f"request-{i}"}
)
assert len(tool_result.content) == 1
context_data = json.loads(
tool_result.content[0].text
if tool_result.content[0].type == "text"
else "{}"
)
contexts.append(context_data)
# Verify each request had its own context
assert len(contexts) == 3
for i, ctx in enumerate(contexts):
assert ctx["request_id"] == f"request-{i}"
assert ctx["headers"].get("x-request-id") == f"request-{i}"
assert ctx["headers"].get("x-custom-value") == f"value-{i}"