Files
mcp-python-sdk/tests/server/message_dispatch/test_redis_integration.py

261 lines
8.8 KiB
Python

"""
Integration tests for Redis message dispatch functionality.
These tests validate Redis message dispatch by making actual HTTP calls and testing
that messages flow correctly through the Redis backend.
This version runs the server in a task instead of a separate process to allow
access to the fakeredis instance for verification of Redis keys.
"""
import asyncio
import socket
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
import anyio
import pytest
import uvicorn
from sse_starlette.sse import AppStatus
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Mount, Route
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.server import Server
from mcp.server.message_queue.redis import RedisMessageDispatch
from mcp.server.sse import SseServerTransport
from mcp.types import TextContent, Tool
SERVER_NAME = "test_server_for_redis_integration_v3"
@pytest.fixture
def server_port() -> int:
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
@pytest.fixture
def server_url(server_port: int) -> str:
return f"http://127.0.0.1:{server_port}"
class RedisTestServer(Server):
"""Test server with basic tool functionality."""
def __init__(self):
super().__init__(SERVER_NAME)
@self.list_tools()
async def handle_list_tools() -> list[Tool]:
return [
Tool(
name="test_tool",
description="A test tool",
inputSchema={"type": "object", "properties": {}},
),
Tool(
name="echo_message",
description="Echo a message back",
inputSchema={
"type": "object",
"properties": {"message": {"type": "string"}},
"required": ["message"],
},
),
]
@self.call_tool()
async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
if name == "echo_message":
message = args.get("message", "")
return [TextContent(type="text", text=f"Echo: {message}")]
return [TextContent(type="text", text=f"Called {name}")]
@pytest.fixture()
async def redis_server_and_app(message_dispatch: RedisMessageDispatch):
"""Create a mock Redis instance and Starlette app for testing."""
# Create SSE transport with Redis message dispatch
sse = SseServerTransport("/messages/", message_dispatch=message_dispatch)
server = RedisTestServer()
async def handle_sse(request: Request):
async with sse.connect_sse(
request.scope, request.receive, request._send
) as streams:
await server.run(
streams[0], streams[1], server.create_initialization_options()
)
return Response()
@asynccontextmanager
async def lifespan(app: Starlette) -> AsyncGenerator[None, None]:
"""Manage the lifecycle of the application."""
try:
yield
finally:
await message_dispatch.close()
app = Starlette(
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
lifespan=lifespan,
)
return app, message_dispatch, message_dispatch._redis
@pytest.fixture()
async def server_and_redis(redis_server_and_app, server_port: int):
"""Run the server in a task and return the Redis instance for inspection."""
app, message_dispatch, mock_redis = redis_server_and_app
# Create a server config
config = uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
server = uvicorn.Server(config=config)
try:
async with anyio.create_task_group() as tg:
# Start server in background
tg.start_soon(server.serve)
# Wait for server to be ready
max_attempts = 20
attempt = 0
while attempt < max_attempts:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("127.0.0.1", server_port))
break
except ConnectionRefusedError:
await anyio.sleep(0.1)
attempt += 1
else:
raise RuntimeError(
f"Server failed to start after {max_attempts} attempts"
)
try:
yield mock_redis, message_dispatch
finally:
server.should_exit = True
finally:
# These class variables are set top-level in starlette-sse
# It isn't designed to be run multiple times in a single
# Python process so we need to manually reset them.
AppStatus.should_exit = False
AppStatus.should_exit_event = None
@pytest.fixture()
async def client_session(server_and_redis, server_url: str):
"""Create a client session for testing."""
async with sse_client(server_url + "/sse") as streams:
async with ClientSession(*streams) as session:
result = await session.initialize()
assert result.serverInfo.name == SERVER_NAME
yield session
@pytest.mark.anyio
async def test_redis_integration_key_verification(
server_and_redis, client_session
) -> None:
"""Test that Redis keys are created correctly for sessions."""
mock_redis, _ = server_and_redis
all_keys = await mock_redis.keys("*") # type: ignore
assert len(all_keys) > 0
session_key = None
for key in all_keys:
if key.startswith("mcp:pubsub:session_active:"):
session_key = key
break
assert session_key is not None, f"No session key found. Keys: {all_keys}"
ttl = await mock_redis.ttl(session_key) # type: ignore
assert ttl > 0, f"Session key should have TTL, got: {ttl}"
@pytest.mark.anyio
async def test_tool_calls(server_and_redis, client_session) -> None:
"""Test that messages are properly published through Redis."""
mock_redis, _ = server_and_redis
for i in range(3):
tool_result = await client_session.call_tool(
"echo_message", {"message": f"Test {i}"}
)
assert tool_result.content[0].text == f"Echo: Test {i}" # type: ignore
@pytest.mark.anyio
async def test_session_cleanup(server_and_redis, server_url: str) -> None:
"""Test Redis key cleanup when sessions end."""
mock_redis, _ = server_and_redis
session_keys_seen = set()
for i in range(3):
async with sse_client(server_url + "/sse") as streams:
async with ClientSession(*streams) as session:
await session.initialize()
all_keys = await mock_redis.keys("*") # type: ignore
for key in all_keys:
if key.startswith("mcp:pubsub:session_active:"):
session_keys_seen.add(key)
value = await mock_redis.get(key) # type: ignore
assert value == "1"
await anyio.sleep(0.1) # Give time for cleanup
all_keys = await mock_redis.keys("*") # type: ignore
assert (
len(all_keys) == 0
), f"Session keys should be cleaned up, found: {all_keys}"
# Verify we saw different session keys for each session
assert len(session_keys_seen) == 3, "Should have seen 3 unique session keys"
@pytest.mark.anyio
async def concurrent_tool_call(server_and_redis, server_url: str) -> None:
"""Test multiple clients and verify Redis key management."""
mock_redis, _ = server_and_redis
async def client_task(client_id: int) -> str:
async with sse_client(server_url + "/sse") as streams:
async with ClientSession(*streams) as session:
await session.initialize()
result = await session.call_tool(
"echo_message",
{"message": f"Message from client {client_id}"},
)
return result.content[0].text # type: ignore
# Run multiple clients concurrently
client_tasks = [client_task(i) for i in range(3)]
results = await asyncio.gather(*client_tasks)
# Verify all clients received their respective messages
assert len(results) == 3
for i, result in enumerate(results):
assert result == f"Echo: Message from client {i}"
# After all clients disconnect, keys should be cleaned up
await anyio.sleep(0.1) # Give time for cleanup
all_keys = await mock_redis.keys("*") # type: ignore
assert len(all_keys) == 0, f"Session keys should be cleaned up, found: {all_keys}"