mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 06:54:18 +01:00
Use 120 characters instead of 88 (#856)
This commit is contained in:
committed by
GitHub
parent
f7265f7b91
commit
543961968c
@@ -72,9 +72,7 @@ class SimpleEventStore(EventStore):
|
||||
self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage]] = []
|
||||
self._event_id_counter = 0
|
||||
|
||||
async def store_event(
|
||||
self, stream_id: StreamId, message: types.JSONRPCMessage
|
||||
) -> EventId:
|
||||
async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage) -> EventId:
|
||||
"""Store an event and return its ID."""
|
||||
self._event_id_counter += 1
|
||||
event_id = str(self._event_id_counter)
|
||||
@@ -156,9 +154,7 @@ class ServerTest(Server):
|
||||
|
||||
# When the tool is called, send a notification to test GET stream
|
||||
if name == "test_tool_with_standalone_notification":
|
||||
await ctx.session.send_resource_updated(
|
||||
uri=AnyUrl("http://test_resource")
|
||||
)
|
||||
await ctx.session.send_resource_updated(uri=AnyUrl("http://test_resource"))
|
||||
return [TextContent(type="text", text=f"Called {name}")]
|
||||
|
||||
elif name == "long_running_with_checkpoints":
|
||||
@@ -189,9 +185,7 @@ class ServerTest(Server):
|
||||
messages=[
|
||||
types.SamplingMessage(
|
||||
role="user",
|
||||
content=types.TextContent(
|
||||
type="text", text="Server needs client sampling"
|
||||
),
|
||||
content=types.TextContent(type="text", text="Server needs client sampling"),
|
||||
)
|
||||
],
|
||||
max_tokens=100,
|
||||
@@ -199,11 +193,7 @@ class ServerTest(Server):
|
||||
)
|
||||
|
||||
# Return the sampling result in the tool response
|
||||
response = (
|
||||
sampling_result.content.text
|
||||
if sampling_result.content.type == "text"
|
||||
else None
|
||||
)
|
||||
response = sampling_result.content.text if sampling_result.content.type == "text" else None
|
||||
return [
|
||||
TextContent(
|
||||
type="text",
|
||||
@@ -214,9 +204,7 @@ class ServerTest(Server):
|
||||
return [TextContent(type="text", text=f"Called {name}")]
|
||||
|
||||
|
||||
def create_app(
|
||||
is_json_response_enabled=False, event_store: EventStore | None = None
|
||||
) -> Starlette:
|
||||
def create_app(is_json_response_enabled=False, event_store: EventStore | None = None) -> Starlette:
|
||||
"""Create a Starlette application for testing using the session manager.
|
||||
|
||||
Args:
|
||||
@@ -245,9 +233,7 @@ def create_app(
|
||||
return app
|
||||
|
||||
|
||||
def run_server(
|
||||
port: int, is_json_response_enabled=False, event_store: EventStore | None = None
|
||||
) -> None:
|
||||
def run_server(port: int, is_json_response_enabled=False, event_store: EventStore | None = None) -> None:
|
||||
"""Run the test server.
|
||||
|
||||
Args:
|
||||
@@ -300,9 +286,7 @@ def json_server_port() -> int:
|
||||
@pytest.fixture
|
||||
def basic_server(basic_server_port: int) -> Generator[None, None, None]:
|
||||
"""Start a basic server."""
|
||||
proc = multiprocessing.Process(
|
||||
target=run_server, kwargs={"port": basic_server_port}, daemon=True
|
||||
)
|
||||
proc = multiprocessing.Process(target=run_server, kwargs={"port": basic_server_port}, daemon=True)
|
||||
proc.start()
|
||||
|
||||
# Wait for server to be running
|
||||
@@ -778,9 +762,7 @@ async def test_streamablehttp_client_basic_connection(basic_server, basic_server
|
||||
@pytest.mark.anyio
|
||||
async def test_streamablehttp_client_resource_read(initialized_client_session):
|
||||
"""Test client resource read functionality."""
|
||||
response = await initialized_client_session.read_resource(
|
||||
uri=AnyUrl("foobar://test-resource")
|
||||
)
|
||||
response = await initialized_client_session.read_resource(uri=AnyUrl("foobar://test-resource"))
|
||||
assert len(response.contents) == 1
|
||||
assert response.contents[0].uri == AnyUrl("foobar://test-resource")
|
||||
assert response.contents[0].text == "Read test-resource"
|
||||
@@ -805,17 +787,13 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session)
|
||||
async def test_streamablehttp_client_error_handling(initialized_client_session):
|
||||
"""Test error handling in client."""
|
||||
with pytest.raises(McpError) as exc_info:
|
||||
await initialized_client_session.read_resource(
|
||||
uri=AnyUrl("unknown://test-error")
|
||||
)
|
||||
await initialized_client_session.read_resource(uri=AnyUrl("unknown://test-error"))
|
||||
assert exc_info.value.error.code == 0
|
||||
assert "Unknown resource: unknown://test-error" in exc_info.value.error.message
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_streamablehttp_client_session_persistence(
|
||||
basic_server, basic_server_url
|
||||
):
|
||||
async def test_streamablehttp_client_session_persistence(basic_server, basic_server_url):
|
||||
"""Test that session ID persists across requests."""
|
||||
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
|
||||
read_stream,
|
||||
@@ -843,9 +821,7 @@ async def test_streamablehttp_client_session_persistence(
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_streamablehttp_client_json_response(
|
||||
json_response_server, json_server_url
|
||||
):
|
||||
async def test_streamablehttp_client_json_response(json_response_server, json_server_url):
|
||||
"""Test client with JSON response mode."""
|
||||
async with streamablehttp_client(f"{json_server_url}/mcp") as (
|
||||
read_stream,
|
||||
@@ -882,9 +858,7 @@ async def test_streamablehttp_client_get_stream(basic_server, basic_server_url):
|
||||
|
||||
# Define message handler to capture notifications
|
||||
async def message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
| types.ServerNotification
|
||||
| Exception,
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None:
|
||||
if isinstance(message, types.ServerNotification):
|
||||
notifications_received.append(message)
|
||||
@@ -894,9 +868,7 @@ async def test_streamablehttp_client_get_stream(basic_server, basic_server_url):
|
||||
write_stream,
|
||||
_,
|
||||
):
|
||||
async with ClientSession(
|
||||
read_stream, write_stream, message_handler=message_handler
|
||||
) as session:
|
||||
async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
|
||||
# Initialize the session - this triggers the GET stream setup
|
||||
result = await session.initialize()
|
||||
assert isinstance(result, InitializeResult)
|
||||
@@ -914,15 +886,11 @@ async def test_streamablehttp_client_get_stream(basic_server, basic_server_url):
|
||||
assert str(notif.root.params.uri) == "http://test_resource/"
|
||||
resource_update_found = True
|
||||
|
||||
assert (
|
||||
resource_update_found
|
||||
), "ResourceUpdatedNotification not received via GET stream"
|
||||
assert resource_update_found, "ResourceUpdatedNotification not received via GET stream"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_streamablehttp_client_session_termination(
|
||||
basic_server, basic_server_url
|
||||
):
|
||||
async def test_streamablehttp_client_session_termination(basic_server, basic_server_url):
|
||||
"""Test client session termination functionality."""
|
||||
|
||||
captured_session_id = None
|
||||
@@ -963,9 +931,7 @@ async def test_streamablehttp_client_session_termination(
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_streamablehttp_client_session_termination_204(
|
||||
basic_server, basic_server_url, monkeypatch
|
||||
):
|
||||
async def test_streamablehttp_client_session_termination_204(basic_server, basic_server_url, monkeypatch):
|
||||
"""Test client session termination functionality with a 204 response.
|
||||
|
||||
This test patches the httpx client to return a 204 response for DELETEs.
|
||||
@@ -1040,9 +1006,7 @@ async def test_streamablehttp_client_resumption(event_server):
|
||||
tool_started = False
|
||||
|
||||
async def message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
| types.ServerNotification
|
||||
| Exception,
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None:
|
||||
if isinstance(message, types.ServerNotification):
|
||||
captured_notifications.append(message)
|
||||
@@ -1062,9 +1026,7 @@ async def test_streamablehttp_client_resumption(event_server):
|
||||
write_stream,
|
||||
get_session_id,
|
||||
):
|
||||
async with ClientSession(
|
||||
read_stream, write_stream, message_handler=message_handler
|
||||
) as session:
|
||||
async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
|
||||
# Initialize the session
|
||||
result = await session.initialize()
|
||||
assert isinstance(result, InitializeResult)
|
||||
@@ -1082,9 +1044,7 @@ async def test_streamablehttp_client_resumption(event_server):
|
||||
types.ClientRequest(
|
||||
types.CallToolRequest(
|
||||
method="tools/call",
|
||||
params=types.CallToolRequestParams(
|
||||
name="long_running_with_checkpoints", arguments={}
|
||||
),
|
||||
params=types.CallToolRequestParams(name="long_running_with_checkpoints", arguments={}),
|
||||
)
|
||||
),
|
||||
types.CallToolResult,
|
||||
@@ -1114,9 +1074,7 @@ async def test_streamablehttp_client_resumption(event_server):
|
||||
write_stream,
|
||||
_,
|
||||
):
|
||||
async with ClientSession(
|
||||
read_stream, write_stream, message_handler=message_handler
|
||||
) as session:
|
||||
async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
|
||||
# Don't initialize - just use the existing session
|
||||
|
||||
# Resume the tool with the resumption token
|
||||
@@ -1129,9 +1087,7 @@ async def test_streamablehttp_client_resumption(event_server):
|
||||
types.ClientRequest(
|
||||
types.CallToolRequest(
|
||||
method="tools/call",
|
||||
params=types.CallToolRequestParams(
|
||||
name="long_running_with_checkpoints", arguments={}
|
||||
),
|
||||
params=types.CallToolRequestParams(name="long_running_with_checkpoints", arguments={}),
|
||||
)
|
||||
),
|
||||
types.CallToolResult,
|
||||
@@ -1149,14 +1105,11 @@ async def test_streamablehttp_client_resumption(event_server):
|
||||
# Should not have the first notification
|
||||
# Check that "Tool started" notification isn't repeated when resuming
|
||||
assert not any(
|
||||
isinstance(n.root, types.LoggingMessageNotification)
|
||||
and n.root.params.data == "Tool started"
|
||||
isinstance(n.root, types.LoggingMessageNotification) and n.root.params.data == "Tool started"
|
||||
for n in captured_notifications
|
||||
)
|
||||
# there is no intersection between pre and post notifications
|
||||
assert not any(
|
||||
n in captured_notifications_pre for n in captured_notifications
|
||||
)
|
||||
assert not any(n in captured_notifications_pre for n in captured_notifications)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -1175,11 +1128,7 @@ async def test_streamablehttp_server_sampling(basic_server, basic_server_url):
|
||||
nonlocal sampling_callback_invoked, captured_message_params
|
||||
sampling_callback_invoked = True
|
||||
captured_message_params = params
|
||||
message_received = (
|
||||
params.messages[0].content.text
|
||||
if params.messages[0].content.type == "text"
|
||||
else None
|
||||
)
|
||||
message_received = params.messages[0].content.text if params.messages[0].content.type == "text" else None
|
||||
|
||||
return types.CreateMessageResult(
|
||||
role="assistant",
|
||||
@@ -1212,19 +1161,13 @@ async def test_streamablehttp_server_sampling(basic_server, basic_server_url):
|
||||
# Verify the tool result contains the expected content
|
||||
assert len(tool_result.content) == 1
|
||||
assert tool_result.content[0].type == "text"
|
||||
assert (
|
||||
"Response from sampling: Received message from server"
|
||||
in tool_result.content[0].text
|
||||
)
|
||||
assert "Response from sampling: Received message from server" in tool_result.content[0].text
|
||||
|
||||
# Verify sampling callback was invoked
|
||||
assert sampling_callback_invoked
|
||||
assert captured_message_params is not None
|
||||
assert len(captured_message_params.messages) == 1
|
||||
assert (
|
||||
captured_message_params.messages[0].content.text
|
||||
== "Server needs client sampling"
|
||||
)
|
||||
assert captured_message_params.messages[0].content.text == "Server needs client sampling"
|
||||
|
||||
|
||||
# Context-aware server implementation for testing request context propagation
|
||||
@@ -1325,9 +1268,7 @@ def run_context_aware_server(port: int):
|
||||
@pytest.fixture
|
||||
def context_aware_server(basic_server_port: int) -> Generator[None, None, None]:
|
||||
"""Start the context-aware server in a separate process."""
|
||||
proc = multiprocessing.Process(
|
||||
target=run_context_aware_server, args=(basic_server_port,), daemon=True
|
||||
)
|
||||
proc = multiprocessing.Process(target=run_context_aware_server, args=(basic_server_port,), daemon=True)
|
||||
proc.start()
|
||||
|
||||
# Wait for server to be running
|
||||
@@ -1342,9 +1283,7 @@ def context_aware_server(basic_server_port: int) -> Generator[None, None, None]:
|
||||
time.sleep(0.1)
|
||||
attempt += 1
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Context-aware server failed to start after {max_attempts} attempts"
|
||||
)
|
||||
raise RuntimeError(f"Context-aware server failed to start after {max_attempts} attempts")
|
||||
|
||||
yield
|
||||
|
||||
@@ -1355,9 +1294,7 @@ def context_aware_server(basic_server_port: int) -> Generator[None, None, None]:
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_streamablehttp_request_context_propagation(
|
||||
context_aware_server: None, basic_server_url: str
|
||||
) -> None:
|
||||
async def test_streamablehttp_request_context_propagation(context_aware_server: None, basic_server_url: str) -> None:
|
||||
"""Test that request context is properly propagated through StreamableHTTP."""
|
||||
custom_headers = {
|
||||
"Authorization": "Bearer test-token",
|
||||
@@ -1365,9 +1302,11 @@ async def test_streamablehttp_request_context_propagation(
|
||||
"X-Trace-Id": "trace-123",
|
||||
}
|
||||
|
||||
async with streamablehttp_client(
|
||||
f"{basic_server_url}/mcp", headers=custom_headers
|
||||
) as (read_stream, write_stream, _):
|
||||
async with streamablehttp_client(f"{basic_server_url}/mcp", headers=custom_headers) as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
_,
|
||||
):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
result = await session.initialize()
|
||||
assert isinstance(result, InitializeResult)
|
||||
@@ -1388,9 +1327,7 @@ async def test_streamablehttp_request_context_propagation(
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_streamablehttp_request_context_isolation(
|
||||
context_aware_server: None, basic_server_url: str
|
||||
) -> None:
|
||||
async def test_streamablehttp_request_context_isolation(context_aware_server: None, basic_server_url: str) -> None:
|
||||
"""Test that request contexts are isolated between StreamableHTTP clients."""
|
||||
contexts = []
|
||||
|
||||
@@ -1402,16 +1339,12 @@ async def test_streamablehttp_request_context_isolation(
|
||||
"Authorization": f"Bearer token-{i}",
|
||||
}
|
||||
|
||||
async with streamablehttp_client(
|
||||
f"{basic_server_url}/mcp", headers=headers
|
||||
) as (read_stream, write_stream, _):
|
||||
async with streamablehttp_client(f"{basic_server_url}/mcp", headers=headers) as (read_stream, write_stream, _):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
await session.initialize()
|
||||
|
||||
# Call the tool that echoes context
|
||||
tool_result = await session.call_tool(
|
||||
"echo_context", {"request_id": f"request-{i}"}
|
||||
)
|
||||
tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"})
|
||||
|
||||
assert len(tool_result.content) == 1
|
||||
assert isinstance(tool_result.content[0], TextContent)
|
||||
|
||||
Reference in New Issue
Block a user