mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-18 22:44:20 +01:00
Add message queue for SSE messages POST endpoint (#459)
This commit is contained in:
@@ -62,7 +62,7 @@ async def test_client_session_initialize():
|
||||
async with server_to_client_send:
|
||||
await server_to_client_send.send(
|
||||
SessionMessage(
|
||||
JSONRPCMessage(
|
||||
message=JSONRPCMessage(
|
||||
JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=jsonrpc_request.root.id,
|
||||
@@ -153,7 +153,7 @@ async def test_client_session_custom_client_info():
|
||||
async with server_to_client_send:
|
||||
await server_to_client_send.send(
|
||||
SessionMessage(
|
||||
JSONRPCMessage(
|
||||
message=JSONRPCMessage(
|
||||
JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=jsonrpc_request.root.id,
|
||||
@@ -220,7 +220,7 @@ async def test_client_session_default_client_info():
|
||||
async with server_to_client_send:
|
||||
await server_to_client_send.send(
|
||||
SessionMessage(
|
||||
JSONRPCMessage(
|
||||
message=JSONRPCMessage(
|
||||
JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=jsonrpc_request.root.id,
|
||||
|
||||
@@ -23,7 +23,7 @@ async def test_stdio_client():
|
||||
|
||||
async with write_stream:
|
||||
for message in messages:
|
||||
session_message = SessionMessage(message)
|
||||
session_message = SessionMessage(message=message)
|
||||
await write_stream.send(session_message)
|
||||
|
||||
read_messages = []
|
||||
|
||||
@@ -65,7 +65,7 @@ async def test_request_id_match() -> None:
|
||||
jsonrpc="2.0",
|
||||
)
|
||||
|
||||
await client_writer.send(SessionMessage(JSONRPCMessage(root=init_req)))
|
||||
await client_writer.send(SessionMessage(message=JSONRPCMessage(root=init_req)))
|
||||
response = (
|
||||
await server_reader.receive()
|
||||
) # Get init response but don't need to check it
|
||||
@@ -77,7 +77,7 @@ async def test_request_id_match() -> None:
|
||||
jsonrpc="2.0",
|
||||
)
|
||||
await client_writer.send(
|
||||
SessionMessage(JSONRPCMessage(root=initialized_notification))
|
||||
SessionMessage(message=JSONRPCMessage(root=initialized_notification))
|
||||
)
|
||||
|
||||
# Send ping request with custom ID
|
||||
@@ -85,7 +85,9 @@ async def test_request_id_match() -> None:
|
||||
id=custom_request_id, method="ping", params={}, jsonrpc="2.0"
|
||||
)
|
||||
|
||||
await client_writer.send(SessionMessage(JSONRPCMessage(root=ping_request)))
|
||||
await client_writer.send(
|
||||
SessionMessage(message=JSONRPCMessage(root=ping_request))
|
||||
)
|
||||
|
||||
# Read response
|
||||
response = await server_reader.receive()
|
||||
|
||||
1
tests/server/message_dispatch/__init__.py
Normal file
1
tests/server/message_dispatch/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Message queue tests module
|
||||
28
tests/server/message_dispatch/conftest.py
Normal file
28
tests/server/message_dispatch/conftest.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Shared fixtures for message queue tests."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from mcp.server.message_queue.redis import RedisMessageDispatch
|
||||
|
||||
# Set up fakeredis for testing
|
||||
try:
|
||||
from fakeredis import aioredis as fake_redis
|
||||
except ImportError:
|
||||
pytest.skip(
|
||||
"fakeredis is required for testing Redis functionality", allow_module_level=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def message_dispatch() -> AsyncGenerator[RedisMessageDispatch, None]:
|
||||
"""Create a shared Redis message dispatch with a fake Redis client."""
|
||||
with patch("mcp.server.message_queue.redis.redis", fake_redis.FakeRedis):
|
||||
# Shorter TTL for testing
|
||||
message_dispatch = RedisMessageDispatch(session_ttl=5)
|
||||
try:
|
||||
yield message_dispatch
|
||||
finally:
|
||||
await message_dispatch.close()
|
||||
355
tests/server/message_dispatch/test_redis.py
Normal file
355
tests/server/message_dispatch/test_redis.py
Normal file
@@ -0,0 +1,355 @@
|
||||
from unittest.mock import AsyncMock
|
||||
from uuid import uuid4
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.server.message_queue.redis import RedisMessageDispatch
|
||||
from mcp.shared.message import SessionMessage
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_session_heartbeat(message_dispatch):
|
||||
"""Test that session heartbeat refreshes TTL."""
|
||||
session_id = uuid4()
|
||||
|
||||
async with message_dispatch.subscribe(session_id, AsyncMock()):
|
||||
session_key = message_dispatch._session_key(session_id)
|
||||
|
||||
# Initial TTL
|
||||
initial_ttl = await message_dispatch._redis.ttl(session_key) # type: ignore
|
||||
assert initial_ttl > 0
|
||||
|
||||
# Wait for heartbeat to run
|
||||
await anyio.sleep(message_dispatch._session_ttl / 2 + 0.5)
|
||||
|
||||
# TTL should be refreshed
|
||||
refreshed_ttl = await message_dispatch._redis.ttl(session_key) # type: ignore
|
||||
assert refreshed_ttl > 0
|
||||
assert refreshed_ttl <= message_dispatch._session_ttl
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_subscribe_unsubscribe(message_dispatch):
|
||||
"""Test subscribing and unsubscribing from a session."""
|
||||
session_id = uuid4()
|
||||
callback = AsyncMock()
|
||||
|
||||
# Subscribe
|
||||
async with message_dispatch.subscribe(session_id, callback):
|
||||
# Check that session is tracked
|
||||
assert session_id in message_dispatch._session_state
|
||||
assert await message_dispatch.session_exists(session_id)
|
||||
|
||||
# After context exit, session should be cleaned up
|
||||
assert session_id not in message_dispatch._session_state
|
||||
assert not await message_dispatch.session_exists(session_id)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_publish_message_valid_json(message_dispatch: RedisMessageDispatch):
|
||||
"""Test publishing a valid JSON-RPC message."""
|
||||
session_id = uuid4()
|
||||
callback = AsyncMock()
|
||||
message = types.JSONRPCMessage.model_validate(
|
||||
{"jsonrpc": "2.0", "method": "test", "params": {}, "id": 1}
|
||||
)
|
||||
|
||||
# Subscribe to messages
|
||||
async with message_dispatch.subscribe(session_id, callback):
|
||||
# Publish message
|
||||
published = await message_dispatch.publish_message(
|
||||
session_id, SessionMessage(message=message)
|
||||
)
|
||||
assert published
|
||||
|
||||
# Give some time for the message to be processed
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
# Callback should have been called with the message
|
||||
callback.assert_called_once()
|
||||
call_args = callback.call_args[0][0]
|
||||
assert isinstance(call_args, SessionMessage)
|
||||
assert isinstance(call_args.message.root, types.JSONRPCRequest)
|
||||
assert (
|
||||
call_args.message.root.method == "test"
|
||||
) # Access method through root attribute
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_publish_message_invalid_json(message_dispatch):
|
||||
"""Test publishing an invalid JSON string."""
|
||||
session_id = uuid4()
|
||||
callback = AsyncMock()
|
||||
invalid_json = '{"invalid": "json",,}' # Invalid JSON
|
||||
|
||||
# Subscribe to messages
|
||||
async with message_dispatch.subscribe(session_id, callback):
|
||||
# Publish invalid message
|
||||
published = await message_dispatch.publish_message(session_id, invalid_json)
|
||||
assert published
|
||||
|
||||
# Give some time for the message to be processed
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
# Callback should have been called with a ValidationError
|
||||
callback.assert_called_once()
|
||||
error = callback.call_args[0][0]
|
||||
assert isinstance(error, ValidationError)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_publish_to_nonexistent_session(message_dispatch: RedisMessageDispatch):
|
||||
"""Test publishing to a session that doesn't exist."""
|
||||
session_id = uuid4()
|
||||
message = SessionMessage(
|
||||
message=types.JSONRPCMessage.model_validate(
|
||||
{"jsonrpc": "2.0", "method": "test", "params": {}, "id": 1}
|
||||
)
|
||||
)
|
||||
|
||||
published = await message_dispatch.publish_message(session_id, message)
|
||||
assert not published
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_extract_session_id(message_dispatch):
|
||||
"""Test extracting session ID from channel name."""
|
||||
session_id = uuid4()
|
||||
channel = message_dispatch._session_channel(session_id)
|
||||
|
||||
# Valid channel
|
||||
extracted_id = message_dispatch._extract_session_id(channel)
|
||||
assert extracted_id == session_id
|
||||
|
||||
# Invalid channel format
|
||||
extracted_id = message_dispatch._extract_session_id("invalid_channel_name")
|
||||
assert extracted_id is None
|
||||
|
||||
# Invalid UUID in channel
|
||||
invalid_channel = f"{message_dispatch._prefix}session:invalid_uuid"
|
||||
extracted_id = message_dispatch._extract_session_id(invalid_channel)
|
||||
assert extracted_id is None
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_multiple_sessions(message_dispatch: RedisMessageDispatch):
|
||||
"""Test handling multiple concurrent sessions."""
|
||||
session1 = uuid4()
|
||||
session2 = uuid4()
|
||||
callback1 = AsyncMock()
|
||||
callback2 = AsyncMock()
|
||||
|
||||
async with message_dispatch.subscribe(session1, callback1):
|
||||
async with message_dispatch.subscribe(session2, callback2):
|
||||
# Both sessions should exist
|
||||
assert await message_dispatch.session_exists(session1)
|
||||
assert await message_dispatch.session_exists(session2)
|
||||
|
||||
# Publish to session1
|
||||
message1 = types.JSONRPCMessage.model_validate(
|
||||
{"jsonrpc": "2.0", "method": "test1", "params": {}, "id": 1}
|
||||
)
|
||||
await message_dispatch.publish_message(
|
||||
session1, SessionMessage(message=message1)
|
||||
)
|
||||
|
||||
# Publish to session2
|
||||
message2 = types.JSONRPCMessage.model_validate(
|
||||
{"jsonrpc": "2.0", "method": "test2", "params": {}, "id": 2}
|
||||
)
|
||||
await message_dispatch.publish_message(
|
||||
session2, SessionMessage(message=message2)
|
||||
)
|
||||
|
||||
# Give some time for messages to be processed
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
# Check callbacks
|
||||
callback1.assert_called_once()
|
||||
callback2.assert_called_once()
|
||||
|
||||
call1_args = callback1.call_args[0][0]
|
||||
assert isinstance(call1_args, SessionMessage)
|
||||
assert call1_args.message.root.method == "test1" # type: ignore
|
||||
|
||||
call2_args = callback2.call_args[0][0]
|
||||
assert isinstance(call2_args, SessionMessage)
|
||||
assert call2_args.message.root.method == "test2" # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_task_group_cancellation(message_dispatch):
|
||||
"""Test that task group is properly cancelled when context exits."""
|
||||
session_id = uuid4()
|
||||
callback = AsyncMock()
|
||||
|
||||
async with message_dispatch.subscribe(session_id, callback):
|
||||
# Check that task group is active
|
||||
_, task_group = message_dispatch._session_state[session_id]
|
||||
assert task_group.cancel_scope.cancel_called is False
|
||||
|
||||
# After context exit, task group should be cancelled
|
||||
# And session state should be cleaned up
|
||||
assert session_id not in message_dispatch._session_state
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_session_cancellation_isolation(message_dispatch):
|
||||
"""Test that cancelling one session doesn't affect other sessions."""
|
||||
session1 = uuid4()
|
||||
session2 = uuid4()
|
||||
|
||||
# Create a blocking callback for session1 to ensure it's running when cancelled
|
||||
session1_event = anyio.Event()
|
||||
session1_started = anyio.Event()
|
||||
session1_cancelled = False
|
||||
|
||||
async def blocking_callback1(msg):
|
||||
session1_started.set()
|
||||
try:
|
||||
await session1_event.wait()
|
||||
except anyio.get_cancelled_exc_class():
|
||||
nonlocal session1_cancelled
|
||||
session1_cancelled = True
|
||||
raise
|
||||
|
||||
callback2 = AsyncMock()
|
||||
|
||||
# Start session2 first
|
||||
async with message_dispatch.subscribe(session2, callback2):
|
||||
# Start session1 with a blocking callback
|
||||
async with anyio.create_task_group() as tg:
|
||||
|
||||
async def session1_runner():
|
||||
async with message_dispatch.subscribe(session1, blocking_callback1):
|
||||
# Publish a message to trigger the blocking callback
|
||||
message = types.JSONRPCMessage.model_validate(
|
||||
{"jsonrpc": "2.0", "method": "test", "params": {}, "id": 1}
|
||||
)
|
||||
await message_dispatch.publish_message(session1, message)
|
||||
|
||||
# Wait for the callback to start
|
||||
await session1_started.wait()
|
||||
|
||||
# Keep the context alive while we test cancellation
|
||||
await anyio.sleep_forever()
|
||||
|
||||
tg.start_soon(session1_runner)
|
||||
|
||||
# Wait for session1's callback to start
|
||||
await session1_started.wait()
|
||||
|
||||
# Cancel session1
|
||||
tg.cancel_scope.cancel()
|
||||
|
||||
# Give some time for cancellation to propagate
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
# Verify session1 was cancelled
|
||||
assert session1_cancelled
|
||||
assert session1 not in message_dispatch._session_state
|
||||
|
||||
# Verify session2 is still active and can receive messages
|
||||
assert await message_dispatch.session_exists(session2)
|
||||
message2 = types.JSONRPCMessage.model_validate(
|
||||
{"jsonrpc": "2.0", "method": "test2", "params": {}, "id": 2}
|
||||
)
|
||||
await message_dispatch.publish_message(session2, message2)
|
||||
|
||||
# Give some time for the message to be processed
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
# Verify session2 received the message
|
||||
callback2.assert_called_once()
|
||||
call_args = callback2.call_args[0][0]
|
||||
assert call_args.root.method == "test2"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_listener_task_handoff_on_cancellation(message_dispatch):
|
||||
"""
|
||||
Test that the single listening task is properly
|
||||
handed off when a session is cancelled.
|
||||
"""
|
||||
session1 = uuid4()
|
||||
session2 = uuid4()
|
||||
|
||||
session1_messages_received = 0
|
||||
session2_messages_received = 0
|
||||
|
||||
async def callback1(msg):
|
||||
nonlocal session1_messages_received
|
||||
session1_messages_received += 1
|
||||
|
||||
async def callback2(msg):
|
||||
nonlocal session2_messages_received
|
||||
session2_messages_received += 1
|
||||
|
||||
# Create a cancel scope for session1
|
||||
async with anyio.create_task_group() as tg:
|
||||
session1_cancel_scope: anyio.CancelScope | None = None
|
||||
|
||||
async def session1_runner():
|
||||
nonlocal session1_cancel_scope
|
||||
with anyio.CancelScope() as cancel_scope:
|
||||
session1_cancel_scope = cancel_scope
|
||||
async with message_dispatch.subscribe(session1, callback1):
|
||||
# Keep session alive until cancelled
|
||||
await anyio.sleep_forever()
|
||||
|
||||
# Start session1
|
||||
tg.start_soon(session1_runner)
|
||||
|
||||
# Wait for session1 to be established
|
||||
await anyio.sleep(0.1)
|
||||
assert session1 in message_dispatch._session_state
|
||||
|
||||
# Send message to session1 to verify it's working
|
||||
message1 = types.JSONRPCMessage.model_validate(
|
||||
{"jsonrpc": "2.0", "method": "test1", "params": {}, "id": 1}
|
||||
)
|
||||
await message_dispatch.publish_message(session1, message1)
|
||||
await anyio.sleep(0.1)
|
||||
assert session1_messages_received == 1
|
||||
|
||||
# Start session2 while session1 is still active
|
||||
async with message_dispatch.subscribe(session2, callback2):
|
||||
# Both sessions should be active
|
||||
assert session1 in message_dispatch._session_state
|
||||
assert session2 in message_dispatch._session_state
|
||||
|
||||
# Cancel session1
|
||||
assert session1_cancel_scope is not None
|
||||
session1_cancel_scope.cancel()
|
||||
|
||||
# Wait for cancellation to complete
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
# Session1 should be gone, session2 should remain
|
||||
assert session1 not in message_dispatch._session_state
|
||||
assert session2 in message_dispatch._session_state
|
||||
|
||||
# Send message to session2 to verify the listener was handed off
|
||||
message2 = types.JSONRPCMessage.model_validate(
|
||||
{"jsonrpc": "2.0", "method": "test2", "params": {}, "id": 2}
|
||||
)
|
||||
await message_dispatch.publish_message(session2, message2)
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
# Session2 should have received the message
|
||||
assert session2_messages_received == 1
|
||||
|
||||
# Session1 shouldn't receive any more messages
|
||||
assert session1_messages_received == 1
|
||||
|
||||
# Send another message to verify the listener is still working
|
||||
message3 = types.JSONRPCMessage.model_validate(
|
||||
{"jsonrpc": "2.0", "method": "test3", "params": {}, "id": 3}
|
||||
)
|
||||
await message_dispatch.publish_message(session2, message3)
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
assert session2_messages_received == 2
|
||||
260
tests/server/message_dispatch/test_redis_integration.py
Normal file
260
tests/server/message_dispatch/test_redis_integration.py
Normal file
@@ -0,0 +1,260 @@
|
||||
"""
|
||||
Integration tests for Redis message dispatch functionality.
|
||||
|
||||
These tests validate Redis message dispatch by making actual HTTP calls and testing
|
||||
that messages flow correctly through the Redis backend.
|
||||
|
||||
This version runs the server in a task instead of a separate process to allow
|
||||
access to the fakeredis instance for verification of Redis keys.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
import uvicorn
|
||||
from sse_starlette.sse import AppStatus
|
||||
from starlette.applications import Starlette
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import Mount, Route
|
||||
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.server import Server
|
||||
from mcp.server.message_queue.redis import RedisMessageDispatch
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from mcp.types import TextContent, Tool
|
||||
|
||||
SERVER_NAME = "test_server_for_redis_integration_v3"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server_port() -> int:
|
||||
with socket.socket() as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server_url(server_port: int) -> str:
|
||||
return f"http://127.0.0.1:{server_port}"
|
||||
|
||||
|
||||
class RedisTestServer(Server):
|
||||
"""Test server with basic tool functionality."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(SERVER_NAME)
|
||||
|
||||
@self.list_tools()
|
||||
async def handle_list_tools() -> list[Tool]:
|
||||
return [
|
||||
Tool(
|
||||
name="test_tool",
|
||||
description="A test tool",
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
),
|
||||
Tool(
|
||||
name="echo_message",
|
||||
description="Echo a message back",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {"message": {"type": "string"}},
|
||||
"required": ["message"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
@self.call_tool()
|
||||
async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
|
||||
if name == "echo_message":
|
||||
message = args.get("message", "")
|
||||
return [TextContent(type="text", text=f"Echo: {message}")]
|
||||
return [TextContent(type="text", text=f"Called {name}")]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
async def redis_server_and_app(message_dispatch: RedisMessageDispatch):
|
||||
"""Create a mock Redis instance and Starlette app for testing."""
|
||||
|
||||
# Create SSE transport with Redis message dispatch
|
||||
sse = SseServerTransport("/messages/", message_dispatch=message_dispatch)
|
||||
server = RedisTestServer()
|
||||
|
||||
async def handle_sse(request: Request):
|
||||
async with sse.connect_sse(
|
||||
request.scope, request.receive, request._send
|
||||
) as streams:
|
||||
await server.run(
|
||||
streams[0], streams[1], server.create_initialization_options()
|
||||
)
|
||||
return Response()
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: Starlette) -> AsyncGenerator[None, None]:
|
||||
"""Manage the lifecycle of the application."""
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
await message_dispatch.close()
|
||||
|
||||
app = Starlette(
|
||||
routes=[
|
||||
Route("/sse", endpoint=handle_sse),
|
||||
Mount("/messages/", app=sse.handle_post_message),
|
||||
],
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
return app, message_dispatch, message_dispatch._redis
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
async def server_and_redis(redis_server_and_app, server_port: int):
|
||||
"""Run the server in a task and return the Redis instance for inspection."""
|
||||
app, message_dispatch, mock_redis = redis_server_and_app
|
||||
|
||||
# Create a server config
|
||||
config = uvicorn.Config(
|
||||
app=app, host="127.0.0.1", port=server_port, log_level="error"
|
||||
)
|
||||
server = uvicorn.Server(config=config)
|
||||
try:
|
||||
async with anyio.create_task_group() as tg:
|
||||
# Start server in background
|
||||
tg.start_soon(server.serve)
|
||||
|
||||
# Wait for server to be ready
|
||||
max_attempts = 20
|
||||
attempt = 0
|
||||
while attempt < max_attempts:
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.connect(("127.0.0.1", server_port))
|
||||
break
|
||||
except ConnectionRefusedError:
|
||||
await anyio.sleep(0.1)
|
||||
attempt += 1
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Server failed to start after {max_attempts} attempts"
|
||||
)
|
||||
|
||||
try:
|
||||
yield mock_redis, message_dispatch
|
||||
finally:
|
||||
server.should_exit = True
|
||||
finally:
|
||||
# These class variables are set top-level in starlette-sse
|
||||
# It isn't designed to be run multiple times in a single
|
||||
# Python process so we need to manually reset them.
|
||||
AppStatus.should_exit = False
|
||||
AppStatus.should_exit_event = None
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
async def client_session(server_and_redis, server_url: str):
|
||||
"""Create a client session for testing."""
|
||||
async with sse_client(server_url + "/sse") as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
result = await session.initialize()
|
||||
assert result.serverInfo.name == SERVER_NAME
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_redis_integration_key_verification(
|
||||
server_and_redis, client_session
|
||||
) -> None:
|
||||
"""Test that Redis keys are created correctly for sessions."""
|
||||
mock_redis, _ = server_and_redis
|
||||
|
||||
all_keys = await mock_redis.keys("*") # type: ignore
|
||||
|
||||
assert len(all_keys) > 0
|
||||
|
||||
session_key = None
|
||||
for key in all_keys:
|
||||
if key.startswith("mcp:pubsub:session_active:"):
|
||||
session_key = key
|
||||
break
|
||||
|
||||
assert session_key is not None, f"No session key found. Keys: {all_keys}"
|
||||
|
||||
ttl = await mock_redis.ttl(session_key) # type: ignore
|
||||
assert ttl > 0, f"Session key should have TTL, got: {ttl}"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_calls(server_and_redis, client_session) -> None:
|
||||
"""Test that messages are properly published through Redis."""
|
||||
mock_redis, _ = server_and_redis
|
||||
|
||||
for i in range(3):
|
||||
tool_result = await client_session.call_tool(
|
||||
"echo_message", {"message": f"Test {i}"}
|
||||
)
|
||||
assert tool_result.content[0].text == f"Echo: Test {i}" # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_session_cleanup(server_and_redis, server_url: str) -> None:
|
||||
"""Test Redis key cleanup when sessions end."""
|
||||
mock_redis, _ = server_and_redis
|
||||
session_keys_seen = set()
|
||||
|
||||
for i in range(3):
|
||||
async with sse_client(server_url + "/sse") as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
|
||||
all_keys = await mock_redis.keys("*") # type: ignore
|
||||
for key in all_keys:
|
||||
if key.startswith("mcp:pubsub:session_active:"):
|
||||
session_keys_seen.add(key)
|
||||
value = await mock_redis.get(key) # type: ignore
|
||||
assert value == "1"
|
||||
|
||||
await anyio.sleep(0.1) # Give time for cleanup
|
||||
all_keys = await mock_redis.keys("*") # type: ignore
|
||||
assert (
|
||||
len(all_keys) == 0
|
||||
), f"Session keys should be cleaned up, found: {all_keys}"
|
||||
|
||||
# Verify we saw different session keys for each session
|
||||
assert len(session_keys_seen) == 3, "Should have seen 3 unique session keys"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def concurrent_tool_call(server_and_redis, server_url: str) -> None:
|
||||
"""Test multiple clients and verify Redis key management."""
|
||||
mock_redis, _ = server_and_redis
|
||||
|
||||
async def client_task(client_id: int) -> str:
|
||||
async with sse_client(server_url + "/sse") as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
|
||||
result = await session.call_tool(
|
||||
"echo_message",
|
||||
{"message": f"Message from client {client_id}"},
|
||||
)
|
||||
return result.content[0].text # type: ignore
|
||||
|
||||
# Run multiple clients concurrently
|
||||
client_tasks = [client_task(i) for i in range(3)]
|
||||
results = await asyncio.gather(*client_tasks)
|
||||
|
||||
# Verify all clients received their respective messages
|
||||
assert len(results) == 3
|
||||
for i, result in enumerate(results):
|
||||
assert result == f"Echo: Message from client {i}"
|
||||
|
||||
# After all clients disconnect, keys should be cleaned up
|
||||
await anyio.sleep(0.1) # Give time for cleanup
|
||||
all_keys = await mock_redis.keys("*") # type: ignore
|
||||
assert len(all_keys) == 0, f"Session keys should be cleaned up, found: {all_keys}"
|
||||
@@ -84,7 +84,7 @@ async def test_lowlevel_server_lifespan():
|
||||
)
|
||||
await send_stream1.send(
|
||||
SessionMessage(
|
||||
JSONRPCMessage(
|
||||
message=JSONRPCMessage(
|
||||
root=JSONRPCRequest(
|
||||
jsonrpc="2.0",
|
||||
id=1,
|
||||
@@ -100,7 +100,7 @@ async def test_lowlevel_server_lifespan():
|
||||
# Send initialized notification
|
||||
await send_stream1.send(
|
||||
SessionMessage(
|
||||
JSONRPCMessage(
|
||||
message=JSONRPCMessage(
|
||||
root=JSONRPCNotification(
|
||||
jsonrpc="2.0",
|
||||
method="notifications/initialized",
|
||||
@@ -112,7 +112,7 @@ async def test_lowlevel_server_lifespan():
|
||||
# Call the tool to verify lifespan context
|
||||
await send_stream1.send(
|
||||
SessionMessage(
|
||||
JSONRPCMessage(
|
||||
message=JSONRPCMessage(
|
||||
root=JSONRPCRequest(
|
||||
jsonrpc="2.0",
|
||||
id=2,
|
||||
@@ -188,7 +188,7 @@ async def test_fastmcp_server_lifespan():
|
||||
)
|
||||
await send_stream1.send(
|
||||
SessionMessage(
|
||||
JSONRPCMessage(
|
||||
message=JSONRPCMessage(
|
||||
root=JSONRPCRequest(
|
||||
jsonrpc="2.0",
|
||||
id=1,
|
||||
@@ -204,7 +204,7 @@ async def test_fastmcp_server_lifespan():
|
||||
# Send initialized notification
|
||||
await send_stream1.send(
|
||||
SessionMessage(
|
||||
JSONRPCMessage(
|
||||
message=JSONRPCMessage(
|
||||
root=JSONRPCNotification(
|
||||
jsonrpc="2.0",
|
||||
method="notifications/initialized",
|
||||
@@ -216,7 +216,7 @@ async def test_fastmcp_server_lifespan():
|
||||
# Call the tool to verify lifespan context
|
||||
await send_stream1.send(
|
||||
SessionMessage(
|
||||
JSONRPCMessage(
|
||||
message=JSONRPCMessage(
|
||||
root=JSONRPCRequest(
|
||||
jsonrpc="2.0",
|
||||
id=2,
|
||||
|
||||
@@ -51,7 +51,7 @@ async def test_stdio_server():
|
||||
|
||||
async with write_stream:
|
||||
for response in responses:
|
||||
session_message = SessionMessage(response)
|
||||
session_message = SessionMessage(message=response)
|
||||
await write_stream.send(session_message)
|
||||
|
||||
stdout.seek(0)
|
||||
|
||||
Reference in New Issue
Block a user