Use 120 characters instead of 88 (#856)

This commit is contained in:
Marcelo Trylesinski
2025-06-11 02:45:50 -07:00
committed by GitHub
parent f7265f7b91
commit 543961968c
90 changed files with 687 additions and 2142 deletions

View File

@@ -22,12 +22,8 @@ from mcp.shared.session import (
async def test_bidirectional_progress_notifications():
"""Test that both client and server can send progress notifications."""
# Create memory streams for client/server
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
SessionMessage
](5)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
SessionMessage
](5)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5)
# Run a server session so we can send progress updates in tool
async def run_server():
@@ -134,9 +130,7 @@ async def test_bidirectional_progress_notifications():
# Client message handler to store progress notifications
async def handle_client_message(
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, Exception):
raise message
@@ -172,9 +166,7 @@ async def test_bidirectional_progress_notifications():
await client_session.list_tools()
# Call test_tool with progress token
await client_session.call_tool(
"test_tool", {"_meta": {"progressToken": client_progress_token}}
)
await client_session.call_tool("test_tool", {"_meta": {"progressToken": client_progress_token}})
# Send progress notifications from client to server
await client_session.send_progress_notification(
@@ -221,12 +213,8 @@ async def test_bidirectional_progress_notifications():
async def test_progress_context_manager():
"""Test client using progress context manager for sending progress notifications."""
# Create memory streams for client/server
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
SessionMessage
](5)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
SessionMessage
](5)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5)
# Track progress updates
server_progress_updates = []
@@ -270,9 +258,7 @@ async def test_progress_context_manager():
# Client message handler
async def handle_client_message(
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, Exception):
raise message

View File

@@ -90,9 +90,7 @@ async def test_request_cancellation():
ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(
name="slow_tool", arguments={}
),
params=types.CallToolRequestParams(name="slow_tool", arguments={}),
)
),
types.CallToolResult,
@@ -103,9 +101,7 @@ async def test_request_cancellation():
assert "Request cancelled" in str(e)
ev_cancelled.set()
async with create_connected_server_and_client_session(
make_server()
) as client_session:
async with create_connected_server_and_client_session(make_server()) as client_session:
async with anyio.create_task_group() as tg:
tg.start_soon(make_request, client_session)

View File

