Fix #201: Move incoming message stream from BaseSession to ServerSession (#325)

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
David Soria Parra
2025-03-24 14:14:14 +00:00
committed by GitHub
parent 9ae4df85fb
commit 568cbd1a66
9 changed files with 168 additions and 109 deletions

View File

@@ -1,6 +1,7 @@
from datetime import timedelta
from typing import Any, Protocol
import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl, TypeAdapter
@@ -31,6 +32,23 @@ class LoggingFnT(Protocol):
) -> None: ...
class MessageHandlerFnT(Protocol):
async def __call__(
self,
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
) -> None: ...
async def _default_message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
) -> None:
await anyio.lowlevel.checkpoint()
async def _default_sampling_callback(
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
@@ -78,6 +96,7 @@ class ClientSession(
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
) -> None:
super().__init__(
read_stream,
@@ -89,6 +108,7 @@ class ClientSession(
self._sampling_callback = sampling_callback or _default_sampling_callback
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
self._logging_callback = logging_callback or _default_logging_callback
self._message_handler = message_handler or _default_message_handler
async def initialize(self) -> types.InitializeResult:
sampling = types.SamplingCapability()
@@ -337,10 +357,20 @@ class ClientSession(
types.ClientResult(root=types.EmptyResult())
)
async def _handle_incoming(
self,
req: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
) -> None:
"""Handle incoming messages by forwarding to the message handler."""
await self._message_handler(req)
async def _received_notification(
self, notification: types.ServerNotification
) -> None:
"""Handle notifications from the server."""
# Process specific notification types
match notification.root:
case types.LoggingMessageNotification(params=params):
await self._logging_callback(params)