Merge pull request #206 from modelcontextprotocol/jerome/fix/188

Jerome/fix/188
This commit is contained in:
David Soria Parra
2025-02-13 15:09:42 +00:00
committed by GitHub
4 changed files with 87 additions and 30 deletions

0
.git-blame-ignore-revs Normal file
View File

1
.gitignore vendored
View File

@@ -1,4 +1,5 @@
.DS_Store
scratch/
# Byte-compiled / optimized / DLL files
__pycache__/

View File

@@ -68,9 +68,10 @@ import contextvars
import logging
import warnings
from collections.abc import Awaitable, Callable
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
from typing import Any, AsyncIterator, Generic, Sequence, TypeVar
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
@@ -469,41 +470,47 @@ class Server(Generic[LifespanResultT]):
# in-process servers.
raise_exceptions: bool = False,
):
with warnings.catch_warnings(record=True) as w:
from contextlib import AsyncExitStack
async with AsyncExitStack() as stack:
lifespan_context = await stack.enter_async_context(self.lifespan(self))
session = await stack.enter_async_context(
ServerSession(read_stream, write_stream, initialization_options)
)
async with anyio.create_task_group() as tg:
async for message in session.incoming_messages:
logger.debug(f"Received message: {message}")
match message:
case (
RequestResponder(
request=types.ClientRequest(root=req)
) as responder
):
with responder:
await self._handle_request(
tg.start_soon(
self._handle_message,
message,
req,
session,
lifespan_context,
raise_exceptions,
)
async def _handle_message(
self,
message: RequestResponder[types.ClientRequest, types.ServerResult]
| types.ClientNotification
| Exception,
session: ServerSession,
lifespan_context: LifespanResultT,
raise_exceptions: bool = False,
):
with warnings.catch_warnings(record=True) as w:
match message:
case (
RequestResponder(request=types.ClientRequest(root=req)) as responder
):
with responder:
await self._handle_request(
message, req, session, lifespan_context, raise_exceptions
)
case types.ClientNotification(root=notify):
await self._handle_notification(notify)
for warning in w:
logger.info(
"Warning: %s: %s",
warning.category.__name__,
warning.message,
)
logger.info(f"Warning: {warning.category.__name__}: {warning.message}")
async def _handle_request(
self,

View File

@@ -0,0 +1,49 @@
import anyio
from pydantic import AnyUrl
from mcp.server.fastmcp import FastMCP
from mcp.shared.memory import (
create_connected_server_and_client_session as create_session,
)
_sleep_time_seconds = 0.01
_resource_name = "slow://slow_resource"
async def test_messages_are_executed_concurrently():
server = FastMCP("test")
@server.tool("sleep")
async def sleep_tool():
await anyio.sleep(_sleep_time_seconds)
return "done"
@server.resource(_resource_name)
async def slow_resource():
await anyio.sleep(_sleep_time_seconds)
return "slow"
async with create_session(server._mcp_server) as client_session:
start_time = anyio.current_time()
async with anyio.create_task_group() as tg:
for _ in range(10):
tg.start_soon(client_session.call_tool, "sleep")
tg.start_soon(client_session.read_resource, AnyUrl(_resource_name))
end_time = anyio.current_time()
duration = end_time - start_time
assert duration < 3 * _sleep_time_seconds
print(duration)
def main():
anyio.run(test_messages_are_executed_concurrently)
if __name__ == "__main__":
import logging
logging.basicConfig(level=logging.DEBUG)
main()