Use 120 characters instead of 88 (#856)

This commit is contained in:
Marcelo Trylesinski
2025-06-11 02:45:50 -07:00
committed by GitHub
parent f7265f7b91
commit 543961968c
90 changed files with 687 additions and 2142 deletions

View File

@@ -72,9 +72,7 @@ class SimpleEventStore(EventStore):
self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage]] = []
self._event_id_counter = 0
async def store_event(
self, stream_id: StreamId, message: types.JSONRPCMessage
) -> EventId:
async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage) -> EventId:
"""Store an event and return its ID."""
self._event_id_counter += 1
event_id = str(self._event_id_counter)
@@ -156,9 +154,7 @@ class ServerTest(Server):
# When the tool is called, send a notification to test GET stream
if name == "test_tool_with_standalone_notification":
await ctx.session.send_resource_updated(
uri=AnyUrl("http://test_resource")
)
await ctx.session.send_resource_updated(uri=AnyUrl("http://test_resource"))
return [TextContent(type="text", text=f"Called {name}")]
elif name == "long_running_with_checkpoints":
@@ -189,9 +185,7 @@ class ServerTest(Server):
messages=[
types.SamplingMessage(
role="user",
content=types.TextContent(
type="text", text="Server needs client sampling"
),
content=types.TextContent(type="text", text="Server needs client sampling"),
)
],
max_tokens=100,
@@ -199,11 +193,7 @@ class ServerTest(Server):
)
# Return the sampling result in the tool response
response = (
sampling_result.content.text
if sampling_result.content.type == "text"
else None
)
response = sampling_result.content.text if sampling_result.content.type == "text" else None
return [
TextContent(
type="text",
@@ -214,9 +204,7 @@ class ServerTest(Server):
return [TextContent(type="text", text=f"Called {name}")]
def create_app(
is_json_response_enabled=False, event_store: EventStore | None = None
) -> Starlette:
def create_app(is_json_response_enabled=False, event_store: EventStore | None = None) -> Starlette:
"""Create a Starlette application for testing using the session manager.
Args:
@@ -245,9 +233,7 @@ def create_app(
return app
def run_server(
port: int, is_json_response_enabled=False, event_store: EventStore | None = None
) -> None:
def run_server(port: int, is_json_response_enabled=False, event_store: EventStore | None = None) -> None:
"""Run the test server.
Args:
@@ -300,9 +286,7 @@ def json_server_port() -> int:
@pytest.fixture
def basic_server(basic_server_port: int) -> Generator[None, None, None]:
"""Start a basic server."""
proc = multiprocessing.Process(
target=run_server, kwargs={"port": basic_server_port}, daemon=True
)
proc = multiprocessing.Process(target=run_server, kwargs={"port": basic_server_port}, daemon=True)
proc.start()
# Wait for server to be running
@@ -778,9 +762,7 @@ async def test_streamablehttp_client_basic_connection(basic_server, basic_server
@pytest.mark.anyio
async def test_streamablehttp_client_resource_read(initialized_client_session):
"""Test client resource read functionality."""
response = await initialized_client_session.read_resource(
uri=AnyUrl("foobar://test-resource")
)
response = await initialized_client_session.read_resource(uri=AnyUrl("foobar://test-resource"))
assert len(response.contents) == 1
assert response.contents[0].uri == AnyUrl("foobar://test-resource")
assert response.contents[0].text == "Read test-resource"
@@ -805,17 +787,13 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session)
async def test_streamablehttp_client_error_handling(initialized_client_session):
"""Test error handling in client."""
with pytest.raises(McpError) as exc_info:
await initialized_client_session.read_resource(
uri=AnyUrl("unknown://test-error")
)
await initialized_client_session.read_resource(uri=AnyUrl("unknown://test-error"))
assert exc_info.value.error.code == 0
assert "Unknown resource: unknown://test-error" in exc_info.value.error.message
@pytest.mark.anyio
async def test_streamablehttp_client_session_persistence(
basic_server, basic_server_url
):
async def test_streamablehttp_client_session_persistence(basic_server, basic_server_url):
"""Test that session ID persists across requests."""
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
read_stream,
@@ -843,9 +821,7 @@ async def test_streamablehttp_client_session_persistence(
@pytest.mark.anyio
async def test_streamablehttp_client_json_response(
json_response_server, json_server_url
):
async def test_streamablehttp_client_json_response(json_response_server, json_server_url):
"""Test client with JSON response mode."""
async with streamablehttp_client(f"{json_server_url}/mcp") as (
read_stream,
@@ -882,9 +858,7 @@ async def test_streamablehttp_client_get_stream(basic_server, basic_server_url):
# Define message handler to capture notifications
async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, types.ServerNotification):
notifications_received.append(message)
@@ -894,9 +868,7 @@ async def test_streamablehttp_client_get_stream(basic_server, basic_server_url):
write_stream,
_,
):
async with ClientSession(
read_stream, write_stream, message_handler=message_handler
) as session:
async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
# Initialize the session - this triggers the GET stream setup
result = await session.initialize()
assert isinstance(result, InitializeResult)
@@ -914,15 +886,11 @@ async def test_streamablehttp_client_get_stream(basic_server, basic_server_url):
assert str(notif.root.params.uri) == "http://test_resource/"
resource_update_found = True
assert (
resource_update_found
), "ResourceUpdatedNotification not received via GET stream"
assert resource_update_found, "ResourceUpdatedNotification not received via GET stream"
@pytest.mark.anyio
async def test_streamablehttp_client_session_termination(
basic_server, basic_server_url
):
async def test_streamablehttp_client_session_termination(basic_server, basic_server_url):
"""Test client session termination functionality."""
captured_session_id = None
@@ -963,9 +931,7 @@ async def test_streamablehttp_client_session_termination(
@pytest.mark.anyio
async def test_streamablehttp_client_session_termination_204(
basic_server, basic_server_url, monkeypatch
):
async def test_streamablehttp_client_session_termination_204(basic_server, basic_server_url, monkeypatch):
"""Test client session termination functionality with a 204 response.
This test patches the httpx client to return a 204 response for DELETEs.
@@ -1040,9 +1006,7 @@ async def test_streamablehttp_client_resumption(event_server):
tool_started = False
async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, types.ServerNotification):
captured_notifications.append(message)
@@ -1062,9 +1026,7 @@ async def test_streamablehttp_client_resumption(event_server):
write_stream,
get_session_id,
):
async with ClientSession(
read_stream, write_stream, message_handler=message_handler
) as session:
async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
# Initialize the session
result = await session.initialize()
assert isinstance(result, InitializeResult)
@@ -1082,9 +1044,7 @@ async def test_streamablehttp_client_resumption(event_server):
types.ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(
name="long_running_with_checkpoints", arguments={}
),
params=types.CallToolRequestParams(name="long_running_with_checkpoints", arguments={}),
)
),
types.CallToolResult,
@@ -1114,9 +1074,7 @@ async def test_streamablehttp_client_resumption(event_server):
write_stream,
_,
):
async with ClientSession(
read_stream, write_stream, message_handler=message_handler
) as session:
async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
# Don't initialize - just use the existing session
# Resume the tool with the resumption token
@@ -1129,9 +1087,7 @@ async def test_streamablehttp_client_resumption(event_server):
types.ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(
name="long_running_with_checkpoints", arguments={}
),
params=types.CallToolRequestParams(name="long_running_with_checkpoints", arguments={}),
)
),
types.CallToolResult,
@@ -1149,14 +1105,11 @@ async def test_streamablehttp_client_resumption(event_server):
# Should not have the first notification
# Check that "Tool started" notification isn't repeated when resuming
assert not any(
isinstance(n.root, types.LoggingMessageNotification)
and n.root.params.data == "Tool started"
isinstance(n.root, types.LoggingMessageNotification) and n.root.params.data == "Tool started"
for n in captured_notifications
)
# there is no intersection between pre and post notifications
assert not any(
n in captured_notifications_pre for n in captured_notifications
)
assert not any(n in captured_notifications_pre for n in captured_notifications)
@pytest.mark.anyio
@@ -1175,11 +1128,7 @@ async def test_streamablehttp_server_sampling(basic_server, basic_server_url):
nonlocal sampling_callback_invoked, captured_message_params
sampling_callback_invoked = True
captured_message_params = params
message_received = (
params.messages[0].content.text
if params.messages[0].content.type == "text"
else None
)
message_received = params.messages[0].content.text if params.messages[0].content.type == "text" else None
return types.CreateMessageResult(
role="assistant",
@@ -1212,19 +1161,13 @@ async def test_streamablehttp_server_sampling(basic_server, basic_server_url):
# Verify the tool result contains the expected content
assert len(tool_result.content) == 1
assert tool_result.content[0].type == "text"
assert (
"Response from sampling: Received message from server"
in tool_result.content[0].text
)
assert "Response from sampling: Received message from server" in tool_result.content[0].text
# Verify sampling callback was invoked
assert sampling_callback_invoked
assert captured_message_params is not None
assert len(captured_message_params.messages) == 1
assert (
captured_message_params.messages[0].content.text
== "Server needs client sampling"
)
assert captured_message_params.messages[0].content.text == "Server needs client sampling"
# Context-aware server implementation for testing request context propagation
@@ -1325,9 +1268,7 @@ def run_context_aware_server(port: int):
@pytest.fixture
def context_aware_server(basic_server_port: int) -> Generator[None, None, None]:
"""Start the context-aware server in a separate process."""
proc = multiprocessing.Process(
target=run_context_aware_server, args=(basic_server_port,), daemon=True
)
proc = multiprocessing.Process(target=run_context_aware_server, args=(basic_server_port,), daemon=True)
proc.start()
# Wait for server to be running
@@ -1342,9 +1283,7 @@ def context_aware_server(basic_server_port: int) -> Generator[None, None, None]:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(
f"Context-aware server failed to start after {max_attempts} attempts"
)
raise RuntimeError(f"Context-aware server failed to start after {max_attempts} attempts")
yield
@@ -1355,9 +1294,7 @@ def context_aware_server(basic_server_port: int) -> Generator[None, None, None]:
@pytest.mark.anyio
async def test_streamablehttp_request_context_propagation(
context_aware_server: None, basic_server_url: str
) -> None:
async def test_streamablehttp_request_context_propagation(context_aware_server: None, basic_server_url: str) -> None:
"""Test that request context is properly propagated through StreamableHTTP."""
custom_headers = {
"Authorization": "Bearer test-token",
@@ -1365,9 +1302,11 @@ async def test_streamablehttp_request_context_propagation(
"X-Trace-Id": "trace-123",
}
async with streamablehttp_client(
f"{basic_server_url}/mcp", headers=custom_headers
) as (read_stream, write_stream, _):
async with streamablehttp_client(f"{basic_server_url}/mcp", headers=custom_headers) as (
read_stream,
write_stream,
_,
):
async with ClientSession(read_stream, write_stream) as session:
result = await session.initialize()
assert isinstance(result, InitializeResult)
@@ -1388,9 +1327,7 @@ async def test_streamablehttp_request_context_propagation(
@pytest.mark.anyio
async def test_streamablehttp_request_context_isolation(
context_aware_server: None, basic_server_url: str
) -> None:
async def test_streamablehttp_request_context_isolation(context_aware_server: None, basic_server_url: str) -> None:
"""Test that request contexts are isolated between StreamableHTTP clients."""
contexts = []
@@ -1402,16 +1339,12 @@ async def test_streamablehttp_request_context_isolation(
"Authorization": f"Bearer token-{i}",
}
async with streamablehttp_client(
f"{basic_server_url}/mcp", headers=headers
) as (read_stream, write_stream, _):
async with streamablehttp_client(f"{basic_server_url}/mcp", headers=headers) as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
# Call the tool that echoes context
tool_result = await session.call_tool(
"echo_context", {"request_id": f"request-{i}"}
)
tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"})
assert len(tool_result.content) == 1
assert isinstance(tool_result.content[0], TextContent)