@@ -60,11 +60,7 @@ class ServerTest(Server):
await anyio.sleep(2.0)
return f"Slow response from {uri.host}"
raise McpError(
error=ErrorData(
code=404, message="OOPS! no resource with that URI was found"
)
)
raise McpError(error=ErrorData(code=404, message="OOPS! no resource with that URI was found"))
@self.list_tools()
async def handle_list_tools() -> list[Tool]:
@@ -88,12 +84,8 @@ def make_server_app() -> Starlette:
server = ServerTest()
async def handle_sse(request: Request) -> Response:
async with sse.connect_sse(
request.scope, request.receive, request._send
) as streams:
await server.run(
streams[0], streams[1], server.create_initialization_options()
)
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
await server.run(streams[0], streams[1], server.create_initialization_options())
return Response()
app = Starlette(
@@ -108,11 +100,7 @@ def make_server_app() -> Starlette:
def run_server(server_port: int) -> None:
app = make_server_app()
server = uvicorn.Server(
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
print(f"starting server on {server_port}")
server.run()
@@ -124,9 +112,7 @@ def run_server(server_port: int) -> None:
@pytest.fixture()
def server(server_port: int) -> Generator[None, None, None]:
proc = multiprocessing.Process(
target=run_server, kwargs={"server_port": server_port}, daemon=True
)
proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True)
print("starting process")
proc.start()
@@ -171,10 +157,7 @@ async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None:
async def connection_test() -> None:
async with http_client.stream("GET", "/sse") as response:
assert response.status_code == 200
assert (
response.headers["content-type"]
== "text/event-stream; charset=utf-8"
)
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
line_number = 0
async for line in response.aiter_lines():
@@ -206,9 +189,7 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non
@pytest.fixture
async def initialized_sse_client_session(
server, server_url: str
) -> AsyncGenerator[ClientSession, None]:
async def initialized_sse_client_session(server, server_url: str) -> AsyncGenerator[ClientSession, None]:
async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
@@ -236,9 +217,7 @@ async def test_sse_client_exception_handling(
@pytest.mark.anyio
@pytest.mark.skip(
"this test highlights a possible bug in SSE read timeout exception handling"
)
@pytest.mark.skip("this test highlights a possible bug in SSE read timeout exception handling")
async def test_sse_client_timeout(
initialized_sse_client_session: ClientSession,
) -> None:
@@ -260,11 +239,7 @@ async def test_sse_client_timeout(
def run_mounted_server(server_port: int) -> None:
app = make_server_app()
main_app = Starlette(routes=[Mount("/mounted_app", app=app)])
server = uvicorn.Server(
config=uvicorn.Config(
app=main_app, host="127.0.0.1", port=server_port, log_level="error"
)
)
server = uvicorn.Server(config=uvicorn.Config(app=main_app, host="127.0.0.1", port=server_port, log_level="error"))
print(f"starting server on {server_port}")
server.run()
@@ -276,9 +251,7 @@ def run_mounted_server(server_port: int) -> None:
@pytest.fixture()
def mounted_server(server_port: int) -> Generator[None, None, None]:
proc = multiprocessing.Process(
target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True
)
proc = multiprocessing.Process(target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True)
print("starting process")
proc.start()
@@ -308,9 +281,7 @@ def mounted_server(server_port: int) -> Generator[None, None, None]:
@pytest.mark.anyio
async def test_sse_client_basic_connection_mounted_app(
mounted_server: None, server_url: str
) -> None:
async def test_sse_client_basic_connection_mounted_app(mounted_server: None, server_url: str) -> None:
async with sse_client(server_url + "/mounted_app/sse") as streams:
async with ClientSession(*streams) as session:
# Test initialization
@@ -372,12 +343,8 @@ def run_context_server(server_port: int) -> None:
context_server = RequestContextServer()
async def handle_sse(request: Request) -> Response:
async with sse.connect_sse(
request.scope, request.receive, request._send
) as streams:
await context_server.run(
streams[0], streams[1], context_server.create_initialization_options()
)
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
await context_server.run(streams[0], streams[1], context_server.create_initialization_options())
return Response()
app = Starlette(
@@ -387,11 +354,7 @@ def run_context_server(server_port: int) -> None:
]
)
server = uvicorn.Server(
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
print(f"starting context server on {server_port}")
server.run()
@@ -399,9 +362,7 @@ def run_context_server(server_port: int) -> None:
@pytest.fixture()
def context_server(server_port: int) -> Generator[None, None, None]:
"""Fixture that provides a server with request context capture"""
proc = multiprocessing.Process(
target=run_context_server, kwargs={"server_port": server_port}, daemon=True
)
proc = multiprocessing.Process(target=run_context_server, kwargs={"server_port": server_port}, daemon=True)
print("starting context server process")
proc.start()
@@ -418,9 +379,7 @@ def context_server(server_port: int) -> Generator[None, None, None]:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(
f"Context server failed to start after {max_attempts} attempts"
)
raise RuntimeError(f"Context server failed to start after {max_attempts} attempts")
yield
@@ -432,9 +391,7 @@ def context_server(server_port: int) -> Generator[None, None, None]:
@pytest.mark.anyio
async def test_request_context_propagation(
context_server: None, server_url: str
) -> None:
async def test_request_context_propagation(context_server: None, server_url: str) -> None:
"""Test that request context is properly propagated through SSE transport."""
# Test with custom headers
custom_headers = {
@@ -458,11 +415,7 @@ async def test_request_context_propagation(
# Parse the JSON response
assert len(tool_result.content) == 1
headers_data = json.loads(
tool_result.content[0].text
if tool_result.content[0].type == "text"
else "{}"
)
headers_data = json.loads(tool_result.content[0].text if tool_result.content[0].type == "text" else "{}")
# Verify headers were propagated
assert headers_data.get("authorization") == "Bearer test-token"
@@ -487,15 +440,11 @@ async def test_request_context_isolation(context_server: None, server_url: str)
await session.initialize()
# Call the tool that echoes context
tool_result = await session.call_tool(
"echo_context", {"request_id": f"request-{i}"}
)
tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"})
assert len(tool_result.content) == 1
context_data = json.loads(
tool_result.content[0].text
if tool_result.content[0].type == "text"
else "{}"
tool_result.content[0].text if tool_result.content[0].type == "text" else "{}"
)
contexts.append(context_data)
@@ -514,8 +463,4 @@ def test_sse_message_id_coercion():
"""
json_message = '{"jsonrpc": "2.0", "id": "123", "method": "ping", "params": null}'
msg = types.JSONRPCMessage.model_validate_json(json_message)
assert msg == snapshot(
types.JSONRPCMessage(
root=types.JSONRPCRequest(method="ping", jsonrpc="2.0", id=123)
)
)
assert msg == snapshot(types.JSONRPCMessage(root=types.JSONRPCRequest(method="ping", jsonrpc="2.0", id=123)))

View File

@@ -72,9 +72,7 @@ class SimpleEventStore(EventStore):
self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage]] = []
self._event_id_counter = 0
async def store_event(
self, stream_id: StreamId, message: types.JSONRPCMessage
) -> EventId:
async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage) -> EventId:
"""Store an event and return its ID."""
self._event_id_counter += 1
event_id = str(self._event_id_counter)
@@ -156,9 +154,7 @@ class ServerTest(Server):
# When the tool is called, send a notification to test GET stream
if name == "test_tool_with_standalone_notification":
await ctx.session.send_resource_updated(
uri=AnyUrl("http://test_resource")
)
await ctx.session.send_resource_updated(uri=AnyUrl("http://test_resource"))
return [TextContent(type="text", text=f"Called {name}")]
elif name == "long_running_with_checkpoints":
@@ -189,9 +185,7 @@ class ServerTest(Server):
messages=[
types.SamplingMessage(
role="user",
content=types.TextContent(
type="text", text="Server needs client sampling"
),
content=types.TextContent(type="text", text="Server needs client sampling"),
)
],
max_tokens=100,
@@ -199,11 +193,7 @@ class ServerTest(Server):
)
# Return the sampling result in the tool response
response = (
sampling_result.content.text
if sampling_result.content.type == "text"
else None
)
response = sampling_result.content.text if sampling_result.content.type == "text" else None
return [
TextContent(
type="text",
@@ -214,9 +204,7 @@ class ServerTest(Server):
return [TextContent(type="text", text=f"Called {name}")]
def create_app(
is_json_response_enabled=False, event_store: EventStore | None = None
) -> Starlette:
def create_app(is_json_response_enabled=False, event_store: EventStore | None = None) -> Starlette:
"""Create a Starlette application for testing using the session manager.
Args:
@@ -245,9 +233,7 @@ def create_app(
return app
def run_server(
port: int, is_json_response_enabled=False, event_store: EventStore | None = None
) -> None:
def run_server(port: int, is_json_response_enabled=False, event_store: EventStore | None = None) -> None:
"""Run the test server.
Args:
@@ -300,9 +286,7 @@ def json_server_port() -> int:
@pytest.fixture
def basic_server(basic_server_port: int) -> Generator[None, None, None]:
"""Start a basic server."""
proc = multiprocessing.Process(
target=run_server, kwargs={"port": basic_server_port}, daemon=True
)
proc = multiprocessing.Process(target=run_server, kwargs={"port": basic_server_port}, daemon=True)
proc.start()
# Wait for server to be running
@@ -778,9 +762,7 @@ async def test_streamablehttp_client_basic_connection(basic_server, basic_server
@pytest.mark.anyio
async def test_streamablehttp_client_resource_read(initialized_client_session):
"""Test client resource read functionality."""
response = await initialized_client_session.read_resource(
uri=AnyUrl("foobar://test-resource")
)
response = await initialized_client_session.read_resource(uri=AnyUrl("foobar://test-resource"))
assert len(response.contents) == 1
assert response.contents[0].uri == AnyUrl("foobar://test-resource")
assert response.contents[0].text == "Read test-resource"
@@ -805,17 +787,13 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session)
async def test_streamablehttp_client_error_handling(initialized_client_session):
"""Test error handling in client."""
with pytest.raises(McpError) as exc_info:
await initialized_client_session.read_resource(
uri=AnyUrl("unknown://test-error")
)
await initialized_client_session.read_resource(uri=AnyUrl("unknown://test-error"))
assert exc_info.value.error.code == 0
assert "Unknown resource: unknown://test-error" in exc_info.value.error.message
@pytest.mark.anyio
async def test_streamablehttp_client_session_persistence(
basic_server, basic_server_url
):
async def test_streamablehttp_client_session_persistence(basic_server, basic_server_url):
"""Test that session ID persists across requests."""
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
read_stream,
@@ -843,9 +821,7 @@ async def test_streamablehttp_client_session_persistence(
@pytest.mark.anyio
async def test_streamablehttp_client_json_response(
json_response_server, json_server_url
):
async def test_streamablehttp_client_json_response(json_response_server, json_server_url):
"""Test client with JSON response mode."""
async with streamablehttp_client(f"{json_server_url}/mcp") as (
read_stream,
@@ -882,9 +858,7 @@ async def test_streamablehttp_client_get_stream(basic_server, basic_server_url):
# Define message handler to capture notifications
async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, types.ServerNotification):
notifications_received.append(message)
@@ -894,9 +868,7 @@ async def test_streamablehttp_client_get_stream(basic_server, basic_server_url):
write_stream,
_,
):
async with ClientSession(
read_stream, write_stream, message_handler=message_handler
) as session:
async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
# Initialize the session - this triggers the GET stream setup
result = await session.initialize()
assert isinstance(result, InitializeResult)
@@ -914,15 +886,11 @@ async def test_streamablehttp_client_get_stream(basic_server, basic_server_url):
assert str(notif.root.params.uri) == "http://test_resource/"
resource_update_found = True
assert (
resource_update_found
), "ResourceUpdatedNotification not received via GET stream"
assert resource_update_found, "ResourceUpdatedNotification not received via GET stream"
@pytest.mark.anyio
async def test_streamablehttp_client_session_termination(
basic_server, basic_server_url
):
async def test_streamablehttp_client_session_termination(basic_server, basic_server_url):
"""Test client session termination functionality."""
captured_session_id = None
@@ -963,9 +931,7 @@ async def test_streamablehttp_client_session_termination(
@pytest.mark.anyio
async def test_streamablehttp_client_session_termination_204(
basic_server, basic_server_url, monkeypatch
):
async def test_streamablehttp_client_session_termination_204(basic_server, basic_server_url, monkeypatch):
"""Test client session termination functionality with a 204 response.
This test patches the httpx client to return a 204 response for DELETEs.
@@ -1040,9 +1006,7 @@ async def test_streamablehttp_client_resumption(event_server):
tool_started = False
async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, types.ServerNotification):
captured_notifications.append(message)
@@ -1062,9 +1026,7 @@ async def test_streamablehttp_client_resumption(event_server):
write_stream,
get_session_id,
):
async with ClientSession(
read_stream, write_stream, message_handler=message_handler
) as session:
async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
# Initialize the session
result = await session.initialize()
assert isinstance(result, InitializeResult)
@@ -1082,9 +1044,7 @@ async def test_streamablehttp_client_resumption(event_server):
types.ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(
name="long_running_with_checkpoints", arguments={}
),
params=types.CallToolRequestParams(name="long_running_with_checkpoints", arguments={}),
)
),
types.CallToolResult,
@@ -1114,9 +1074,7 @@ async def test_streamablehttp_client_resumption(event_server):
write_stream,
_,
):
async with ClientSession(
read_stream, write_stream, message_handler=message_handler
) as session:
async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
# Don't initialize - just use the existing session
# Resume the tool with the resumption token
@@ -1129,9 +1087,7 @@ async def test_streamablehttp_client_resumption(event_server):
types.ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(
name="long_running_with_checkpoints", arguments={}
),
params=types.CallToolRequestParams(name="long_running_with_checkpoints", arguments={}),
)
),
types.CallToolResult,
@@ -1149,14 +1105,11 @@ async def test_streamablehttp_client_resumption(event_server):
# Should not have the first notification
# Check that "Tool started" notification isn't repeated when resuming
assert not any(
isinstance(n.root, types.LoggingMessageNotification)
and n.root.params.data == "Tool started"
isinstance(n.root, types.LoggingMessageNotification) and n.root.params.data == "Tool started"
for n in captured_notifications
)
# there is no intersection between pre and post notifications
assert not any(
n in captured_notifications_pre for n in captured_notifications
)
assert not any(n in captured_notifications_pre for n in captured_notifications)
@pytest.mark.anyio
@@ -1175,11 +1128,7 @@ async def test_streamablehttp_server_sampling(basic_server, basic_server_url):
nonlocal sampling_callback_invoked, captured_message_params
sampling_callback_invoked = True
captured_message_params = params
message_received = (
params.messages[0].content.text
if params.messages[0].content.type == "text"
else None
)
message_received = params.messages[0].content.text if params.messages[0].content.type == "text" else None
return types.CreateMessageResult(
role="assistant",
@@ -1212,19 +1161,13 @@ async def test_streamablehttp_server_sampling(basic_server, basic_server_url):
# Verify the tool result contains the expected content
assert len(tool_result.content) == 1
assert tool_result.content[0].type == "text"
assert (
"Response from sampling: Received message from server"
in tool_result.content[0].text
)
assert "Response from sampling: Received message from server" in tool_result.content[0].text
# Verify sampling callback was invoked
assert sampling_callback_invoked
assert captured_message_params is not None
assert len(captured_message_params.messages) == 1
assert (
captured_message_params.messages[0].content.text
== "Server needs client sampling"
)
assert captured_message_params.messages[0].content.text == "Server needs client sampling"
# Context-aware server implementation for testing request context propagation
@@ -1325,9 +1268,7 @@ def run_context_aware_server(port: int):
@pytest.fixture
def context_aware_server(basic_server_port: int) -> Generator[None, None, None]:
"""Start the context-aware server in a separate process."""
proc = multiprocessing.Process(
target=run_context_aware_server, args=(basic_server_port,), daemon=True
)
proc = multiprocessing.Process(target=run_context_aware_server, args=(basic_server_port,), daemon=True)
proc.start()
# Wait for server to be running
@@ -1342,9 +1283,7 @@ def context_aware_server(basic_server_port: int) -> Generator[None, None, None]:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(
f"Context-aware server failed to start after {max_attempts} attempts"
)
raise RuntimeError(f"Context-aware server failed to start after {max_attempts} attempts")
yield
@@ -1355,9 +1294,7 @@ def context_aware_server(basic_server_port: int) -> Generator[None, None, None]:
@pytest.mark.anyio
async def test_streamablehttp_request_context_propagation(
context_aware_server: None, basic_server_url: str
) -> None:
async def test_streamablehttp_request_context_propagation(context_aware_server: None, basic_server_url: str) -> None:
"""Test that request context is properly propagated through StreamableHTTP."""
custom_headers = {
"Authorization": "Bearer test-token",
@@ -1365,9 +1302,11 @@ async def test_streamablehttp_request_context_propagation(
"X-Trace-Id": "trace-123",
}
async with streamablehttp_client(
f"{basic_server_url}/mcp", headers=custom_headers
) as (read_stream, write_stream, _):
async with streamablehttp_client(f"{basic_server_url}/mcp", headers=custom_headers) as (
read_stream,
write_stream,
_,
):
async with ClientSession(read_stream, write_stream) as session:
result = await session.initialize()
assert isinstance(result, InitializeResult)
@@ -1388,9 +1327,7 @@ async def test_streamablehttp_request_context_propagation(
@pytest.mark.anyio
async def test_streamablehttp_request_context_isolation(
context_aware_server: None, basic_server_url: str
) -> None:
async def test_streamablehttp_request_context_isolation(context_aware_server: None, basic_server_url: str) -> None:
"""Test that request contexts are isolated between StreamableHTTP clients."""
contexts = []
@@ -1402,16 +1339,12 @@ async def test_streamablehttp_request_context_isolation(
"Authorization": f"Bearer token-{i}",
}
async with streamablehttp_client(
f"{basic_server_url}/mcp", headers=headers
) as (read_stream, write_stream, _):
async with streamablehttp_client(f"{basic_server_url}/mcp", headers=headers) as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
# Call the tool that echoes context
tool_result = await session.call_tool(
"echo_context", {"request_id": f"request-{i}"}
)
tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"})
assert len(tool_result.content) == 1
assert isinstance(tool_result.content[0], TextContent)

View File

@@ -54,11 +54,7 @@ class ServerTest(Server):
await anyio.sleep(2.0)
return f"Slow response from {uri.host}"
raise McpError(
error=ErrorData(
code=404, message="OOPS! no resource with that URI was found"
)
)
raise McpError(error=ErrorData(code=404, message="OOPS! no resource with that URI was found"))
@self.list_tools()
async def handle_list_tools() -> list[Tool]:
@@ -81,12 +77,8 @@ def make_server_app() -> Starlette:
server = ServerTest()
async def handle_ws(websocket):
async with websocket_server(
websocket.scope, websocket.receive, websocket.send
) as streams:
await server.run(
streams[0], streams[1], server.create_initialization_options()
)
async with websocket_server(websocket.scope, websocket.receive, websocket.send) as streams:
await server.run(streams[0], streams[1], server.create_initialization_options())
app = Starlette(
routes=[
@@ -99,11 +91,7 @@ def make_server_app() -> Starlette:
def run_server(server_port: int) -> None:
app = make_server_app()
server = uvicorn.Server(
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
print(f"starting server on {server_port}")
server.run()
@@ -115,9 +103,7 @@ def run_server(server_port: int) -> None:
@pytest.fixture()
def server(server_port: int) -> Generator[None, None, None]:
proc = multiprocessing.Process(
target=run_server, kwargs={"server_port": server_port}, daemon=True
)
proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True)
print("starting process")
proc.start()
@@ -147,9 +133,7 @@ def server(server_port: int) -> Generator[None, None, None]:
@pytest.fixture()
async def initialized_ws_client_session(
server, server_url: str
) -> AsyncGenerator[ClientSession, None]:
async def initialized_ws_client_session(server, server_url: str) -> AsyncGenerator[ClientSession, None]:
"""Create and initialize a WebSocket client session"""
async with websocket_client(server_url + "/ws") as streams:
async with ClientSession(*streams) as session:
@@ -186,9 +170,7 @@ async def test_ws_client_happy_request_and_response(
initialized_ws_client_session: ClientSession,
) -> None:
"""Test a successful request and response via WebSocket"""
result = await initialized_ws_client_session.read_resource(
AnyUrl("foobar://example")
)
result = await initialized_ws_client_session.read_resource(AnyUrl("foobar://example"))
assert isinstance(result, ReadResourceResult)
assert isinstance(result.contents, list)
assert len(result.contents) > 0
@@ -218,9 +200,7 @@ async def test_ws_client_timeout(
# Now test that we can still use the session after a timeout
with anyio.fail_after(5): # Longer timeout to allow completion
result = await initialized_ws_client_session.read_resource(
AnyUrl("foobar://example")
)
result = await initialized_ws_client_session.read_resource(AnyUrl("foobar://example"))
assert isinstance(result, ReadResourceResult)
assert isinstance(result.contents, list)
assert len(result.contents) > 0