mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Add message queue for SSE messages POST endpoint (#459)
This commit is contained in:
355
tests/server/message_dispatch/test_redis.py
Normal file
355
tests/server/message_dispatch/test_redis.py
Normal file
@@ -0,0 +1,355 @@
|
||||
from unittest.mock import AsyncMock
|
||||
from uuid import uuid4
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.server.message_queue.redis import RedisMessageDispatch
|
||||
from mcp.shared.message import SessionMessage
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_session_heartbeat(message_dispatch):
|
||||
"""Test that session heartbeat refreshes TTL."""
|
||||
session_id = uuid4()
|
||||
|
||||
async with message_dispatch.subscribe(session_id, AsyncMock()):
|
||||
session_key = message_dispatch._session_key(session_id)
|
||||
|
||||
# Initial TTL
|
||||
initial_ttl = await message_dispatch._redis.ttl(session_key) # type: ignore
|
||||
assert initial_ttl > 0
|
||||
|
||||
# Wait for heartbeat to run
|
||||
await anyio.sleep(message_dispatch._session_ttl / 2 + 0.5)
|
||||
|
||||
# TTL should be refreshed
|
||||
refreshed_ttl = await message_dispatch._redis.ttl(session_key) # type: ignore
|
||||
assert refreshed_ttl > 0
|
||||
assert refreshed_ttl <= message_dispatch._session_ttl
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_subscribe_unsubscribe(message_dispatch):
|
||||
"""Test subscribing and unsubscribing from a session."""
|
||||
session_id = uuid4()
|
||||
callback = AsyncMock()
|
||||
|
||||
# Subscribe
|
||||
async with message_dispatch.subscribe(session_id, callback):
|
||||
# Check that session is tracked
|
||||
assert session_id in message_dispatch._session_state
|
||||
assert await message_dispatch.session_exists(session_id)
|
||||
|
||||
# After context exit, session should be cleaned up
|
||||
assert session_id not in message_dispatch._session_state
|
||||
assert not await message_dispatch.session_exists(session_id)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_publish_message_valid_json(message_dispatch: RedisMessageDispatch):
|
||||
"""Test publishing a valid JSON-RPC message."""
|
||||
session_id = uuid4()
|
||||
callback = AsyncMock()
|
||||
message = types.JSONRPCMessage.model_validate(
|
||||
{"jsonrpc": "2.0", "method": "test", "params": {}, "id": 1}
|
||||
)
|
||||
|
||||
# Subscribe to messages
|
||||
async with message_dispatch.subscribe(session_id, callback):
|
||||
# Publish message
|
||||
published = await message_dispatch.publish_message(
|
||||
session_id, SessionMessage(message=message)
|
||||
)
|
||||
assert published
|
||||
|
||||
# Give some time for the message to be processed
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
# Callback should have been called with the message
|
||||
callback.assert_called_once()
|
||||
call_args = callback.call_args[0][0]
|
||||
assert isinstance(call_args, SessionMessage)
|
||||
assert isinstance(call_args.message.root, types.JSONRPCRequest)
|
||||
assert (
|
||||
call_args.message.root.method == "test"
|
||||
) # Access method through root attribute
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_publish_message_invalid_json(message_dispatch):
|
||||
"""Test publishing an invalid JSON string."""
|
||||
session_id = uuid4()
|
||||
callback = AsyncMock()
|
||||
invalid_json = '{"invalid": "json",,}' # Invalid JSON
|
||||
|
||||
# Subscribe to messages
|
||||
async with message_dispatch.subscribe(session_id, callback):
|
||||
# Publish invalid message
|
||||
published = await message_dispatch.publish_message(session_id, invalid_json)
|
||||
assert published
|
||||
|
||||
# Give some time for the message to be processed
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
# Callback should have been called with a ValidationError
|
||||
callback.assert_called_once()
|
||||
error = callback.call_args[0][0]
|
||||
assert isinstance(error, ValidationError)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_publish_to_nonexistent_session(message_dispatch: RedisMessageDispatch):
|
||||
"""Test publishing to a session that doesn't exist."""
|
||||
session_id = uuid4()
|
||||
message = SessionMessage(
|
||||
message=types.JSONRPCMessage.model_validate(
|
||||
{"jsonrpc": "2.0", "method": "test", "params": {}, "id": 1}
|
||||
)
|
||||
)
|
||||
|
||||
published = await message_dispatch.publish_message(session_id, message)
|
||||
assert not published
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_extract_session_id(message_dispatch):
|
||||
"""Test extracting session ID from channel name."""
|
||||
session_id = uuid4()
|
||||
channel = message_dispatch._session_channel(session_id)
|
||||
|
||||
# Valid channel
|
||||
extracted_id = message_dispatch._extract_session_id(channel)
|
||||
assert extracted_id == session_id
|
||||
|
||||
# Invalid channel format
|
||||
extracted_id = message_dispatch._extract_session_id("invalid_channel_name")
|
||||
assert extracted_id is None
|
||||
|
||||
# Invalid UUID in channel
|
||||
invalid_channel = f"{message_dispatch._prefix}session:invalid_uuid"
|
||||
extracted_id = message_dispatch._extract_session_id(invalid_channel)
|
||||
assert extracted_id is None
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_multiple_sessions(message_dispatch: RedisMessageDispatch):
|
||||
"""Test handling multiple concurrent sessions."""
|
||||
session1 = uuid4()
|
||||
session2 = uuid4()
|
||||
callback1 = AsyncMock()
|
||||
callback2 = AsyncMock()
|
||||
|
||||
async with message_dispatch.subscribe(session1, callback1):
|
||||
async with message_dispatch.subscribe(session2, callback2):
|
||||
# Both sessions should exist
|
||||
assert await message_dispatch.session_exists(session1)
|
||||
assert await message_dispatch.session_exists(session2)
|
||||
|
||||
# Publish to session1
|
||||
message1 = types.JSONRPCMessage.model_validate(
|
||||
{"jsonrpc": "2.0", "method": "test1", "params": {}, "id": 1}
|
||||
)
|
||||
await message_dispatch.publish_message(
|
||||
session1, SessionMessage(message=message1)
|
||||
)
|
||||
|
||||
# Publish to session2
|
||||
message2 = types.JSONRPCMessage.model_validate(
|
||||
{"jsonrpc": "2.0", "method": "test2", "params": {}, "id": 2}
|
||||
)
|
||||
await message_dispatch.publish_message(
|
||||
session2, SessionMessage(message=message2)
|
||||
)
|
||||
|
||||
# Give some time for messages to be processed
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
# Check callbacks
|
||||
callback1.assert_called_once()
|
||||
callback2.assert_called_once()
|
||||
|
||||
call1_args = callback1.call_args[0][0]
|
||||
assert isinstance(call1_args, SessionMessage)
|
||||
assert call1_args.message.root.method == "test1" # type: ignore
|
||||
|
||||
call2_args = callback2.call_args[0][0]
|
||||
assert isinstance(call2_args, SessionMessage)
|
||||
assert call2_args.message.root.method == "test2" # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_task_group_cancellation(message_dispatch):
|
||||
"""Test that task group is properly cancelled when context exits."""
|
||||
session_id = uuid4()
|
||||
callback = AsyncMock()
|
||||
|
||||
async with message_dispatch.subscribe(session_id, callback):
|
||||
# Check that task group is active
|
||||
_, task_group = message_dispatch._session_state[session_id]
|
||||
assert task_group.cancel_scope.cancel_called is False
|
||||
|
||||
# After context exit, task group should be cancelled
|
||||
# And session state should be cleaned up
|
||||
assert session_id not in message_dispatch._session_state
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_session_cancellation_isolation(message_dispatch):
|
||||
"""Test that cancelling one session doesn't affect other sessions."""
|
||||
session1 = uuid4()
|
||||
session2 = uuid4()
|
||||
|
||||
# Create a blocking callback for session1 to ensure it's running when cancelled
|
||||
session1_event = anyio.Event()
|
||||
session1_started = anyio.Event()
|
||||
session1_cancelled = False
|
||||
|
||||
async def blocking_callback1(msg):
|
||||
session1_started.set()
|
||||
try:
|
||||
await session1_event.wait()
|
||||
except anyio.get_cancelled_exc_class():
|
||||
nonlocal session1_cancelled
|
||||
session1_cancelled = True
|
||||
raise
|
||||
|
||||
callback2 = AsyncMock()
|
||||
|
||||
# Start session2 first
|
||||
async with message_dispatch.subscribe(session2, callback2):
|
||||
# Start session1 with a blocking callback
|
||||
async with anyio.create_task_group() as tg:
|
||||
|
||||
async def session1_runner():
|
||||
async with message_dispatch.subscribe(session1, blocking_callback1):
|
||||
# Publish a message to trigger the blocking callback
|
||||
message = types.JSONRPCMessage.model_validate(
|
||||
{"jsonrpc": "2.0", "method": "test", "params": {}, "id": 1}
|
||||
)
|
||||
await message_dispatch.publish_message(session1, message)
|
||||
|
||||
# Wait for the callback to start
|
||||
await session1_started.wait()
|
||||
|
||||
# Keep the context alive while we test cancellation
|
||||
await anyio.sleep_forever()
|
||||
|
||||
tg.start_soon(session1_runner)
|
||||
|
||||
# Wait for session1's callback to start
|
||||
await session1_started.wait()
|
||||
|
||||
# Cancel session1
|
||||
tg.cancel_scope.cancel()
|
||||
|
||||
# Give some time for cancellation to propagate
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
# Verify session1 was cancelled
|
||||
assert session1_cancelled
|
||||
assert session1 not in message_dispatch._session_state
|
||||
|
||||
# Verify session2 is still active and can receive messages
|
||||
assert await message_dispatch.session_exists(session2)
|
||||
message2 = types.JSONRPCMessage.model_validate(
|
||||
{"jsonrpc": "2.0", "method": "test2", "params": {}, "id": 2}
|
||||
)
|
||||
await message_dispatch.publish_message(session2, message2)
|
||||
|
||||
# Give some time for the message to be processed
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
# Verify session2 received the message
|
||||
callback2.assert_called_once()
|
||||
call_args = callback2.call_args[0][0]
|
||||
assert call_args.root.method == "test2"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_listener_task_handoff_on_cancellation(message_dispatch):
|
||||
"""
|
||||
Test that the single listening task is properly
|
||||
handed off when a session is cancelled.
|
||||
"""
|
||||
session1 = uuid4()
|
||||
session2 = uuid4()
|
||||
|
||||
session1_messages_received = 0
|
||||
session2_messages_received = 0
|
||||
|
||||
async def callback1(msg):
|
||||
nonlocal session1_messages_received
|
||||
session1_messages_received += 1
|
||||
|
||||
async def callback2(msg):
|
||||
nonlocal session2_messages_received
|
||||
session2_messages_received += 1
|
||||
|
||||
# Create a cancel scope for session1
|
||||
async with anyio.create_task_group() as tg:
|
||||
session1_cancel_scope: anyio.CancelScope | None = None
|
||||
|
||||
async def session1_runner():
|
||||
nonlocal session1_cancel_scope
|
||||
with anyio.CancelScope() as cancel_scope:
|
||||
session1_cancel_scope = cancel_scope
|
||||
async with message_dispatch.subscribe(session1, callback1):
|
||||
# Keep session alive until cancelled
|
||||
await anyio.sleep_forever()
|
||||
|
||||
# Start session1
|
||||
tg.start_soon(session1_runner)
|
||||
|
||||
# Wait for session1 to be established
|
||||
await anyio.sleep(0.1)
|
||||
assert session1 in message_dispatch._session_state
|
||||
|
||||
# Send message to session1 to verify it's working
|
||||
message1 = types.JSONRPCMessage.model_validate(
|
||||
{"jsonrpc": "2.0", "method": "test1", "params": {}, "id": 1}
|
||||
)
|
||||
await message_dispatch.publish_message(session1, message1)
|
||||
await anyio.sleep(0.1)
|
||||
assert session1_messages_received == 1
|
||||
|
||||
# Start session2 while session1 is still active
|
||||
async with message_dispatch.subscribe(session2, callback2):
|
||||
# Both sessions should be active
|
||||
assert session1 in message_dispatch._session_state
|
||||
assert session2 in message_dispatch._session_state
|
||||
|
||||
# Cancel session1
|
||||
assert session1_cancel_scope is not None
|
||||
session1_cancel_scope.cancel()
|
||||
|
||||
# Wait for cancellation to complete
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
# Session1 should be gone, session2 should remain
|
||||
assert session1 not in message_dispatch._session_state
|
||||
assert session2 in message_dispatch._session_state
|
||||
|
||||
# Send message to session2 to verify the listener was handed off
|
||||
message2 = types.JSONRPCMessage.model_validate(
|
||||
{"jsonrpc": "2.0", "method": "test2", "params": {}, "id": 2}
|
||||
)
|
||||
await message_dispatch.publish_message(session2, message2)
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
# Session2 should have received the message
|
||||
assert session2_messages_received == 1
|
||||
|
||||
# Session1 shouldn't receive any more messages
|
||||
assert session1_messages_received == 1
|
||||
|
||||
# Send another message to verify the listener is still working
|
||||
message3 = types.JSONRPCMessage.model_validate(
|
||||
{"jsonrpc": "2.0", "method": "test3", "params": {}, "id": 3}
|
||||
)
|
||||
await message_dispatch.publish_message(session2, message3)
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
assert session2_messages_received == 2
|
||||
Reference in New Issue
Block a user