mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 06:54:18 +01:00
261 lines
8.8 KiB
Python
261 lines
8.8 KiB
Python
"""
|
|
Integration tests for Redis message dispatch functionality.
|
|
|
|
These tests validate Redis message dispatch by making actual HTTP calls and testing
|
|
that messages flow correctly through the Redis backend.
|
|
|
|
This version runs the server in a task instead of a separate process to allow
|
|
access to the fakeredis instance for verification of Redis keys.
|
|
"""
|
|
|
|
import asyncio
|
|
import socket
|
|
from collections.abc import AsyncGenerator
|
|
from contextlib import asynccontextmanager
|
|
|
|
import anyio
|
|
import pytest
|
|
import uvicorn
|
|
from sse_starlette.sse import AppStatus
|
|
from starlette.applications import Starlette
|
|
from starlette.requests import Request
|
|
from starlette.responses import Response
|
|
from starlette.routing import Mount, Route
|
|
|
|
from mcp.client.session import ClientSession
|
|
from mcp.client.sse import sse_client
|
|
from mcp.server import Server
|
|
from mcp.server.message_queue.redis import RedisMessageDispatch
|
|
from mcp.server.sse import SseServerTransport
|
|
from mcp.types import TextContent, Tool
|
|
|
|
SERVER_NAME = "test_server_for_redis_integration_v3"
|
|
|
|
|
|
@pytest.fixture
|
|
def server_port() -> int:
|
|
with socket.socket() as s:
|
|
s.bind(("127.0.0.1", 0))
|
|
return s.getsockname()[1]
|
|
|
|
|
|
@pytest.fixture
|
|
def server_url(server_port: int) -> str:
|
|
return f"http://127.0.0.1:{server_port}"
|
|
|
|
|
|
class RedisTestServer(Server):
|
|
"""Test server with basic tool functionality."""
|
|
|
|
def __init__(self):
|
|
super().__init__(SERVER_NAME)
|
|
|
|
@self.list_tools()
|
|
async def handle_list_tools() -> list[Tool]:
|
|
return [
|
|
Tool(
|
|
name="test_tool",
|
|
description="A test tool",
|
|
inputSchema={"type": "object", "properties": {}},
|
|
),
|
|
Tool(
|
|
name="echo_message",
|
|
description="Echo a message back",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {"message": {"type": "string"}},
|
|
"required": ["message"],
|
|
},
|
|
),
|
|
]
|
|
|
|
@self.call_tool()
|
|
async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
|
|
if name == "echo_message":
|
|
message = args.get("message", "")
|
|
return [TextContent(type="text", text=f"Echo: {message}")]
|
|
return [TextContent(type="text", text=f"Called {name}")]
|
|
|
|
|
|
@pytest.fixture()
|
|
async def redis_server_and_app(message_dispatch: RedisMessageDispatch):
|
|
"""Create a mock Redis instance and Starlette app for testing."""
|
|
|
|
# Create SSE transport with Redis message dispatch
|
|
sse = SseServerTransport("/messages/", message_dispatch=message_dispatch)
|
|
server = RedisTestServer()
|
|
|
|
async def handle_sse(request: Request):
|
|
async with sse.connect_sse(
|
|
request.scope, request.receive, request._send
|
|
) as streams:
|
|
await server.run(
|
|
streams[0], streams[1], server.create_initialization_options()
|
|
)
|
|
return Response()
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: Starlette) -> AsyncGenerator[None, None]:
|
|
"""Manage the lifecycle of the application."""
|
|
try:
|
|
yield
|
|
finally:
|
|
await message_dispatch.close()
|
|
|
|
app = Starlette(
|
|
routes=[
|
|
Route("/sse", endpoint=handle_sse),
|
|
Mount("/messages/", app=sse.handle_post_message),
|
|
],
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
return app, message_dispatch, message_dispatch._redis
|
|
|
|
|
|
@pytest.fixture()
|
|
async def server_and_redis(redis_server_and_app, server_port: int):
|
|
"""Run the server in a task and return the Redis instance for inspection."""
|
|
app, message_dispatch, mock_redis = redis_server_and_app
|
|
|
|
# Create a server config
|
|
config = uvicorn.Config(
|
|
app=app, host="127.0.0.1", port=server_port, log_level="error"
|
|
)
|
|
server = uvicorn.Server(config=config)
|
|
try:
|
|
async with anyio.create_task_group() as tg:
|
|
# Start server in background
|
|
tg.start_soon(server.serve)
|
|
|
|
# Wait for server to be ready
|
|
max_attempts = 20
|
|
attempt = 0
|
|
while attempt < max_attempts:
|
|
try:
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
s.connect(("127.0.0.1", server_port))
|
|
break
|
|
except ConnectionRefusedError:
|
|
await anyio.sleep(0.1)
|
|
attempt += 1
|
|
else:
|
|
raise RuntimeError(
|
|
f"Server failed to start after {max_attempts} attempts"
|
|
)
|
|
|
|
try:
|
|
yield mock_redis, message_dispatch
|
|
finally:
|
|
server.should_exit = True
|
|
finally:
|
|
# These class variables are set top-level in starlette-sse
|
|
# It isn't designed to be run multiple times in a single
|
|
# Python process so we need to manually reset them.
|
|
AppStatus.should_exit = False
|
|
AppStatus.should_exit_event = None
|
|
|
|
|
|
@pytest.fixture()
|
|
async def client_session(server_and_redis, server_url: str):
|
|
"""Create a client session for testing."""
|
|
async with sse_client(server_url + "/sse") as streams:
|
|
async with ClientSession(*streams) as session:
|
|
result = await session.initialize()
|
|
assert result.serverInfo.name == SERVER_NAME
|
|
yield session
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_redis_integration_key_verification(
|
|
server_and_redis, client_session
|
|
) -> None:
|
|
"""Test that Redis keys are created correctly for sessions."""
|
|
mock_redis, _ = server_and_redis
|
|
|
|
all_keys = await mock_redis.keys("*") # type: ignore
|
|
|
|
assert len(all_keys) > 0
|
|
|
|
session_key = None
|
|
for key in all_keys:
|
|
if key.startswith("mcp:pubsub:session_active:"):
|
|
session_key = key
|
|
break
|
|
|
|
assert session_key is not None, f"No session key found. Keys: {all_keys}"
|
|
|
|
ttl = await mock_redis.ttl(session_key) # type: ignore
|
|
assert ttl > 0, f"Session key should have TTL, got: {ttl}"
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_tool_calls(server_and_redis, client_session) -> None:
|
|
"""Test that messages are properly published through Redis."""
|
|
mock_redis, _ = server_and_redis
|
|
|
|
for i in range(3):
|
|
tool_result = await client_session.call_tool(
|
|
"echo_message", {"message": f"Test {i}"}
|
|
)
|
|
assert tool_result.content[0].text == f"Echo: Test {i}" # type: ignore
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_session_cleanup(server_and_redis, server_url: str) -> None:
|
|
"""Test Redis key cleanup when sessions end."""
|
|
mock_redis, _ = server_and_redis
|
|
session_keys_seen = set()
|
|
|
|
for i in range(3):
|
|
async with sse_client(server_url + "/sse") as streams:
|
|
async with ClientSession(*streams) as session:
|
|
await session.initialize()
|
|
|
|
all_keys = await mock_redis.keys("*") # type: ignore
|
|
for key in all_keys:
|
|
if key.startswith("mcp:pubsub:session_active:"):
|
|
session_keys_seen.add(key)
|
|
value = await mock_redis.get(key) # type: ignore
|
|
assert value == "1"
|
|
|
|
await anyio.sleep(0.1) # Give time for cleanup
|
|
all_keys = await mock_redis.keys("*") # type: ignore
|
|
assert (
|
|
len(all_keys) == 0
|
|
), f"Session keys should be cleaned up, found: {all_keys}"
|
|
|
|
# Verify we saw different session keys for each session
|
|
assert len(session_keys_seen) == 3, "Should have seen 3 unique session keys"
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def concurrent_tool_call(server_and_redis, server_url: str) -> None:
|
|
"""Test multiple clients and verify Redis key management."""
|
|
mock_redis, _ = server_and_redis
|
|
|
|
async def client_task(client_id: int) -> str:
|
|
async with sse_client(server_url + "/sse") as streams:
|
|
async with ClientSession(*streams) as session:
|
|
await session.initialize()
|
|
|
|
result = await session.call_tool(
|
|
"echo_message",
|
|
{"message": f"Message from client {client_id}"},
|
|
)
|
|
return result.content[0].text # type: ignore
|
|
|
|
# Run multiple clients concurrently
|
|
client_tasks = [client_task(i) for i in range(3)]
|
|
results = await asyncio.gather(*client_tasks)
|
|
|
|
# Verify all clients received their respective messages
|
|
assert len(results) == 3
|
|
for i, result in enumerate(results):
|
|
assert result == f"Echo: Message from client {i}"
|
|
|
|
# After all clients disconnect, keys should be cleaned up
|
|
await anyio.sleep(0.1) # Give time for cleanup
|
|
all_keys = await mock_redis.keys("*") # type: ignore
|
|
assert len(all_keys) == 0, f"Session keys should be cleaned up, found: {all_keys}"
|