Made message handling concurrent

This commit is contained in:
Jerome
2025-02-13 15:22:58 +13:00
committed by David Soria Parra
parent 9abfe775cc
commit da53a97ed9

View File

@@ -68,9 +68,10 @@ import contextvars
import logging import logging
import warnings import warnings
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from contextlib import AbstractAsyncContextManager, asynccontextmanager from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
from typing import Any, AsyncIterator, Generic, Sequence, TypeVar from typing import Any, AsyncIterator, Generic, Sequence, TypeVar
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl from pydantic import AnyUrl
@@ -458,6 +459,30 @@ class Server(Generic[LifespanResultT]):
return decorator return decorator
async def _handle_message(
self,
message: RequestResponder[types.ClientRequest, types.ServerResult]
| types.ClientNotification
| Exception,
session: ServerSession,
lifespan_context: LifespanResultT,
raise_exceptions: bool = False,
):
with warnings.catch_warnings(record=True) as w:
match message:
case (
RequestResponder(request=types.ClientRequest(root=req)) as responder
):
with responder:
await self._handle_request(
message, req, session, lifespan_context, raise_exceptions
)
case types.ClientNotification(root=notify):
await self._handle_notification(notify)
for warning in w:
logger.info(f"Warning: {warning.category.__name__}: {warning.message}")
async def run( async def run(
self, self,
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
@@ -469,41 +494,23 @@ class Server(Generic[LifespanResultT]):
# in-process servers. # in-process servers.
raise_exceptions: bool = False, raise_exceptions: bool = False,
): ):
with warnings.catch_warnings(record=True) as w:
from contextlib import AsyncExitStack
async with AsyncExitStack() as stack: async with AsyncExitStack() as stack:
lifespan_context = await stack.enter_async_context(self.lifespan(self)) lifespan_context = await stack.enter_async_context(self.lifespan(self))
session = await stack.enter_async_context( session = await stack.enter_async_context(
ServerSession(read_stream, write_stream, initialization_options) ServerSession(read_stream, write_stream, initialization_options)
) )
async with anyio.create_task_group() as tg:
async for message in session.incoming_messages: async for message in session.incoming_messages:
logger.debug(f"Received message: {message}") logger.debug(f"Received message: {message}")
match message: tg.start_soon(
case ( self._handle_message,
RequestResponder(
request=types.ClientRequest(root=req)
) as responder
):
with responder:
await self._handle_request(
message, message,
req,
session, session,
lifespan_context, lifespan_context,
raise_exceptions, raise_exceptions,
) )
case types.ClientNotification(root=notify):
await self._handle_notification(notify)
for warning in w:
logger.info(
"Warning: %s: %s",
warning.category.__name__,
warning.message,
)
async def _handle_request( async def _handle_request(
self, self,