Made message handling concurrent

This commit is contained in:
Jerome
2025-02-13 15:22:58 +13:00
committed by David Soria Parra
parent 9abfe775cc
commit da53a97ed9

View File

@@ -68,9 +68,10 @@ import contextvars
import logging
import warnings
from collections.abc import Awaitable, Callable
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
from typing import Any, AsyncIterator, Generic, Sequence, TypeVar
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
@@ -458,6 +459,30 @@ class Server(Generic[LifespanResultT]):
return decorator
async def _handle_message(
self,
message: RequestResponder[types.ClientRequest, types.ServerResult]
| types.ClientNotification
| Exception,
session: ServerSession,
lifespan_context: LifespanResultT,
raise_exceptions: bool = False,
):
with warnings.catch_warnings(record=True) as w:
match message:
case (
RequestResponder(request=types.ClientRequest(root=req)) as responder
):
with responder:
await self._handle_request(
message, req, session, lifespan_context, raise_exceptions
)
case types.ClientNotification(root=notify):
await self._handle_notification(notify)
for warning in w:
logger.info(f"Warning: {warning.category.__name__}: {warning.message}")
async def run(
self,
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
@@ -469,41 +494,23 @@ class Server(Generic[LifespanResultT]):
# in-process servers.
raise_exceptions: bool = False,
):
with warnings.catch_warnings(record=True) as w:
from contextlib import AsyncExitStack
async with AsyncExitStack() as stack:
lifespan_context = await stack.enter_async_context(self.lifespan(self))
session = await stack.enter_async_context(
ServerSession(read_stream, write_stream, initialization_options)
)
async with anyio.create_task_group() as tg:
async for message in session.incoming_messages:
logger.debug(f"Received message: {message}")
match message:
case (
RequestResponder(
request=types.ClientRequest(root=req)
) as responder
):
with responder:
await self._handle_request(
tg.start_soon(
self._handle_message,
message,
req,
session,
lifespan_context,
raise_exceptions,
)
case types.ClientNotification(root=notify):
await self._handle_notification(notify)
for warning in w:
logger.info(
"Warning: %s: %s",
warning.category.__name__,
warning.message,
)
async def _handle_request(
self,