mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
356 lines
13 KiB
Python
356 lines
13 KiB
Python
from unittest.mock import AsyncMock
|
|
from uuid import uuid4
|
|
|
|
import anyio
|
|
import pytest
|
|
from pydantic import ValidationError
|
|
|
|
import mcp.types as types
|
|
from mcp.server.message_queue.redis import RedisMessageDispatch
|
|
from mcp.shared.message import SessionMessage
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_session_heartbeat(message_dispatch):
|
|
"""Test that session heartbeat refreshes TTL."""
|
|
session_id = uuid4()
|
|
|
|
async with message_dispatch.subscribe(session_id, AsyncMock()):
|
|
session_key = message_dispatch._session_key(session_id)
|
|
|
|
# Initial TTL
|
|
initial_ttl = await message_dispatch._redis.ttl(session_key) # type: ignore
|
|
assert initial_ttl > 0
|
|
|
|
# Wait for heartbeat to run
|
|
await anyio.sleep(message_dispatch._session_ttl / 2 + 0.5)
|
|
|
|
# TTL should be refreshed
|
|
refreshed_ttl = await message_dispatch._redis.ttl(session_key) # type: ignore
|
|
assert refreshed_ttl > 0
|
|
assert refreshed_ttl <= message_dispatch._session_ttl
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_subscribe_unsubscribe(message_dispatch):
|
|
"""Test subscribing and unsubscribing from a session."""
|
|
session_id = uuid4()
|
|
callback = AsyncMock()
|
|
|
|
# Subscribe
|
|
async with message_dispatch.subscribe(session_id, callback):
|
|
# Check that session is tracked
|
|
assert session_id in message_dispatch._session_state
|
|
assert await message_dispatch.session_exists(session_id)
|
|
|
|
# After context exit, session should be cleaned up
|
|
assert session_id not in message_dispatch._session_state
|
|
assert not await message_dispatch.session_exists(session_id)
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_publish_message_valid_json(message_dispatch: RedisMessageDispatch):
|
|
"""Test publishing a valid JSON-RPC message."""
|
|
session_id = uuid4()
|
|
callback = AsyncMock()
|
|
message = types.JSONRPCMessage.model_validate(
|
|
{"jsonrpc": "2.0", "method": "test", "params": {}, "id": 1}
|
|
)
|
|
|
|
# Subscribe to messages
|
|
async with message_dispatch.subscribe(session_id, callback):
|
|
# Publish message
|
|
published = await message_dispatch.publish_message(
|
|
session_id, SessionMessage(message=message)
|
|
)
|
|
assert published
|
|
|
|
# Give some time for the message to be processed
|
|
await anyio.sleep(0.1)
|
|
|
|
# Callback should have been called with the message
|
|
callback.assert_called_once()
|
|
call_args = callback.call_args[0][0]
|
|
assert isinstance(call_args, SessionMessage)
|
|
assert isinstance(call_args.message.root, types.JSONRPCRequest)
|
|
assert (
|
|
call_args.message.root.method == "test"
|
|
) # Access method through root attribute
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_publish_message_invalid_json(message_dispatch):
|
|
"""Test publishing an invalid JSON string."""
|
|
session_id = uuid4()
|
|
callback = AsyncMock()
|
|
invalid_json = '{"invalid": "json",,}' # Invalid JSON
|
|
|
|
# Subscribe to messages
|
|
async with message_dispatch.subscribe(session_id, callback):
|
|
# Publish invalid message
|
|
published = await message_dispatch.publish_message(session_id, invalid_json)
|
|
assert published
|
|
|
|
# Give some time for the message to be processed
|
|
await anyio.sleep(0.1)
|
|
|
|
# Callback should have been called with a ValidationError
|
|
callback.assert_called_once()
|
|
error = callback.call_args[0][0]
|
|
assert isinstance(error, ValidationError)
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_publish_to_nonexistent_session(message_dispatch: RedisMessageDispatch):
|
|
"""Test publishing to a session that doesn't exist."""
|
|
session_id = uuid4()
|
|
message = SessionMessage(
|
|
message=types.JSONRPCMessage.model_validate(
|
|
{"jsonrpc": "2.0", "method": "test", "params": {}, "id": 1}
|
|
)
|
|
)
|
|
|
|
published = await message_dispatch.publish_message(session_id, message)
|
|
assert not published
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_extract_session_id(message_dispatch):
|
|
"""Test extracting session ID from channel name."""
|
|
session_id = uuid4()
|
|
channel = message_dispatch._session_channel(session_id)
|
|
|
|
# Valid channel
|
|
extracted_id = message_dispatch._extract_session_id(channel)
|
|
assert extracted_id == session_id
|
|
|
|
# Invalid channel format
|
|
extracted_id = message_dispatch._extract_session_id("invalid_channel_name")
|
|
assert extracted_id is None
|
|
|
|
# Invalid UUID in channel
|
|
invalid_channel = f"{message_dispatch._prefix}session:invalid_uuid"
|
|
extracted_id = message_dispatch._extract_session_id(invalid_channel)
|
|
assert extracted_id is None
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_multiple_sessions(message_dispatch: RedisMessageDispatch):
|
|
"""Test handling multiple concurrent sessions."""
|
|
session1 = uuid4()
|
|
session2 = uuid4()
|
|
callback1 = AsyncMock()
|
|
callback2 = AsyncMock()
|
|
|
|
async with message_dispatch.subscribe(session1, callback1):
|
|
async with message_dispatch.subscribe(session2, callback2):
|
|
# Both sessions should exist
|
|
assert await message_dispatch.session_exists(session1)
|
|
assert await message_dispatch.session_exists(session2)
|
|
|
|
# Publish to session1
|
|
message1 = types.JSONRPCMessage.model_validate(
|
|
{"jsonrpc": "2.0", "method": "test1", "params": {}, "id": 1}
|
|
)
|
|
await message_dispatch.publish_message(
|
|
session1, SessionMessage(message=message1)
|
|
)
|
|
|
|
# Publish to session2
|
|
message2 = types.JSONRPCMessage.model_validate(
|
|
{"jsonrpc": "2.0", "method": "test2", "params": {}, "id": 2}
|
|
)
|
|
await message_dispatch.publish_message(
|
|
session2, SessionMessage(message=message2)
|
|
)
|
|
|
|
# Give some time for messages to be processed
|
|
await anyio.sleep(0.1)
|
|
|
|
# Check callbacks
|
|
callback1.assert_called_once()
|
|
callback2.assert_called_once()
|
|
|
|
call1_args = callback1.call_args[0][0]
|
|
assert isinstance(call1_args, SessionMessage)
|
|
assert call1_args.message.root.method == "test1" # type: ignore
|
|
|
|
call2_args = callback2.call_args[0][0]
|
|
assert isinstance(call2_args, SessionMessage)
|
|
assert call2_args.message.root.method == "test2" # type: ignore
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_task_group_cancellation(message_dispatch):
|
|
"""Test that task group is properly cancelled when context exits."""
|
|
session_id = uuid4()
|
|
callback = AsyncMock()
|
|
|
|
async with message_dispatch.subscribe(session_id, callback):
|
|
# Check that task group is active
|
|
_, task_group = message_dispatch._session_state[session_id]
|
|
assert task_group.cancel_scope.cancel_called is False
|
|
|
|
# After context exit, task group should be cancelled
|
|
# And session state should be cleaned up
|
|
assert session_id not in message_dispatch._session_state
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_session_cancellation_isolation(message_dispatch):
|
|
"""Test that cancelling one session doesn't affect other sessions."""
|
|
session1 = uuid4()
|
|
session2 = uuid4()
|
|
|
|
# Create a blocking callback for session1 to ensure it's running when cancelled
|
|
session1_event = anyio.Event()
|
|
session1_started = anyio.Event()
|
|
session1_cancelled = False
|
|
|
|
async def blocking_callback1(msg):
|
|
session1_started.set()
|
|
try:
|
|
await session1_event.wait()
|
|
except anyio.get_cancelled_exc_class():
|
|
nonlocal session1_cancelled
|
|
session1_cancelled = True
|
|
raise
|
|
|
|
callback2 = AsyncMock()
|
|
|
|
# Start session2 first
|
|
async with message_dispatch.subscribe(session2, callback2):
|
|
# Start session1 with a blocking callback
|
|
async with anyio.create_task_group() as tg:
|
|
|
|
async def session1_runner():
|
|
async with message_dispatch.subscribe(session1, blocking_callback1):
|
|
# Publish a message to trigger the blocking callback
|
|
message = types.JSONRPCMessage.model_validate(
|
|
{"jsonrpc": "2.0", "method": "test", "params": {}, "id": 1}
|
|
)
|
|
await message_dispatch.publish_message(session1, message)
|
|
|
|
# Wait for the callback to start
|
|
await session1_started.wait()
|
|
|
|
# Keep the context alive while we test cancellation
|
|
await anyio.sleep_forever()
|
|
|
|
tg.start_soon(session1_runner)
|
|
|
|
# Wait for session1's callback to start
|
|
await session1_started.wait()
|
|
|
|
# Cancel session1
|
|
tg.cancel_scope.cancel()
|
|
|
|
# Give some time for cancellation to propagate
|
|
await anyio.sleep(0.1)
|
|
|
|
# Verify session1 was cancelled
|
|
assert session1_cancelled
|
|
assert session1 not in message_dispatch._session_state
|
|
|
|
# Verify session2 is still active and can receive messages
|
|
assert await message_dispatch.session_exists(session2)
|
|
message2 = types.JSONRPCMessage.model_validate(
|
|
{"jsonrpc": "2.0", "method": "test2", "params": {}, "id": 2}
|
|
)
|
|
await message_dispatch.publish_message(session2, message2)
|
|
|
|
# Give some time for the message to be processed
|
|
await anyio.sleep(0.1)
|
|
|
|
# Verify session2 received the message
|
|
callback2.assert_called_once()
|
|
call_args = callback2.call_args[0][0]
|
|
assert call_args.root.method == "test2"
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_listener_task_handoff_on_cancellation(message_dispatch):
|
|
"""
|
|
Test that the single listening task is properly
|
|
handed off when a session is cancelled.
|
|
"""
|
|
session1 = uuid4()
|
|
session2 = uuid4()
|
|
|
|
session1_messages_received = 0
|
|
session2_messages_received = 0
|
|
|
|
async def callback1(msg):
|
|
nonlocal session1_messages_received
|
|
session1_messages_received += 1
|
|
|
|
async def callback2(msg):
|
|
nonlocal session2_messages_received
|
|
session2_messages_received += 1
|
|
|
|
# Create a cancel scope for session1
|
|
async with anyio.create_task_group() as tg:
|
|
session1_cancel_scope: anyio.CancelScope | None = None
|
|
|
|
async def session1_runner():
|
|
nonlocal session1_cancel_scope
|
|
with anyio.CancelScope() as cancel_scope:
|
|
session1_cancel_scope = cancel_scope
|
|
async with message_dispatch.subscribe(session1, callback1):
|
|
# Keep session alive until cancelled
|
|
await anyio.sleep_forever()
|
|
|
|
# Start session1
|
|
tg.start_soon(session1_runner)
|
|
|
|
# Wait for session1 to be established
|
|
await anyio.sleep(0.1)
|
|
assert session1 in message_dispatch._session_state
|
|
|
|
# Send message to session1 to verify it's working
|
|
message1 = types.JSONRPCMessage.model_validate(
|
|
{"jsonrpc": "2.0", "method": "test1", "params": {}, "id": 1}
|
|
)
|
|
await message_dispatch.publish_message(session1, message1)
|
|
await anyio.sleep(0.1)
|
|
assert session1_messages_received == 1
|
|
|
|
# Start session2 while session1 is still active
|
|
async with message_dispatch.subscribe(session2, callback2):
|
|
# Both sessions should be active
|
|
assert session1 in message_dispatch._session_state
|
|
assert session2 in message_dispatch._session_state
|
|
|
|
# Cancel session1
|
|
assert session1_cancel_scope is not None
|
|
session1_cancel_scope.cancel()
|
|
|
|
# Wait for cancellation to complete
|
|
await anyio.sleep(0.1)
|
|
|
|
# Session1 should be gone, session2 should remain
|
|
assert session1 not in message_dispatch._session_state
|
|
assert session2 in message_dispatch._session_state
|
|
|
|
# Send message to session2 to verify the listener was handed off
|
|
message2 = types.JSONRPCMessage.model_validate(
|
|
{"jsonrpc": "2.0", "method": "test2", "params": {}, "id": 2}
|
|
)
|
|
await message_dispatch.publish_message(session2, message2)
|
|
await anyio.sleep(0.1)
|
|
|
|
# Session2 should have received the message
|
|
assert session2_messages_received == 1
|
|
|
|
# Session1 shouldn't receive any more messages
|
|
assert session1_messages_received == 1
|
|
|
|
# Send another message to verify the listener is still working
|
|
message3 = types.JSONRPCMessage.model_validate(
|
|
{"jsonrpc": "2.0", "method": "test3", "params": {}, "id": 3}
|
|
)
|
|
await message_dispatch.publish_message(session2, message3)
|
|
await anyio.sleep(0.1)
|
|
|
|
assert session2_messages_received == 2
|