mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-20 15:24:25 +01:00
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
committed by
GitHub
parent
9ae4df85fb
commit
568cbd1a66
@@ -10,7 +10,13 @@ from typing import Any
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
|
||||
from mcp.client.session import ClientSession, ListRootsFnT, LoggingFnT, SamplingFnT
|
||||
from mcp.client.session import (
|
||||
ClientSession,
|
||||
ListRootsFnT,
|
||||
LoggingFnT,
|
||||
MessageHandlerFnT,
|
||||
SamplingFnT,
|
||||
)
|
||||
from mcp.server import Server
|
||||
from mcp.types import JSONRPCMessage
|
||||
|
||||
@@ -58,6 +64,7 @@ async def create_connected_server_and_client_session(
|
||||
sampling_callback: SamplingFnT | None = None,
|
||||
list_roots_callback: ListRootsFnT | None = None,
|
||||
logging_callback: LoggingFnT | None = None,
|
||||
message_handler: MessageHandlerFnT | None = None,
|
||||
raise_exceptions: bool = False,
|
||||
) -> AsyncGenerator[ClientSession, None]:
|
||||
"""Creates a ClientSession that is connected to a running MCP server."""
|
||||
@@ -87,6 +94,7 @@ async def create_connected_server_and_client_session(
|
||||
sampling_callback=sampling_callback,
|
||||
list_roots_callback=list_roots_callback,
|
||||
logging_callback=logging_callback,
|
||||
message_handler=message_handler,
|
||||
) as client_session:
|
||||
await client_session.initialize()
|
||||
yield client_session
|
||||
|
||||
@@ -189,19 +189,6 @@ class BaseSession(
|
||||
self._in_flight = {}
|
||||
|
||||
self._exit_stack = AsyncExitStack()
|
||||
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
|
||||
anyio.create_memory_object_stream[
|
||||
RequestResponder[ReceiveRequestT, SendResultT]
|
||||
| ReceiveNotificationT
|
||||
| Exception
|
||||
]()
|
||||
)
|
||||
self._exit_stack.push_async_callback(
|
||||
lambda: self._incoming_message_stream_reader.aclose()
|
||||
)
|
||||
self._exit_stack.push_async_callback(
|
||||
lambda: self._incoming_message_stream_writer.aclose()
|
||||
)
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
self._task_group = anyio.create_task_group()
|
||||
@@ -312,11 +299,10 @@ class BaseSession(
|
||||
async with (
|
||||
self._read_stream,
|
||||
self._write_stream,
|
||||
self._incoming_message_stream_writer,
|
||||
):
|
||||
async for message in self._read_stream:
|
||||
if isinstance(message, Exception):
|
||||
await self._incoming_message_stream_writer.send(message)
|
||||
await self._handle_incoming(message)
|
||||
elif isinstance(message.root, JSONRPCRequest):
|
||||
validated_request = self._receive_request_type.model_validate(
|
||||
message.root.model_dump(
|
||||
@@ -336,8 +322,9 @@ class BaseSession(
|
||||
|
||||
self._in_flight[responder.request_id] = responder
|
||||
await self._received_request(responder)
|
||||
|
||||
if not responder._completed: # type: ignore[reportPrivateUsage]
|
||||
await self._incoming_message_stream_writer.send(responder)
|
||||
await self._handle_incoming(responder)
|
||||
|
||||
elif isinstance(message.root, JSONRPCNotification):
|
||||
try:
|
||||
@@ -353,9 +340,7 @@ class BaseSession(
|
||||
await self._in_flight[cancelled_id].cancel()
|
||||
else:
|
||||
await self._received_notification(notification)
|
||||
await self._incoming_message_stream_writer.send(
|
||||
notification
|
||||
)
|
||||
await self._handle_incoming(notification)
|
||||
except Exception as e:
|
||||
# For other validation errors, log and continue
|
||||
logging.warning(
|
||||
@@ -367,7 +352,7 @@ class BaseSession(
|
||||
if stream:
|
||||
await stream.send(message.root)
|
||||
else:
|
||||
await self._incoming_message_stream_writer.send(
|
||||
await self._handle_incoming(
|
||||
RuntimeError(
|
||||
"Received response with an unknown "
|
||||
f"request ID: {message}"
|
||||
@@ -399,12 +384,11 @@ class BaseSession(
|
||||
processed.
|
||||
"""
|
||||
|
||||
@property
|
||||
def incoming_messages(
|
||||
async def _handle_incoming(
|
||||
self,
|
||||
) -> MemoryObjectReceiveStream[
|
||||
RequestResponder[ReceiveRequestT, SendResultT]
|
||||
req: RequestResponder[ReceiveRequestT, SendResultT]
|
||||
| ReceiveNotificationT
|
||||
| Exception
|
||||
]:
|
||||
return self._incoming_message_stream_reader
|
||||
| Exception,
|
||||
) -> None:
|
||||
"""A generic handler for incoming messages. Overwritten by subclasses."""
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user