Streamable HTTP - improve usability, fast mcp and auth (#641)

This commit is contained in:
ihrpr
2025-05-08 20:43:25 +01:00
committed by GitHub
parent 280bab36f4
commit e4e119b324
7 changed files with 750 additions and 229 deletions

View File

@@ -15,6 +15,7 @@ import uvicorn
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.server.fastmcp import FastMCP
from mcp.types import InitializeResult, TextContent
@@ -33,6 +34,34 @@ def server_url(server_port: int) -> str:
return f"http://127.0.0.1:{server_port}"
@pytest.fixture
def http_server_port() -> int:
"""Get a free port for testing the StreamableHTTP server."""
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
@pytest.fixture
def http_server_url(http_server_port: int) -> str:
"""Get the StreamableHTTP server URL for testing."""
return f"http://127.0.0.1:{http_server_port}"
@pytest.fixture
def stateless_http_server_port() -> int:
"""Get a free port for testing the stateless StreamableHTTP server."""
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
@pytest.fixture
def stateless_http_server_url(stateless_http_server_port: int) -> str:
"""Get the stateless StreamableHTTP server URL for testing."""
return f"http://127.0.0.1:{stateless_http_server_port}"
# Create a function to make the FastMCP server app
def make_fastmcp_app():
"""Create a FastMCP server without auth settings."""
@@ -51,6 +80,40 @@ def make_fastmcp_app():
return mcp, app
def make_fastmcp_streamable_http_app():
"""Create a FastMCP server with StreamableHTTP transport."""
from starlette.applications import Starlette
mcp = FastMCP(name="NoAuthServer")
# Add a simple tool
@mcp.tool(description="A simple echo tool")
def echo(message: str) -> str:
return f"Echo: {message}"
# Create the StreamableHTTP app
app: Starlette = mcp.streamable_http_app()
return mcp, app
def make_fastmcp_stateless_http_app():
"""Create a FastMCP server with stateless StreamableHTTP transport."""
from starlette.applications import Starlette
mcp = FastMCP(name="StatelessServer", stateless_http=True)
# Add a simple tool
@mcp.tool(description="A simple echo tool")
def echo(message: str) -> str:
return f"Echo: {message}"
# Create the StreamableHTTP app
app: Starlette = mcp.streamable_http_app()
return mcp, app
def run_server(server_port: int) -> None:
"""Run the server."""
_, app = make_fastmcp_app()
@@ -63,6 +126,30 @@ def run_server(server_port: int) -> None:
server.run()
def run_streamable_http_server(server_port: int) -> None:
"""Run the StreamableHTTP server."""
_, app = make_fastmcp_streamable_http_app()
server = uvicorn.Server(
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
print(f"Starting StreamableHTTP server on port {server_port}")
server.run()
def run_stateless_http_server(server_port: int) -> None:
"""Run the stateless StreamableHTTP server."""
_, app = make_fastmcp_stateless_http_app()
server = uvicorn.Server(
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
print(f"Starting stateless StreamableHTTP server on port {server_port}")
server.run()
@pytest.fixture()
def server(server_port: int) -> Generator[None, None, None]:
"""Start the server in a separate process and clean up after the test."""
@@ -94,6 +181,80 @@ def server(server_port: int) -> Generator[None, None, None]:
print("Server process failed to terminate")
@pytest.fixture()
def streamable_http_server(http_server_port: int) -> Generator[None, None, None]:
"""Start the StreamableHTTP server in a separate process."""
proc = multiprocessing.Process(
target=run_streamable_http_server, args=(http_server_port,), daemon=True
)
print("Starting StreamableHTTP server process")
proc.start()
# Wait for server to be running
max_attempts = 20
attempt = 0
print("Waiting for StreamableHTTP server to start")
while attempt < max_attempts:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("127.0.0.1", http_server_port))
break
except ConnectionRefusedError:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(
f"StreamableHTTP server failed to start after {max_attempts} attempts"
)
yield
print("Killing StreamableHTTP server")
proc.kill()
proc.join(timeout=2)
if proc.is_alive():
print("StreamableHTTP server process failed to terminate")
@pytest.fixture()
def stateless_http_server(
stateless_http_server_port: int,
) -> Generator[None, None, None]:
"""Start the stateless StreamableHTTP server in a separate process."""
proc = multiprocessing.Process(
target=run_stateless_http_server,
args=(stateless_http_server_port,),
daemon=True,
)
print("Starting stateless StreamableHTTP server process")
proc.start()
# Wait for server to be running
max_attempts = 20
attempt = 0
print("Waiting for stateless StreamableHTTP server to start")
while attempt < max_attempts:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("127.0.0.1", stateless_http_server_port))
break
except ConnectionRefusedError:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(
f"Stateless server failed to start after {max_attempts} attempts"
)
yield
print("Killing stateless StreamableHTTP server")
proc.kill()
proc.join(timeout=2)
if proc.is_alive():
print("Stateless StreamableHTTP server process failed to terminate")
@pytest.mark.anyio
async def test_fastmcp_without_auth(server: None, server_url: str) -> None:
"""Test that FastMCP works when auth settings are not provided."""
@@ -110,3 +271,55 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None:
assert len(tool_result.content) == 1
assert isinstance(tool_result.content[0], TextContent)
assert tool_result.content[0].text == "Echo: hello"
@pytest.mark.anyio
async def test_fastmcp_streamable_http(
streamable_http_server: None, http_server_url: str
) -> None:
"""Test that FastMCP works with StreamableHTTP transport."""
# Connect to the server using StreamableHTTP
async with streamablehttp_client(http_server_url + "/mcp") as (
read_stream,
write_stream,
_,
):
# Create a session using the client streams
async with ClientSession(read_stream, write_stream) as session:
# Test initialization
result = await session.initialize()
assert isinstance(result, InitializeResult)
assert result.serverInfo.name == "NoAuthServer"
# Test that we can call tools without authentication
tool_result = await session.call_tool("echo", {"message": "hello"})
assert len(tool_result.content) == 1
assert isinstance(tool_result.content[0], TextContent)
assert tool_result.content[0].text == "Echo: hello"
@pytest.mark.anyio
async def test_fastmcp_stateless_streamable_http(
stateless_http_server: None, stateless_http_server_url: str
) -> None:
"""Test that FastMCP works with stateless StreamableHTTP transport."""
# Connect to the server using StreamableHTTP
async with streamablehttp_client(stateless_http_server_url + "/mcp") as (
read_stream,
write_stream,
_,
):
async with ClientSession(read_stream, write_stream) as session:
result = await session.initialize()
assert isinstance(result, InitializeResult)
assert result.serverInfo.name == "StatelessServer"
tool_result = await session.call_tool("echo", {"message": "hello"})
assert len(tool_result.content) == 1
assert isinstance(tool_result.content[0], TextContent)
assert tool_result.content[0].text == "Echo: hello"
for i in range(3):
tool_result = await session.call_tool("echo", {"message": f"test_{i}"})
assert len(tool_result.content) == 1
assert isinstance(tool_result.content[0], TextContent)
assert tool_result.content[0].text == f"Echo: test_{i}"

View File

@@ -0,0 +1,81 @@
"""Tests for StreamableHTTPSessionManager."""
import anyio
import pytest
from mcp.server.lowlevel import Server
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
@pytest.mark.anyio
async def test_run_can_only_be_called_once():
"""Test that run() can only be called once per instance."""
app = Server("test-server")
manager = StreamableHTTPSessionManager(app=app)
# First call should succeed
async with manager.run():
pass
# Second call should raise RuntimeError
with pytest.raises(RuntimeError) as excinfo:
async with manager.run():
pass
assert (
"StreamableHTTPSessionManager .run() can only be called once per instance"
in str(excinfo.value)
)
@pytest.mark.anyio
async def test_run_prevents_concurrent_calls():
"""Test that concurrent calls to run() are prevented."""
app = Server("test-server")
manager = StreamableHTTPSessionManager(app=app)
errors = []
async def try_run():
try:
async with manager.run():
# Simulate some work
await anyio.sleep(0.1)
except RuntimeError as e:
errors.append(e)
# Try to run concurrently
async with anyio.create_task_group() as tg:
tg.start_soon(try_run)
tg.start_soon(try_run)
# One should succeed, one should fail
assert len(errors) == 1
assert (
"StreamableHTTPSessionManager .run() can only be called once per instance"
in str(errors[0])
)
@pytest.mark.anyio
async def test_handle_request_without_run_raises_error():
"""Test that handle_request raises error if run() hasn't been called."""
app = Server("test-server")
manager = StreamableHTTPSessionManager(app=app)
# Mock ASGI parameters
scope = {"type": "http", "method": "POST", "path": "/test"}
async def receive():
return {"type": "http.request", "body": b""}
async def send(message):
pass
# Should raise error because run() hasn't been called
with pytest.raises(RuntimeError) as excinfo:
await manager.handle_request(scope, receive, send)
assert "Task group is not initialized. Make sure to use run()." in str(
excinfo.value
)

View File

@@ -4,13 +4,10 @@ Tests for the StreamableHTTP server and client transport.
Contains tests for both server and client sides of the StreamableHTTP transport.
"""
import contextlib
import multiprocessing
import socket
import time
from collections.abc import Generator
from http import HTTPStatus
from uuid import uuid4
import anyio
import httpx
@@ -19,8 +16,6 @@ import requests
import uvicorn
from pydantic import AnyUrl
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Mount
import mcp.types as types
@@ -37,6 +32,7 @@ from mcp.server.streamable_http import (
StreamableHTTPServerTransport,
StreamId,
)
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from mcp.shared.exceptions import McpError
from mcp.shared.message import (
ClientMessageMetadata,
@@ -184,7 +180,7 @@ class ServerTest(Server):
def create_app(
is_json_response_enabled=False, event_store: EventStore | None = None
) -> Starlette:
"""Create a Starlette application for testing that matches the example server.
"""Create a Starlette application for testing using the session manager.
Args:
is_json_response_enabled: If True, use JSON responses instead of SSE streams.
@@ -193,85 +189,20 @@ def create_app(
# Create server instance
server = ServerTest()
server_instances = {}
# Lock to prevent race conditions when creating new sessions
session_creation_lock = anyio.Lock()
task_group = None
# Create the session manager
session_manager = StreamableHTTPSessionManager(
app=server,
event_store=event_store,
json_response=is_json_response_enabled,
)
@contextlib.asynccontextmanager
async def lifespan(app):
"""Application lifespan context manager for managing task group."""
nonlocal task_group
async with anyio.create_task_group() as tg:
task_group = tg
try:
yield
finally:
if task_group:
tg.cancel_scope.cancel()
task_group = None
async def handle_streamable_http(scope, receive, send):
request = Request(scope, receive)
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)
# Use existing transport if session ID matches
if (
request_mcp_session_id is not None
and request_mcp_session_id in server_instances
):
transport = server_instances[request_mcp_session_id]
await transport.handle_request(scope, receive, send)
elif request_mcp_session_id is None:
async with session_creation_lock:
new_session_id = uuid4().hex
http_transport = StreamableHTTPServerTransport(
mcp_session_id=new_session_id,
is_json_response_enabled=is_json_response_enabled,
event_store=event_store,
)
async def run_server(task_status=None):
async with http_transport.connect() as streams:
read_stream, write_stream = streams
if task_status:
task_status.started()
await server.run(
read_stream,
write_stream,
server.create_initialization_options(),
)
if task_group is None:
response = Response(
"Internal Server Error: Task group is not initialized",
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
)
await response(scope, receive, send)
return
# Store the instance before starting the task to prevent races
server_instances[http_transport.mcp_session_id] = http_transport
await task_group.start(run_server)
await http_transport.handle_request(scope, receive, send)
else:
response = Response(
"Bad Request: No valid session ID provided",
status_code=HTTPStatus.BAD_REQUEST,
)
await response(scope, receive, send)
# Create an ASGI application
# Create an ASGI application that uses the session manager
app = Starlette(
debug=True,
routes=[
Mount("/mcp", app=handle_streamable_http),
Mount("/mcp", app=session_manager.handle_request),
],
lifespan=lifespan,
lifespan=lambda app: session_manager.run(),
)
return app