mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Merge pull request #206 from modelcontextprotocol/jerome/fix/188
Jerome/fix/188
This commit is contained in:
0
.git-blame-ignore-revs
Normal file
0
.git-blame-ignore-revs
Normal file
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,4 +1,5 @@
|
|||||||
.DS_Store
|
.DS_Store
|
||||||
|
scratch/
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
|||||||
@@ -68,9 +68,10 @@ import contextvars
|
|||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
|
||||||
from typing import Any, AsyncIterator, Generic, Sequence, TypeVar
|
from typing import Any, AsyncIterator, Generic, Sequence, TypeVar
|
||||||
|
|
||||||
|
import anyio
|
||||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||||
from pydantic import AnyUrl
|
from pydantic import AnyUrl
|
||||||
|
|
||||||
@@ -469,41 +470,47 @@ class Server(Generic[LifespanResultT]):
|
|||||||
# in-process servers.
|
# in-process servers.
|
||||||
raise_exceptions: bool = False,
|
raise_exceptions: bool = False,
|
||||||
):
|
):
|
||||||
with warnings.catch_warnings(record=True) as w:
|
async with AsyncExitStack() as stack:
|
||||||
from contextlib import AsyncExitStack
|
lifespan_context = await stack.enter_async_context(self.lifespan(self))
|
||||||
|
session = await stack.enter_async_context(
|
||||||
async with AsyncExitStack() as stack:
|
ServerSession(read_stream, write_stream, initialization_options)
|
||||||
lifespan_context = await stack.enter_async_context(self.lifespan(self))
|
)
|
||||||
session = await stack.enter_async_context(
|
|
||||||
ServerSession(read_stream, write_stream, initialization_options)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
async with anyio.create_task_group() as tg:
|
||||||
async for message in session.incoming_messages:
|
async for message in session.incoming_messages:
|
||||||
logger.debug(f"Received message: {message}")
|
logger.debug(f"Received message: {message}")
|
||||||
|
|
||||||
match message:
|
tg.start_soon(
|
||||||
case (
|
self._handle_message,
|
||||||
RequestResponder(
|
message,
|
||||||
request=types.ClientRequest(root=req)
|
session,
|
||||||
) as responder
|
lifespan_context,
|
||||||
):
|
raise_exceptions,
|
||||||
with responder:
|
)
|
||||||
await self._handle_request(
|
|
||||||
message,
|
|
||||||
req,
|
|
||||||
session,
|
|
||||||
lifespan_context,
|
|
||||||
raise_exceptions,
|
|
||||||
)
|
|
||||||
case types.ClientNotification(root=notify):
|
|
||||||
await self._handle_notification(notify)
|
|
||||||
|
|
||||||
for warning in w:
|
async def _handle_message(
|
||||||
logger.info(
|
self,
|
||||||
"Warning: %s: %s",
|
message: RequestResponder[types.ClientRequest, types.ServerResult]
|
||||||
warning.category.__name__,
|
| types.ClientNotification
|
||||||
warning.message,
|
| Exception,
|
||||||
|
session: ServerSession,
|
||||||
|
lifespan_context: LifespanResultT,
|
||||||
|
raise_exceptions: bool = False,
|
||||||
|
):
|
||||||
|
with warnings.catch_warnings(record=True) as w:
|
||||||
|
match message:
|
||||||
|
case (
|
||||||
|
RequestResponder(request=types.ClientRequest(root=req)) as responder
|
||||||
|
):
|
||||||
|
with responder:
|
||||||
|
await self._handle_request(
|
||||||
|
message, req, session, lifespan_context, raise_exceptions
|
||||||
)
|
)
|
||||||
|
case types.ClientNotification(root=notify):
|
||||||
|
await self._handle_notification(notify)
|
||||||
|
|
||||||
|
for warning in w:
|
||||||
|
logger.info(f"Warning: {warning.category.__name__}: {warning.message}")
|
||||||
|
|
||||||
async def _handle_request(
|
async def _handle_request(
|
||||||
self,
|
self,
|
||||||
|
|||||||
49
tests/issues/test_188_concurrency.py
Normal file
49
tests/issues/test_188_concurrency.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import anyio
|
||||||
|
from pydantic import AnyUrl
|
||||||
|
|
||||||
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
from mcp.shared.memory import (
|
||||||
|
create_connected_server_and_client_session as create_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
_sleep_time_seconds = 0.01
|
||||||
|
_resource_name = "slow://slow_resource"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_messages_are_executed_concurrently():
|
||||||
|
server = FastMCP("test")
|
||||||
|
|
||||||
|
@server.tool("sleep")
|
||||||
|
async def sleep_tool():
|
||||||
|
await anyio.sleep(_sleep_time_seconds)
|
||||||
|
return "done"
|
||||||
|
|
||||||
|
@server.resource(_resource_name)
|
||||||
|
async def slow_resource():
|
||||||
|
await anyio.sleep(_sleep_time_seconds)
|
||||||
|
return "slow"
|
||||||
|
|
||||||
|
async with create_session(server._mcp_server) as client_session:
|
||||||
|
start_time = anyio.current_time()
|
||||||
|
async with anyio.create_task_group() as tg:
|
||||||
|
for _ in range(10):
|
||||||
|
tg.start_soon(client_session.call_tool, "sleep")
|
||||||
|
tg.start_soon(client_session.read_resource, AnyUrl(_resource_name))
|
||||||
|
|
||||||
|
end_time = anyio.current_time()
|
||||||
|
|
||||||
|
duration = end_time - start_time
|
||||||
|
assert duration < 3 * _sleep_time_seconds
|
||||||
|
print(duration)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
anyio.run(test_messages_are_executed_concurrently)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user