Fix #201: Move incoming message stream from BaseSession to ServerSession (#325)

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
David Soria Parra
2025-03-24 14:14:14 +00:00
committed by GitHub
parent 9ae4df85fb
commit 568cbd1a66
9 changed files with 168 additions and 109 deletions

View File

@@ -189,19 +189,6 @@ class BaseSession(
self._in_flight = {}
self._exit_stack = AsyncExitStack()
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
anyio.create_memory_object_stream[
RequestResponder[ReceiveRequestT, SendResultT]
| ReceiveNotificationT
| Exception
]()
)
self._exit_stack.push_async_callback(
lambda: self._incoming_message_stream_reader.aclose()
)
self._exit_stack.push_async_callback(
lambda: self._incoming_message_stream_writer.aclose()
)
async def __aenter__(self) -> Self:
self._task_group = anyio.create_task_group()
@@ -312,11 +299,10 @@ class BaseSession(
async with (
self._read_stream,
self._write_stream,
self._incoming_message_stream_writer,
):
async for message in self._read_stream:
if isinstance(message, Exception):
await self._incoming_message_stream_writer.send(message)
await self._handle_incoming(message)
elif isinstance(message.root, JSONRPCRequest):
validated_request = self._receive_request_type.model_validate(
message.root.model_dump(
@@ -336,8 +322,9 @@ class BaseSession(
self._in_flight[responder.request_id] = responder
await self._received_request(responder)
if not responder._completed: # type: ignore[reportPrivateUsage]
await self._incoming_message_stream_writer.send(responder)
await self._handle_incoming(responder)
elif isinstance(message.root, JSONRPCNotification):
try:
@@ -353,9 +340,7 @@ class BaseSession(
await self._in_flight[cancelled_id].cancel()
else:
await self._received_notification(notification)
await self._incoming_message_stream_writer.send(
notification
)
await self._handle_incoming(notification)
except Exception as e:
# For other validation errors, log and continue
logging.warning(
@@ -367,7 +352,7 @@ class BaseSession(
if stream:
await stream.send(message.root)
else:
await self._incoming_message_stream_writer.send(
await self._handle_incoming(
RuntimeError(
"Received response with an unknown "
f"request ID: {message}"
@@ -399,12 +384,11 @@ class BaseSession(
processed.
"""
@property
def incoming_messages(
async def _handle_incoming(
self,
) -> MemoryObjectReceiveStream[
RequestResponder[ReceiveRequestT, SendResultT]
req: RequestResponder[ReceiveRequestT, SendResultT]
| ReceiveNotificationT
| Exception
]:
return self._incoming_message_stream_reader
| Exception,
) -> None:
"""A generic handler for incoming messages. Overwritten by subclasses."""
pass