mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Made message handling concurrent
This commit is contained in:
committed by
David Soria Parra
parent
9abfe775cc
commit
da53a97ed9
@@ -68,9 +68,10 @@ import contextvars
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
|
||||
from typing import Any, AsyncIterator, Generic, Sequence, TypeVar
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl
|
||||
|
||||
@@ -458,6 +459,30 @@ class Server(Generic[LifespanResultT]):
|
||||
|
||||
return decorator
|
||||
|
||||
async def _handle_message(
|
||||
self,
|
||||
message: RequestResponder[types.ClientRequest, types.ServerResult]
|
||||
| types.ClientNotification
|
||||
| Exception,
|
||||
session: ServerSession,
|
||||
lifespan_context: LifespanResultT,
|
||||
raise_exceptions: bool = False,
|
||||
):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
match message:
|
||||
case (
|
||||
RequestResponder(request=types.ClientRequest(root=req)) as responder
|
||||
):
|
||||
with responder:
|
||||
await self._handle_request(
|
||||
message, req, session, lifespan_context, raise_exceptions
|
||||
)
|
||||
case types.ClientNotification(root=notify):
|
||||
await self._handle_notification(notify)
|
||||
|
||||
for warning in w:
|
||||
logger.info(f"Warning: {warning.category.__name__}: {warning.message}")
|
||||
|
||||
async def run(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
|
||||
@@ -469,41 +494,23 @@ class Server(Generic[LifespanResultT]):
|
||||
# in-process servers.
|
||||
raise_exceptions: bool = False,
|
||||
):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
from contextlib import AsyncExitStack
|
||||
|
||||
async with AsyncExitStack() as stack:
|
||||
lifespan_context = await stack.enter_async_context(self.lifespan(self))
|
||||
session = await stack.enter_async_context(
|
||||
ServerSession(read_stream, write_stream, initialization_options)
|
||||
)
|
||||
async with AsyncExitStack() as stack:
|
||||
lifespan_context = await stack.enter_async_context(self.lifespan(self))
|
||||
session = await stack.enter_async_context(
|
||||
ServerSession(read_stream, write_stream, initialization_options)
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
async for message in session.incoming_messages:
|
||||
logger.debug(f"Received message: {message}")
|
||||
|
||||
match message:
|
||||
case (
|
||||
RequestResponder(
|
||||
request=types.ClientRequest(root=req)
|
||||
) as responder
|
||||
):
|
||||
with responder:
|
||||
await self._handle_request(
|
||||
message,
|
||||
req,
|
||||
session,
|
||||
lifespan_context,
|
||||
raise_exceptions,
|
||||
)
|
||||
case types.ClientNotification(root=notify):
|
||||
await self._handle_notification(notify)
|
||||
|
||||
for warning in w:
|
||||
logger.info(
|
||||
"Warning: %s: %s",
|
||||
warning.category.__name__,
|
||||
warning.message,
|
||||
)
|
||||
tg.start_soon(
|
||||
self._handle_message,
|
||||
message,
|
||||
session,
|
||||
lifespan_context,
|
||||
raise_exceptions,
|
||||
)
|
||||
|
||||
async def _handle_request(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user