mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Add message queue for SSE messages POST endpoint (#459)
This commit is contained in:
260
tests/server/message_dispatch/test_redis_integration.py
Normal file
260
tests/server/message_dispatch/test_redis_integration.py
Normal file
@@ -0,0 +1,260 @@
|
||||
"""
|
||||
Integration tests for Redis message dispatch functionality.
|
||||
|
||||
These tests validate Redis message dispatch by making actual HTTP calls and testing
|
||||
that messages flow correctly through the Redis backend.
|
||||
|
||||
This version runs the server in a task instead of a separate process to allow
|
||||
access to the fakeredis instance for verification of Redis keys.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
import uvicorn
|
||||
from sse_starlette.sse import AppStatus
|
||||
from starlette.applications import Starlette
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import Mount, Route
|
||||
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.server import Server
|
||||
from mcp.server.message_queue.redis import RedisMessageDispatch
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from mcp.types import TextContent, Tool
|
||||
|
||||
SERVER_NAME = "test_server_for_redis_integration_v3"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server_port() -> int:
|
||||
with socket.socket() as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server_url(server_port: int) -> str:
|
||||
return f"http://127.0.0.1:{server_port}"
|
||||
|
||||
|
||||
class RedisTestServer(Server):
|
||||
"""Test server with basic tool functionality."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(SERVER_NAME)
|
||||
|
||||
@self.list_tools()
|
||||
async def handle_list_tools() -> list[Tool]:
|
||||
return [
|
||||
Tool(
|
||||
name="test_tool",
|
||||
description="A test tool",
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
),
|
||||
Tool(
|
||||
name="echo_message",
|
||||
description="Echo a message back",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {"message": {"type": "string"}},
|
||||
"required": ["message"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
@self.call_tool()
|
||||
async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
|
||||
if name == "echo_message":
|
||||
message = args.get("message", "")
|
||||
return [TextContent(type="text", text=f"Echo: {message}")]
|
||||
return [TextContent(type="text", text=f"Called {name}")]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
async def redis_server_and_app(message_dispatch: RedisMessageDispatch):
|
||||
"""Create a mock Redis instance and Starlette app for testing."""
|
||||
|
||||
# Create SSE transport with Redis message dispatch
|
||||
sse = SseServerTransport("/messages/", message_dispatch=message_dispatch)
|
||||
server = RedisTestServer()
|
||||
|
||||
async def handle_sse(request: Request):
|
||||
async with sse.connect_sse(
|
||||
request.scope, request.receive, request._send
|
||||
) as streams:
|
||||
await server.run(
|
||||
streams[0], streams[1], server.create_initialization_options()
|
||||
)
|
||||
return Response()
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: Starlette) -> AsyncGenerator[None, None]:
|
||||
"""Manage the lifecycle of the application."""
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
await message_dispatch.close()
|
||||
|
||||
app = Starlette(
|
||||
routes=[
|
||||
Route("/sse", endpoint=handle_sse),
|
||||
Mount("/messages/", app=sse.handle_post_message),
|
||||
],
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
return app, message_dispatch, message_dispatch._redis
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
async def server_and_redis(redis_server_and_app, server_port: int):
|
||||
"""Run the server in a task and return the Redis instance for inspection."""
|
||||
app, message_dispatch, mock_redis = redis_server_and_app
|
||||
|
||||
# Create a server config
|
||||
config = uvicorn.Config(
|
||||
app=app, host="127.0.0.1", port=server_port, log_level="error"
|
||||
)
|
||||
server = uvicorn.Server(config=config)
|
||||
try:
|
||||
async with anyio.create_task_group() as tg:
|
||||
# Start server in background
|
||||
tg.start_soon(server.serve)
|
||||
|
||||
# Wait for server to be ready
|
||||
max_attempts = 20
|
||||
attempt = 0
|
||||
while attempt < max_attempts:
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.connect(("127.0.0.1", server_port))
|
||||
break
|
||||
except ConnectionRefusedError:
|
||||
await anyio.sleep(0.1)
|
||||
attempt += 1
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Server failed to start after {max_attempts} attempts"
|
||||
)
|
||||
|
||||
try:
|
||||
yield mock_redis, message_dispatch
|
||||
finally:
|
||||
server.should_exit = True
|
||||
finally:
|
||||
# These class variables are set top-level in starlette-sse
|
||||
# It isn't designed to be run multiple times in a single
|
||||
# Python process so we need to manually reset them.
|
||||
AppStatus.should_exit = False
|
||||
AppStatus.should_exit_event = None
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
async def client_session(server_and_redis, server_url: str):
|
||||
"""Create a client session for testing."""
|
||||
async with sse_client(server_url + "/sse") as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
result = await session.initialize()
|
||||
assert result.serverInfo.name == SERVER_NAME
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_redis_integration_key_verification(
|
||||
server_and_redis, client_session
|
||||
) -> None:
|
||||
"""Test that Redis keys are created correctly for sessions."""
|
||||
mock_redis, _ = server_and_redis
|
||||
|
||||
all_keys = await mock_redis.keys("*") # type: ignore
|
||||
|
||||
assert len(all_keys) > 0
|
||||
|
||||
session_key = None
|
||||
for key in all_keys:
|
||||
if key.startswith("mcp:pubsub:session_active:"):
|
||||
session_key = key
|
||||
break
|
||||
|
||||
assert session_key is not None, f"No session key found. Keys: {all_keys}"
|
||||
|
||||
ttl = await mock_redis.ttl(session_key) # type: ignore
|
||||
assert ttl > 0, f"Session key should have TTL, got: {ttl}"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_calls(server_and_redis, client_session) -> None:
|
||||
"""Test that messages are properly published through Redis."""
|
||||
mock_redis, _ = server_and_redis
|
||||
|
||||
for i in range(3):
|
||||
tool_result = await client_session.call_tool(
|
||||
"echo_message", {"message": f"Test {i}"}
|
||||
)
|
||||
assert tool_result.content[0].text == f"Echo: Test {i}" # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_session_cleanup(server_and_redis, server_url: str) -> None:
|
||||
"""Test Redis key cleanup when sessions end."""
|
||||
mock_redis, _ = server_and_redis
|
||||
session_keys_seen = set()
|
||||
|
||||
for i in range(3):
|
||||
async with sse_client(server_url + "/sse") as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
|
||||
all_keys = await mock_redis.keys("*") # type: ignore
|
||||
for key in all_keys:
|
||||
if key.startswith("mcp:pubsub:session_active:"):
|
||||
session_keys_seen.add(key)
|
||||
value = await mock_redis.get(key) # type: ignore
|
||||
assert value == "1"
|
||||
|
||||
await anyio.sleep(0.1) # Give time for cleanup
|
||||
all_keys = await mock_redis.keys("*") # type: ignore
|
||||
assert (
|
||||
len(all_keys) == 0
|
||||
), f"Session keys should be cleaned up, found: {all_keys}"
|
||||
|
||||
# Verify we saw different session keys for each session
|
||||
assert len(session_keys_seen) == 3, "Should have seen 3 unique session keys"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def concurrent_tool_call(server_and_redis, server_url: str) -> None:
|
||||
"""Test multiple clients and verify Redis key management."""
|
||||
mock_redis, _ = server_and_redis
|
||||
|
||||
async def client_task(client_id: int) -> str:
|
||||
async with sse_client(server_url + "/sse") as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
|
||||
result = await session.call_tool(
|
||||
"echo_message",
|
||||
{"message": f"Message from client {client_id}"},
|
||||
)
|
||||
return result.content[0].text # type: ignore
|
||||
|
||||
# Run multiple clients concurrently
|
||||
client_tasks = [client_task(i) for i in range(3)]
|
||||
results = await asyncio.gather(*client_tasks)
|
||||
|
||||
# Verify all clients received their respective messages
|
||||
assert len(results) == 3
|
||||
for i, result in enumerate(results):
|
||||
assert result == f"Echo: Message from client {i}"
|
||||
|
||||
# After all clients disconnect, keys should be cleaned up
|
||||
await anyio.sleep(0.1) # Give time for cleanup
|
||||
all_keys = await mock_redis.keys("*") # type: ignore
|
||||
assert len(all_keys) == 0, f"Session keys should be cleaned up, found: {all_keys}"
|
||||
Reference in New Issue
Block a user