Revert "Add message queue for SSE messages POST endpoint (#459)" (#649)

This commit is contained in:
ihrpr
2025-05-07 16:35:20 +01:00
committed by GitHub
parent c8a14c9dba
commit 9d99aee014
26 changed files with 51 additions and 1247 deletions

View File

@@ -62,7 +62,7 @@ async def test_client_session_initialize():
async with server_to_client_send:
await server_to_client_send.send(
SessionMessage(
message=JSONRPCMessage(
JSONRPCMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.id,
@@ -153,7 +153,7 @@ async def test_client_session_custom_client_info():
async with server_to_client_send:
await server_to_client_send.send(
SessionMessage(
message=JSONRPCMessage(
JSONRPCMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.id,
@@ -220,7 +220,7 @@ async def test_client_session_default_client_info():
async with server_to_client_send:
await server_to_client_send.send(
SessionMessage(
message=JSONRPCMessage(
JSONRPCMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.id,

View File

@@ -23,7 +23,7 @@ async def test_stdio_client():
async with write_stream:
for message in messages:
session_message = SessionMessage(message=message)
session_message = SessionMessage(message)
await write_stream.send(session_message)
read_messages = []

View File

@@ -65,7 +65,7 @@ async def test_request_id_match() -> None:
jsonrpc="2.0",
)
await client_writer.send(SessionMessage(message=JSONRPCMessage(root=init_req)))
await client_writer.send(SessionMessage(JSONRPCMessage(root=init_req)))
response = (
await server_reader.receive()
) # Get init response but don't need to check it
@@ -77,7 +77,7 @@ async def test_request_id_match() -> None:
jsonrpc="2.0",
)
await client_writer.send(
SessionMessage(message=JSONRPCMessage(root=initialized_notification))
SessionMessage(JSONRPCMessage(root=initialized_notification))
)
# Send ping request with custom ID
@@ -85,9 +85,7 @@ async def test_request_id_match() -> None:
id=custom_request_id, method="ping", params={}, jsonrpc="2.0"
)
await client_writer.send(
SessionMessage(message=JSONRPCMessage(root=ping_request))
)
await client_writer.send(SessionMessage(JSONRPCMessage(root=ping_request)))
# Read response
response = await server_reader.receive()

View File

@@ -1 +0,0 @@
# Message queue tests module

View File

@@ -1,28 +0,0 @@
"""Shared fixtures for message queue tests."""
from collections.abc import AsyncGenerator
from unittest.mock import patch
import pytest
from mcp.server.message_queue.redis import RedisMessageDispatch
# Set up fakeredis for testing
try:
from fakeredis import aioredis as fake_redis
except ImportError:
pytest.skip(
"fakeredis is required for testing Redis functionality", allow_module_level=True
)
@pytest.fixture
async def message_dispatch() -> AsyncGenerator[RedisMessageDispatch, None]:
"""Create a shared Redis message dispatch with a fake Redis client."""
with patch("mcp.server.message_queue.redis.redis", fake_redis.FakeRedis):
# Shorter TTL for testing
message_dispatch = RedisMessageDispatch(session_ttl=5)
try:
yield message_dispatch
finally:
await message_dispatch.close()

View File

@@ -1,355 +0,0 @@
from unittest.mock import AsyncMock
from uuid import uuid4
import anyio
import pytest
from pydantic import ValidationError
import mcp.types as types
from mcp.server.message_queue.redis import RedisMessageDispatch
from mcp.shared.message import SessionMessage
@pytest.mark.anyio
async def test_session_heartbeat(message_dispatch):
"""Test that session heartbeat refreshes TTL."""
session_id = uuid4()
async with message_dispatch.subscribe(session_id, AsyncMock()):
session_key = message_dispatch._session_key(session_id)
# Initial TTL
initial_ttl = await message_dispatch._redis.ttl(session_key) # type: ignore
assert initial_ttl > 0
# Wait for heartbeat to run
await anyio.sleep(message_dispatch._session_ttl / 2 + 0.5)
# TTL should be refreshed
refreshed_ttl = await message_dispatch._redis.ttl(session_key) # type: ignore
assert refreshed_ttl > 0
assert refreshed_ttl <= message_dispatch._session_ttl
@pytest.mark.anyio
async def test_subscribe_unsubscribe(message_dispatch):
"""Test subscribing and unsubscribing from a session."""
session_id = uuid4()
callback = AsyncMock()
# Subscribe
async with message_dispatch.subscribe(session_id, callback):
# Check that session is tracked
assert session_id in message_dispatch._session_state
assert await message_dispatch.session_exists(session_id)
# After context exit, session should be cleaned up
assert session_id not in message_dispatch._session_state
assert not await message_dispatch.session_exists(session_id)
@pytest.mark.anyio
async def test_publish_message_valid_json(message_dispatch: RedisMessageDispatch):
"""Test publishing a valid JSON-RPC message."""
session_id = uuid4()
callback = AsyncMock()
message = types.JSONRPCMessage.model_validate(
{"jsonrpc": "2.0", "method": "test", "params": {}, "id": 1}
)
# Subscribe to messages
async with message_dispatch.subscribe(session_id, callback):
# Publish message
published = await message_dispatch.publish_message(
session_id, SessionMessage(message=message)
)
assert published
# Give some time for the message to be processed
await anyio.sleep(0.1)
# Callback should have been called with the message
callback.assert_called_once()
call_args = callback.call_args[0][0]
assert isinstance(call_args, SessionMessage)
assert isinstance(call_args.message.root, types.JSONRPCRequest)
assert (
call_args.message.root.method == "test"
) # Access method through root attribute
@pytest.mark.anyio
async def test_publish_message_invalid_json(message_dispatch):
"""Test publishing an invalid JSON string."""
session_id = uuid4()
callback = AsyncMock()
invalid_json = '{"invalid": "json",,}' # Invalid JSON
# Subscribe to messages
async with message_dispatch.subscribe(session_id, callback):
# Publish invalid message
published = await message_dispatch.publish_message(session_id, invalid_json)
assert published
# Give some time for the message to be processed
await anyio.sleep(0.1)
# Callback should have been called with a ValidationError
callback.assert_called_once()
error = callback.call_args[0][0]
assert isinstance(error, ValidationError)
@pytest.mark.anyio
async def test_publish_to_nonexistent_session(message_dispatch: RedisMessageDispatch):
"""Test publishing to a session that doesn't exist."""
session_id = uuid4()
message = SessionMessage(
message=types.JSONRPCMessage.model_validate(
{"jsonrpc": "2.0", "method": "test", "params": {}, "id": 1}
)
)
published = await message_dispatch.publish_message(session_id, message)
assert not published
@pytest.mark.anyio
async def test_extract_session_id(message_dispatch):
"""Test extracting session ID from channel name."""
session_id = uuid4()
channel = message_dispatch._session_channel(session_id)
# Valid channel
extracted_id = message_dispatch._extract_session_id(channel)
assert extracted_id == session_id
# Invalid channel format
extracted_id = message_dispatch._extract_session_id("invalid_channel_name")
assert extracted_id is None
# Invalid UUID in channel
invalid_channel = f"{message_dispatch._prefix}session:invalid_uuid"
extracted_id = message_dispatch._extract_session_id(invalid_channel)
assert extracted_id is None
@pytest.mark.anyio
async def test_multiple_sessions(message_dispatch: RedisMessageDispatch):
"""Test handling multiple concurrent sessions."""
session1 = uuid4()
session2 = uuid4()
callback1 = AsyncMock()
callback2 = AsyncMock()
async with message_dispatch.subscribe(session1, callback1):
async with message_dispatch.subscribe(session2, callback2):
# Both sessions should exist
assert await message_dispatch.session_exists(session1)
assert await message_dispatch.session_exists(session2)
# Publish to session1
message1 = types.JSONRPCMessage.model_validate(
{"jsonrpc": "2.0", "method": "test1", "params": {}, "id": 1}
)
await message_dispatch.publish_message(
session1, SessionMessage(message=message1)
)
# Publish to session2
message2 = types.JSONRPCMessage.model_validate(
{"jsonrpc": "2.0", "method": "test2", "params": {}, "id": 2}
)
await message_dispatch.publish_message(
session2, SessionMessage(message=message2)
)
# Give some time for messages to be processed
await anyio.sleep(0.1)
# Check callbacks
callback1.assert_called_once()
callback2.assert_called_once()
call1_args = callback1.call_args[0][0]
assert isinstance(call1_args, SessionMessage)
assert call1_args.message.root.method == "test1" # type: ignore
call2_args = callback2.call_args[0][0]
assert isinstance(call2_args, SessionMessage)
assert call2_args.message.root.method == "test2" # type: ignore
@pytest.mark.anyio
async def test_task_group_cancellation(message_dispatch):
"""Test that task group is properly cancelled when context exits."""
session_id = uuid4()
callback = AsyncMock()
async with message_dispatch.subscribe(session_id, callback):
# Check that task group is active
_, task_group = message_dispatch._session_state[session_id]
assert task_group.cancel_scope.cancel_called is False
# After context exit, task group should be cancelled
# And session state should be cleaned up
assert session_id not in message_dispatch._session_state
@pytest.mark.anyio
async def test_session_cancellation_isolation(message_dispatch):
"""Test that cancelling one session doesn't affect other sessions."""
session1 = uuid4()
session2 = uuid4()
# Create a blocking callback for session1 to ensure it's running when cancelled
session1_event = anyio.Event()
session1_started = anyio.Event()
session1_cancelled = False
async def blocking_callback1(msg):
session1_started.set()
try:
await session1_event.wait()
except anyio.get_cancelled_exc_class():
nonlocal session1_cancelled
session1_cancelled = True
raise
callback2 = AsyncMock()
# Start session2 first
async with message_dispatch.subscribe(session2, callback2):
# Start session1 with a blocking callback
async with anyio.create_task_group() as tg:
async def session1_runner():
async with message_dispatch.subscribe(session1, blocking_callback1):
# Publish a message to trigger the blocking callback
message = types.JSONRPCMessage.model_validate(
{"jsonrpc": "2.0", "method": "test", "params": {}, "id": 1}
)
await message_dispatch.publish_message(session1, message)
# Wait for the callback to start
await session1_started.wait()
# Keep the context alive while we test cancellation
await anyio.sleep_forever()
tg.start_soon(session1_runner)
# Wait for session1's callback to start
await session1_started.wait()
# Cancel session1
tg.cancel_scope.cancel()
# Give some time for cancellation to propagate
await anyio.sleep(0.1)
# Verify session1 was cancelled
assert session1_cancelled
assert session1 not in message_dispatch._session_state
# Verify session2 is still active and can receive messages
assert await message_dispatch.session_exists(session2)
message2 = types.JSONRPCMessage.model_validate(
{"jsonrpc": "2.0", "method": "test2", "params": {}, "id": 2}
)
await message_dispatch.publish_message(session2, message2)
# Give some time for the message to be processed
await anyio.sleep(0.1)
# Verify session2 received the message
callback2.assert_called_once()
call_args = callback2.call_args[0][0]
assert call_args.root.method == "test2"
@pytest.mark.anyio
async def test_listener_task_handoff_on_cancellation(message_dispatch):
"""
Test that the single listening task is properly
handed off when a session is cancelled.
"""
session1 = uuid4()
session2 = uuid4()
session1_messages_received = 0
session2_messages_received = 0
async def callback1(msg):
nonlocal session1_messages_received
session1_messages_received += 1
async def callback2(msg):
nonlocal session2_messages_received
session2_messages_received += 1
# Create a cancel scope for session1
async with anyio.create_task_group() as tg:
session1_cancel_scope: anyio.CancelScope | None = None
async def session1_runner():
nonlocal session1_cancel_scope
with anyio.CancelScope() as cancel_scope:
session1_cancel_scope = cancel_scope
async with message_dispatch.subscribe(session1, callback1):
# Keep session alive until cancelled
await anyio.sleep_forever()
# Start session1
tg.start_soon(session1_runner)
# Wait for session1 to be established
await anyio.sleep(0.1)
assert session1 in message_dispatch._session_state
# Send message to session1 to verify it's working
message1 = types.JSONRPCMessage.model_validate(
{"jsonrpc": "2.0", "method": "test1", "params": {}, "id": 1}
)
await message_dispatch.publish_message(session1, message1)
await anyio.sleep(0.1)
assert session1_messages_received == 1
# Start session2 while session1 is still active
async with message_dispatch.subscribe(session2, callback2):
# Both sessions should be active
assert session1 in message_dispatch._session_state
assert session2 in message_dispatch._session_state
# Cancel session1
assert session1_cancel_scope is not None
session1_cancel_scope.cancel()
# Wait for cancellation to complete
await anyio.sleep(0.1)
# Session1 should be gone, session2 should remain
assert session1 not in message_dispatch._session_state
assert session2 in message_dispatch._session_state
# Send message to session2 to verify the listener was handed off
message2 = types.JSONRPCMessage.model_validate(
{"jsonrpc": "2.0", "method": "test2", "params": {}, "id": 2}
)
await message_dispatch.publish_message(session2, message2)
await anyio.sleep(0.1)
# Session2 should have received the message
assert session2_messages_received == 1
# Session1 shouldn't receive any more messages
assert session1_messages_received == 1
# Send another message to verify the listener is still working
message3 = types.JSONRPCMessage.model_validate(
{"jsonrpc": "2.0", "method": "test3", "params": {}, "id": 3}
)
await message_dispatch.publish_message(session2, message3)
await anyio.sleep(0.1)
assert session2_messages_received == 2

View File

@@ -1,260 +0,0 @@
"""
Integration tests for Redis message dispatch functionality.
These tests validate Redis message dispatch by making actual HTTP calls and testing
that messages flow correctly through the Redis backend.
This version runs the server in a task instead of a separate process to allow
access to the fakeredis instance for verification of Redis keys.
"""
import asyncio
import socket
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
import anyio
import pytest
import uvicorn
from sse_starlette.sse import AppStatus
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Mount, Route
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.server import Server
from mcp.server.message_queue.redis import RedisMessageDispatch
from mcp.server.sse import SseServerTransport
from mcp.types import TextContent, Tool
SERVER_NAME = "test_server_for_redis_integration_v3"
@pytest.fixture
def server_port() -> int:
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
@pytest.fixture
def server_url(server_port: int) -> str:
return f"http://127.0.0.1:{server_port}"
class RedisTestServer(Server):
"""Test server with basic tool functionality."""
def __init__(self):
super().__init__(SERVER_NAME)
@self.list_tools()
async def handle_list_tools() -> list[Tool]:
return [
Tool(
name="test_tool",
description="A test tool",
inputSchema={"type": "object", "properties": {}},
),
Tool(
name="echo_message",
description="Echo a message back",
inputSchema={
"type": "object",
"properties": {"message": {"type": "string"}},
"required": ["message"],
},
),
]
@self.call_tool()
async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
if name == "echo_message":
message = args.get("message", "")
return [TextContent(type="text", text=f"Echo: {message}")]
return [TextContent(type="text", text=f"Called {name}")]
@pytest.fixture()
async def redis_server_and_app(message_dispatch: RedisMessageDispatch):
"""Create a mock Redis instance and Starlette app for testing."""
# Create SSE transport with Redis message dispatch
sse = SseServerTransport("/messages/", message_dispatch=message_dispatch)
server = RedisTestServer()
async def handle_sse(request: Request):
async with sse.connect_sse(
request.scope, request.receive, request._send
) as streams:
await server.run(
streams[0], streams[1], server.create_initialization_options()
)
return Response()
@asynccontextmanager
async def lifespan(app: Starlette) -> AsyncGenerator[None, None]:
"""Manage the lifecycle of the application."""
try:
yield
finally:
await message_dispatch.close()
app = Starlette(
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
lifespan=lifespan,
)
return app, message_dispatch, message_dispatch._redis
@pytest.fixture()
async def server_and_redis(redis_server_and_app, server_port: int):
"""Run the server in a task and return the Redis instance for inspection."""
app, message_dispatch, mock_redis = redis_server_and_app
# Create a server config
config = uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
server = uvicorn.Server(config=config)
try:
async with anyio.create_task_group() as tg:
# Start server in background
tg.start_soon(server.serve)
# Wait for server to be ready
max_attempts = 20
attempt = 0
while attempt < max_attempts:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("127.0.0.1", server_port))
break
except ConnectionRefusedError:
await anyio.sleep(0.1)
attempt += 1
else:
raise RuntimeError(
f"Server failed to start after {max_attempts} attempts"
)
try:
yield mock_redis, message_dispatch
finally:
server.should_exit = True
finally:
# These class variables are set top-level in starlette-sse
# It isn't designed to be run multiple times in a single
# Python process so we need to manually reset them.
AppStatus.should_exit = False
AppStatus.should_exit_event = None
@pytest.fixture()
async def client_session(server_and_redis, server_url: str):
"""Create a client session for testing."""
async with sse_client(server_url + "/sse") as streams:
async with ClientSession(*streams) as session:
result = await session.initialize()
assert result.serverInfo.name == SERVER_NAME
yield session
@pytest.mark.anyio
async def test_redis_integration_key_verification(
server_and_redis, client_session
) -> None:
"""Test that Redis keys are created correctly for sessions."""
mock_redis, _ = server_and_redis
all_keys = await mock_redis.keys("*") # type: ignore
assert len(all_keys) > 0
session_key = None
for key in all_keys:
if key.startswith("mcp:pubsub:session_active:"):
session_key = key
break
assert session_key is not None, f"No session key found. Keys: {all_keys}"
ttl = await mock_redis.ttl(session_key) # type: ignore
assert ttl > 0, f"Session key should have TTL, got: {ttl}"
@pytest.mark.anyio
async def test_tool_calls(server_and_redis, client_session) -> None:
"""Test that messages are properly published through Redis."""
mock_redis, _ = server_and_redis
for i in range(3):
tool_result = await client_session.call_tool(
"echo_message", {"message": f"Test {i}"}
)
assert tool_result.content[0].text == f"Echo: Test {i}" # type: ignore
@pytest.mark.anyio
async def test_session_cleanup(server_and_redis, server_url: str) -> None:
"""Test Redis key cleanup when sessions end."""
mock_redis, _ = server_and_redis
session_keys_seen = set()
for i in range(3):
async with sse_client(server_url + "/sse") as streams:
async with ClientSession(*streams) as session:
await session.initialize()
all_keys = await mock_redis.keys("*") # type: ignore
for key in all_keys:
if key.startswith("mcp:pubsub:session_active:"):
session_keys_seen.add(key)
value = await mock_redis.get(key) # type: ignore
assert value == "1"
await anyio.sleep(0.1) # Give time for cleanup
all_keys = await mock_redis.keys("*") # type: ignore
assert (
len(all_keys) == 0
), f"Session keys should be cleaned up, found: {all_keys}"
# Verify we saw different session keys for each session
assert len(session_keys_seen) == 3, "Should have seen 3 unique session keys"
@pytest.mark.anyio
async def concurrent_tool_call(server_and_redis, server_url: str) -> None:
"""Test multiple clients and verify Redis key management."""
mock_redis, _ = server_and_redis
async def client_task(client_id: int) -> str:
async with sse_client(server_url + "/sse") as streams:
async with ClientSession(*streams) as session:
await session.initialize()
result = await session.call_tool(
"echo_message",
{"message": f"Message from client {client_id}"},
)
return result.content[0].text # type: ignore
# Run multiple clients concurrently
client_tasks = [client_task(i) for i in range(3)]
results = await asyncio.gather(*client_tasks)
# Verify all clients received their respective messages
assert len(results) == 3
for i, result in enumerate(results):
assert result == f"Echo: Message from client {i}"
# After all clients disconnect, keys should be cleaned up
await anyio.sleep(0.1) # Give time for cleanup
all_keys = await mock_redis.keys("*") # type: ignore
assert len(all_keys) == 0, f"Session keys should be cleaned up, found: {all_keys}"

View File

@@ -84,7 +84,7 @@ async def test_lowlevel_server_lifespan():
)
await send_stream1.send(
SessionMessage(
message=JSONRPCMessage(
JSONRPCMessage(
root=JSONRPCRequest(
jsonrpc="2.0",
id=1,
@@ -100,7 +100,7 @@ async def test_lowlevel_server_lifespan():
# Send initialized notification
await send_stream1.send(
SessionMessage(
message=JSONRPCMessage(
JSONRPCMessage(
root=JSONRPCNotification(
jsonrpc="2.0",
method="notifications/initialized",
@@ -112,7 +112,7 @@ async def test_lowlevel_server_lifespan():
# Call the tool to verify lifespan context
await send_stream1.send(
SessionMessage(
message=JSONRPCMessage(
JSONRPCMessage(
root=JSONRPCRequest(
jsonrpc="2.0",
id=2,
@@ -188,7 +188,7 @@ async def test_fastmcp_server_lifespan():
)
await send_stream1.send(
SessionMessage(
message=JSONRPCMessage(
JSONRPCMessage(
root=JSONRPCRequest(
jsonrpc="2.0",
id=1,
@@ -204,7 +204,7 @@ async def test_fastmcp_server_lifespan():
# Send initialized notification
await send_stream1.send(
SessionMessage(
message=JSONRPCMessage(
JSONRPCMessage(
root=JSONRPCNotification(
jsonrpc="2.0",
method="notifications/initialized",
@@ -216,7 +216,7 @@ async def test_fastmcp_server_lifespan():
# Call the tool to verify lifespan context
await send_stream1.send(
SessionMessage(
message=JSONRPCMessage(
JSONRPCMessage(
root=JSONRPCRequest(
jsonrpc="2.0",
id=2,

View File

@@ -51,7 +51,7 @@ async def test_stdio_server():
async with write_stream:
for response in responses:
session_message = SessionMessage(message=response)
session_message = SessionMessage(response)
await write_stream.send(session_message)
stdout.seek(0)