Merge pull request #206 from modelcontextprotocol/jerome/fix/188

Jerome/fix/188
This commit is contained in:
David Soria Parra
2025-02-13 15:09:42 +00:00
committed by GitHub
4 changed files with 87 additions and 30 deletions

0
.git-blame-ignore-revs Normal file
View File

1
.gitignore vendored
View File

@@ -1,4 +1,5 @@
.DS_Store .DS_Store
scratch/
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/

View File

@@ -68,9 +68,10 @@ import contextvars
import logging import logging
import warnings import warnings
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from contextlib import AbstractAsyncContextManager, asynccontextmanager from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
from typing import Any, AsyncIterator, Generic, Sequence, TypeVar from typing import Any, AsyncIterator, Generic, Sequence, TypeVar
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl from pydantic import AnyUrl
@@ -469,41 +470,47 @@ class Server(Generic[LifespanResultT]):
# in-process servers. # in-process servers.
raise_exceptions: bool = False, raise_exceptions: bool = False,
): ):
with warnings.catch_warnings(record=True) as w: async with AsyncExitStack() as stack:
from contextlib import AsyncExitStack lifespan_context = await stack.enter_async_context(self.lifespan(self))
session = await stack.enter_async_context(
async with AsyncExitStack() as stack: ServerSession(read_stream, write_stream, initialization_options)
lifespan_context = await stack.enter_async_context(self.lifespan(self)) )
session = await stack.enter_async_context(
ServerSession(read_stream, write_stream, initialization_options)
)
async with anyio.create_task_group() as tg:
async for message in session.incoming_messages: async for message in session.incoming_messages:
logger.debug(f"Received message: {message}") logger.debug(f"Received message: {message}")
match message: tg.start_soon(
case ( self._handle_message,
RequestResponder( message,
request=types.ClientRequest(root=req) session,
) as responder lifespan_context,
): raise_exceptions,
with responder: )
await self._handle_request(
message,
req,
session,
lifespan_context,
raise_exceptions,
)
case types.ClientNotification(root=notify):
await self._handle_notification(notify)
for warning in w: async def _handle_message(
logger.info( self,
"Warning: %s: %s", message: RequestResponder[types.ClientRequest, types.ServerResult]
warning.category.__name__, | types.ClientNotification
warning.message, | Exception,
session: ServerSession,
lifespan_context: LifespanResultT,
raise_exceptions: bool = False,
):
with warnings.catch_warnings(record=True) as w:
match message:
case (
RequestResponder(request=types.ClientRequest(root=req)) as responder
):
with responder:
await self._handle_request(
message, req, session, lifespan_context, raise_exceptions
) )
case types.ClientNotification(root=notify):
await self._handle_notification(notify)
for warning in w:
logger.info(f"Warning: {warning.category.__name__}: {warning.message}")
async def _handle_request( async def _handle_request(
self, self,

View File

@@ -0,0 +1,49 @@
import anyio
from pydantic import AnyUrl
from mcp.server.fastmcp import FastMCP
from mcp.shared.memory import (
create_connected_server_and_client_session as create_session,
)
_sleep_time_seconds = 0.01
_resource_name = "slow://slow_resource"
async def test_messages_are_executed_concurrently():
server = FastMCP("test")
@server.tool("sleep")
async def sleep_tool():
await anyio.sleep(_sleep_time_seconds)
return "done"
@server.resource(_resource_name)
async def slow_resource():
await anyio.sleep(_sleep_time_seconds)
return "slow"
async with create_session(server._mcp_server) as client_session:
start_time = anyio.current_time()
async with anyio.create_task_group() as tg:
for _ in range(10):
tg.start_soon(client_session.call_tool, "sleep")
tg.start_soon(client_session.read_resource, AnyUrl(_resource_name))
end_time = anyio.current_time()
duration = end_time - start_time
assert duration < 3 * _sleep_time_seconds
print(duration)
def main():
anyio.run(test_messages_are_executed_concurrently)
if __name__ == "__main__":
import logging
logging.basicConfig(level=logging.DEBUG)
main()