mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Merge pull request #206 from modelcontextprotocol/jerome/fix/188
Jerome/fix/188
This commit is contained in:
0
.git-blame-ignore-revs
Normal file
0
.git-blame-ignore-revs
Normal file
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,4 +1,5 @@
|
||||
.DS_Store
|
||||
scratch/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
||||
@@ -68,9 +68,10 @@ import contextvars
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
|
||||
from typing import Any, AsyncIterator, Generic, Sequence, TypeVar
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl
|
||||
|
||||
@@ -469,41 +470,47 @@ class Server(Generic[LifespanResultT]):
|
||||
# in-process servers.
|
||||
raise_exceptions: bool = False,
|
||||
):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
from contextlib import AsyncExitStack
|
||||
|
||||
async with AsyncExitStack() as stack:
|
||||
lifespan_context = await stack.enter_async_context(self.lifespan(self))
|
||||
session = await stack.enter_async_context(
|
||||
ServerSession(read_stream, write_stream, initialization_options)
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
async for message in session.incoming_messages:
|
||||
logger.debug(f"Received message: {message}")
|
||||
|
||||
match message:
|
||||
case (
|
||||
RequestResponder(
|
||||
request=types.ClientRequest(root=req)
|
||||
) as responder
|
||||
):
|
||||
with responder:
|
||||
await self._handle_request(
|
||||
tg.start_soon(
|
||||
self._handle_message,
|
||||
message,
|
||||
req,
|
||||
session,
|
||||
lifespan_context,
|
||||
raise_exceptions,
|
||||
)
|
||||
|
||||
async def _handle_message(
|
||||
self,
|
||||
message: RequestResponder[types.ClientRequest, types.ServerResult]
|
||||
| types.ClientNotification
|
||||
| Exception,
|
||||
session: ServerSession,
|
||||
lifespan_context: LifespanResultT,
|
||||
raise_exceptions: bool = False,
|
||||
):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
match message:
|
||||
case (
|
||||
RequestResponder(request=types.ClientRequest(root=req)) as responder
|
||||
):
|
||||
with responder:
|
||||
await self._handle_request(
|
||||
message, req, session, lifespan_context, raise_exceptions
|
||||
)
|
||||
case types.ClientNotification(root=notify):
|
||||
await self._handle_notification(notify)
|
||||
|
||||
for warning in w:
|
||||
logger.info(
|
||||
"Warning: %s: %s",
|
||||
warning.category.__name__,
|
||||
warning.message,
|
||||
)
|
||||
logger.info(f"Warning: {warning.category.__name__}: {warning.message}")
|
||||
|
||||
async def _handle_request(
|
||||
self,
|
||||
|
||||
49
tests/issues/test_188_concurrency.py
Normal file
49
tests/issues/test_188_concurrency.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import anyio
|
||||
from pydantic import AnyUrl
|
||||
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp.shared.memory import (
|
||||
create_connected_server_and_client_session as create_session,
|
||||
)
|
||||
|
||||
_sleep_time_seconds = 0.01
|
||||
_resource_name = "slow://slow_resource"
|
||||
|
||||
|
||||
async def test_messages_are_executed_concurrently():
|
||||
server = FastMCP("test")
|
||||
|
||||
@server.tool("sleep")
|
||||
async def sleep_tool():
|
||||
await anyio.sleep(_sleep_time_seconds)
|
||||
return "done"
|
||||
|
||||
@server.resource(_resource_name)
|
||||
async def slow_resource():
|
||||
await anyio.sleep(_sleep_time_seconds)
|
||||
return "slow"
|
||||
|
||||
async with create_session(server._mcp_server) as client_session:
|
||||
start_time = anyio.current_time()
|
||||
async with anyio.create_task_group() as tg:
|
||||
for _ in range(10):
|
||||
tg.start_soon(client_session.call_tool, "sleep")
|
||||
tg.start_soon(client_session.read_resource, AnyUrl(_resource_name))
|
||||
|
||||
end_time = anyio.current_time()
|
||||
|
||||
duration = end_time - start_time
|
||||
assert duration < 3 * _sleep_time_seconds
|
||||
print(duration)
|
||||
|
||||
|
||||
def main():
|
||||
anyio.run(test_messages_are_executed_concurrently)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
main()
|
||||
Reference in New Issue
Block a user