mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-18 22:44:20 +01:00
StreamableHttp client transport (#573)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
Tests for the StreamableHTTP server transport validation.
|
||||
Tests for the StreamableHTTP server and client transport.
|
||||
|
||||
This file contains tests for request validation in the StreamableHTTP transport.
|
||||
Contains tests for both server and client sides of the StreamableHTTP transport.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
@@ -13,6 +13,7 @@ from http import HTTPStatus
|
||||
from uuid import uuid4
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
import uvicorn
|
||||
@@ -22,18 +23,16 @@ from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import Mount
|
||||
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from mcp.server import Server
|
||||
from mcp.server.streamableHttp import (
|
||||
from mcp.server.streamable_http import (
|
||||
MCP_SESSION_ID_HEADER,
|
||||
SESSION_ID_PATTERN,
|
||||
StreamableHTTPServerTransport,
|
||||
)
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.types import (
|
||||
ErrorData,
|
||||
TextContent,
|
||||
Tool,
|
||||
)
|
||||
from mcp.types import InitializeResult, TextContent, TextResourceContents, Tool
|
||||
|
||||
# Test constants
|
||||
SERVER_NAME = "test_streamable_http_server"
|
||||
@@ -64,11 +63,7 @@ class ServerTest(Server):
|
||||
await anyio.sleep(2.0)
|
||||
return f"Slow response from {uri.host}"
|
||||
|
||||
raise McpError(
|
||||
error=ErrorData(
|
||||
code=404, message="OOPS! no resource with that URI was found"
|
||||
)
|
||||
)
|
||||
raise ValueError(f"Unknown resource: {uri}")
|
||||
|
||||
@self.list_tools()
|
||||
async def handle_list_tools() -> list[Tool]:
|
||||
@@ -77,11 +72,23 @@ class ServerTest(Server):
|
||||
name="test_tool",
|
||||
description="A test tool",
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
)
|
||||
),
|
||||
Tool(
|
||||
name="test_tool_with_standalone_notification",
|
||||
description="A test tool that sends a notification",
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
),
|
||||
]
|
||||
|
||||
@self.call_tool()
|
||||
async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
|
||||
# When the tool is called, send a notification to test GET stream
|
||||
if name == "test_tool_with_standalone_notification":
|
||||
ctx = self.request_context
|
||||
await ctx.session.send_resource_updated(
|
||||
uri=AnyUrl("http://test_resource")
|
||||
)
|
||||
|
||||
return [TextContent(type="text", text=f"Called {name}")]
|
||||
|
||||
|
||||
@@ -630,3 +637,219 @@ def test_get_validation(basic_server, basic_server_url):
|
||||
)
|
||||
assert response.status_code == 406
|
||||
assert "Not Acceptable" in response.text
|
||||
|
||||
|
||||
# Client-specific fixtures
|
||||
@pytest.fixture
|
||||
async def http_client(basic_server, basic_server_url):
|
||||
"""Create test client matching the SSE test pattern."""
|
||||
async with httpx.AsyncClient(base_url=basic_server_url) as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def initialized_client_session(basic_server, basic_server_url):
|
||||
"""Create initialized StreamableHTTP client session."""
|
||||
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
_,
|
||||
):
|
||||
async with ClientSession(
|
||||
read_stream,
|
||||
write_stream,
|
||||
) as session:
|
||||
await session.initialize()
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_streamablehttp_client_basic_connection(basic_server, basic_server_url):
|
||||
"""Test basic client connection with initialization."""
|
||||
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
_,
|
||||
):
|
||||
async with ClientSession(
|
||||
read_stream,
|
||||
write_stream,
|
||||
) as session:
|
||||
# Test initialization
|
||||
result = await session.initialize()
|
||||
assert isinstance(result, InitializeResult)
|
||||
assert result.serverInfo.name == SERVER_NAME
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_streamablehttp_client_resource_read(initialized_client_session):
|
||||
"""Test client resource read functionality."""
|
||||
response = await initialized_client_session.read_resource(
|
||||
uri=AnyUrl("foobar://test-resource")
|
||||
)
|
||||
assert len(response.contents) == 1
|
||||
assert response.contents[0].uri == AnyUrl("foobar://test-resource")
|
||||
assert response.contents[0].text == "Read test-resource"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_streamablehttp_client_tool_invocation(initialized_client_session):
|
||||
"""Test client tool invocation."""
|
||||
# First list tools
|
||||
tools = await initialized_client_session.list_tools()
|
||||
assert len(tools.tools) == 2
|
||||
assert tools.tools[0].name == "test_tool"
|
||||
|
||||
# Call the tool
|
||||
result = await initialized_client_session.call_tool("test_tool", {})
|
||||
assert len(result.content) == 1
|
||||
assert result.content[0].type == "text"
|
||||
assert result.content[0].text == "Called test_tool"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_streamablehttp_client_error_handling(initialized_client_session):
|
||||
"""Test error handling in client."""
|
||||
with pytest.raises(McpError) as exc_info:
|
||||
await initialized_client_session.read_resource(
|
||||
uri=AnyUrl("unknown://test-error")
|
||||
)
|
||||
assert exc_info.value.error.code == 0
|
||||
assert "Unknown resource: unknown://test-error" in exc_info.value.error.message
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_streamablehttp_client_session_persistence(
|
||||
basic_server, basic_server_url
|
||||
):
|
||||
"""Test that session ID persists across requests."""
|
||||
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
_,
|
||||
):
|
||||
async with ClientSession(
|
||||
read_stream,
|
||||
write_stream,
|
||||
) as session:
|
||||
# Initialize the session
|
||||
result = await session.initialize()
|
||||
assert isinstance(result, InitializeResult)
|
||||
|
||||
# Make multiple requests to verify session persistence
|
||||
tools = await session.list_tools()
|
||||
assert len(tools.tools) == 2
|
||||
|
||||
# Read a resource
|
||||
resource = await session.read_resource(uri=AnyUrl("foobar://test-persist"))
|
||||
assert isinstance(resource.contents[0], TextResourceContents) is True
|
||||
content = resource.contents[0]
|
||||
assert isinstance(content, TextResourceContents)
|
||||
assert content.text == "Read test-persist"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_streamablehttp_client_json_response(
|
||||
json_response_server, json_server_url
|
||||
):
|
||||
"""Test client with JSON response mode."""
|
||||
async with streamablehttp_client(f"{json_server_url}/mcp") as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
_,
|
||||
):
|
||||
async with ClientSession(
|
||||
read_stream,
|
||||
write_stream,
|
||||
) as session:
|
||||
# Initialize the session
|
||||
result = await session.initialize()
|
||||
assert isinstance(result, InitializeResult)
|
||||
assert result.serverInfo.name == SERVER_NAME
|
||||
|
||||
# Check tool listing
|
||||
tools = await session.list_tools()
|
||||
assert len(tools.tools) == 2
|
||||
|
||||
# Call a tool and verify JSON response handling
|
||||
result = await session.call_tool("test_tool", {})
|
||||
assert len(result.content) == 1
|
||||
assert result.content[0].type == "text"
|
||||
assert result.content[0].text == "Called test_tool"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_streamablehttp_client_get_stream(basic_server, basic_server_url):
|
||||
"""Test GET stream functionality for server-initiated messages."""
|
||||
import mcp.types as types
|
||||
from mcp.shared.session import RequestResponder
|
||||
|
||||
notifications_received = []
|
||||
|
||||
# Define message handler to capture notifications
|
||||
async def message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
| types.ServerNotification
|
||||
| Exception,
|
||||
) -> None:
|
||||
if isinstance(message, types.ServerNotification):
|
||||
notifications_received.append(message)
|
||||
|
||||
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
_,
|
||||
):
|
||||
async with ClientSession(
|
||||
read_stream, write_stream, message_handler=message_handler
|
||||
) as session:
|
||||
# Initialize the session - this triggers the GET stream setup
|
||||
result = await session.initialize()
|
||||
assert isinstance(result, InitializeResult)
|
||||
|
||||
# Call the special tool that sends a notification
|
||||
await session.call_tool("test_tool_with_standalone_notification", {})
|
||||
|
||||
# Verify we received the notification
|
||||
assert len(notifications_received) > 0
|
||||
|
||||
# Verify the notification is a ResourceUpdatedNotification
|
||||
resource_update_found = False
|
||||
for notif in notifications_received:
|
||||
if isinstance(notif.root, types.ResourceUpdatedNotification):
|
||||
assert str(notif.root.params.uri) == "http://test_resource/"
|
||||
resource_update_found = True
|
||||
|
||||
assert (
|
||||
resource_update_found
|
||||
), "ResourceUpdatedNotification not received via GET stream"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_streamablehttp_client_session_termination(
|
||||
basic_server, basic_server_url
|
||||
):
|
||||
"""Test client session termination functionality."""
|
||||
|
||||
# Create the streamablehttp_client with a custom httpx client to capture headers
|
||||
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
terminate_session,
|
||||
):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
# Initialize the session
|
||||
result = await session.initialize()
|
||||
assert isinstance(result, InitializeResult)
|
||||
|
||||
# Make a request to confirm session is working
|
||||
tools = await session.list_tools()
|
||||
assert len(tools.tools) == 2
|
||||
|
||||
# After exiting ClientSession context, explicitly terminate the session
|
||||
await terminate_session()
|
||||
with pytest.raises(
|
||||
McpError,
|
||||
match="Session terminated",
|
||||
):
|
||||
await session.list_tools()
|
||||
Reference in New Issue
Block a user