from contextlib import asynccontextmanager from unittest.mock import patch import pytest import mcp.shared.memory from mcp.shared.message import SessionMessage from mcp.types import ( JSONRPCNotification, JSONRPCRequest, ) class SpyMemoryObjectSendStream: def __init__(self, original_stream): self.original_stream = original_stream self.sent_messages: list[SessionMessage] = [] async def send(self, message): self.sent_messages.append(message) await self.original_stream.send(message) async def aclose(self): await self.original_stream.aclose() async def __aenter__(self): return self async def __aexit__(self, *args): await self.aclose() class StreamSpyCollection: def __init__( self, client_spy: SpyMemoryObjectSendStream, server_spy: SpyMemoryObjectSendStream, ): self.client = client_spy self.server = server_spy def clear(self) -> None: """Clear all captured messages.""" self.client.sent_messages.clear() self.server.sent_messages.clear() def get_client_requests(self, method: str | None = None) -> list[JSONRPCRequest]: """Get client-sent requests, optionally filtered by method.""" return [ req.message.root for req in self.client.sent_messages if isinstance(req.message.root, JSONRPCRequest) and (method is None or req.message.root.method == method) ] def get_server_requests(self, method: str | None = None) -> list[JSONRPCRequest]: """Get server-sent requests, optionally filtered by method.""" return [ req.message.root for req in self.server.sent_messages if isinstance(req.message.root, JSONRPCRequest) and (method is None or req.message.root.method == method) ] def get_client_notifications( self, method: str | None = None ) -> list[JSONRPCNotification]: """Get client-sent notifications, optionally filtered by method.""" return [ notif.message.root for notif in self.client.sent_messages if isinstance(notif.message.root, JSONRPCNotification) and (method is None or notif.message.root.method == method) ] def get_server_notifications( self, method: str | None = None ) -> list[JSONRPCNotification]: """Get server-sent notifications, optionally filtered by method.""" return [ notif.message.root for notif in self.server.sent_messages if isinstance(notif.message.root, JSONRPCNotification) and (method is None or notif.message.root.method == method) ] @pytest.fixture def stream_spy(): """Fixture that provides spies for both client and server write streams. Example usage: async def test_something(stream_spy): # ... set up server and client ... spies = stream_spy() # Run some operation that sends messages await client.some_operation() # Check the messages requests = spies.get_client_requests(method="some/method") assert len(requests) == 1 # Clear for the next operation spies.clear() """ client_spy = None server_spy = None # Store references to our spy objects def capture_spies(c_spy, s_spy): nonlocal client_spy, server_spy client_spy = c_spy server_spy = s_spy # Create patched version of stream creation original_create_streams = mcp.shared.memory.create_client_server_memory_streams @asynccontextmanager async def patched_create_streams(): async with original_create_streams() as (client_streams, server_streams): client_read, client_write = client_streams server_read, server_write = server_streams # Create spy wrappers spy_client_write = SpyMemoryObjectSendStream(client_write) spy_server_write = SpyMemoryObjectSendStream(server_write) # Capture references for the test to use capture_spies(spy_client_write, spy_server_write) yield (client_read, spy_client_write), (server_read, spy_server_write) # Apply the patch for the duration of the test with patch( "mcp.shared.memory.create_client_server_memory_streams", patched_create_streams ): # Return a collection with helper methods def get_spy_collection() -> StreamSpyCollection: assert client_spy is not None, "client_spy was not initialized" assert server_spy is not None, "server_spy was not initialized" return StreamSpyCollection(client_spy, server_spy) yield get_spy_collection