mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 06:54:18 +01:00
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
committed by
GitHub
parent
9ae4df85fb
commit
568cbd1a66
@@ -7,9 +7,11 @@ from urllib.parse import urlparse
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import StdioServerParameters, stdio_client
|
||||
from mcp.shared.session import RequestResponder
|
||||
from mcp.types import JSONRPCMessage
|
||||
|
||||
if not sys.warnoptions:
|
||||
@@ -21,26 +23,25 @@ logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("client")
|
||||
|
||||
|
||||
async def receive_loop(session: ClientSession):
|
||||
logger.info("Starting receive loop")
|
||||
async for message in session.incoming_messages:
|
||||
if isinstance(message, Exception):
|
||||
logger.error("Error: %s", message)
|
||||
continue
|
||||
async def message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
| types.ServerNotification
|
||||
| Exception,
|
||||
) -> None:
|
||||
if isinstance(message, Exception):
|
||||
logger.error("Error: %s", message)
|
||||
return
|
||||
|
||||
logger.info("Received message from server: %s", message)
|
||||
logger.info("Received message from server: %s", message)
|
||||
|
||||
|
||||
async def run_session(
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||
):
|
||||
async with (
|
||||
ClientSession(read_stream, write_stream) as session,
|
||||
anyio.create_task_group() as tg,
|
||||
):
|
||||
tg.start_soon(receive_loop, session)
|
||||
|
||||
async with ClientSession(
|
||||
read_stream, write_stream, message_handler=message_handler
|
||||
) as session:
|
||||
logger.info("Initializing session")
|
||||
await session.initialize()
|
||||
logger.info("Initialized")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any, Protocol
|
||||
|
||||
import anyio.lowlevel
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl, TypeAdapter
|
||||
|
||||
@@ -31,6 +32,23 @@ class LoggingFnT(Protocol):
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class MessageHandlerFnT(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
| types.ServerNotification
|
||||
| Exception,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
async def _default_message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
| types.ServerNotification
|
||||
| Exception,
|
||||
) -> None:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
|
||||
async def _default_sampling_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
@@ -78,6 +96,7 @@ class ClientSession(
|
||||
sampling_callback: SamplingFnT | None = None,
|
||||
list_roots_callback: ListRootsFnT | None = None,
|
||||
logging_callback: LoggingFnT | None = None,
|
||||
message_handler: MessageHandlerFnT | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
read_stream,
|
||||
@@ -89,6 +108,7 @@ class ClientSession(
|
||||
self._sampling_callback = sampling_callback or _default_sampling_callback
|
||||
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
|
||||
self._logging_callback = logging_callback or _default_logging_callback
|
||||
self._message_handler = message_handler or _default_message_handler
|
||||
|
||||
async def initialize(self) -> types.InitializeResult:
|
||||
sampling = types.SamplingCapability()
|
||||
@@ -337,10 +357,20 @@ class ClientSession(
|
||||
types.ClientResult(root=types.EmptyResult())
|
||||
)
|
||||
|
||||
async def _handle_incoming(
|
||||
self,
|
||||
req: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
| types.ServerNotification
|
||||
| Exception,
|
||||
) -> None:
|
||||
"""Handle incoming messages by forwarding to the message handler."""
|
||||
await self._message_handler(req)
|
||||
|
||||
async def _received_notification(
|
||||
self, notification: types.ServerNotification
|
||||
) -> None:
|
||||
"""Handle notifications from the server."""
|
||||
# Process specific notification types
|
||||
match notification.root:
|
||||
case types.LoggingMessageNotification(params=params):
|
||||
await self._logging_callback(params)
|
||||
|
||||
@@ -61,6 +61,12 @@ class InitializationState(Enum):
|
||||
|
||||
ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession")
|
||||
|
||||
ServerRequestResponder = (
|
||||
RequestResponder[types.ClientRequest, types.ServerResult]
|
||||
| types.ClientNotification
|
||||
| Exception
|
||||
)
|
||||
|
||||
|
||||
class ServerSession(
|
||||
BaseSession[
|
||||
@@ -85,6 +91,15 @@ class ServerSession(
|
||||
)
|
||||
self._initialization_state = InitializationState.NotInitialized
|
||||
self._init_options = init_options
|
||||
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
|
||||
anyio.create_memory_object_stream[ServerRequestResponder](0)
|
||||
)
|
||||
self._exit_stack.push_async_callback(
|
||||
lambda: self._incoming_message_stream_reader.aclose()
|
||||
)
|
||||
self._exit_stack.push_async_callback(
|
||||
lambda: self._incoming_message_stream_writer.aclose()
|
||||
)
|
||||
|
||||
@property
|
||||
def client_params(self) -> types.InitializeRequestParams | None:
|
||||
@@ -291,3 +306,12 @@ class ServerSession(
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def _handle_incoming(self, req: ServerRequestResponder) -> None:
|
||||
await self._incoming_message_stream_writer.send(req)
|
||||
|
||||
@property
|
||||
def incoming_messages(
|
||||
self,
|
||||
) -> MemoryObjectReceiveStream[ServerRequestResponder]:
|
||||
return self._incoming_message_stream_reader
|
||||
|
||||
@@ -10,7 +10,13 @@ from typing import Any
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
|
||||
from mcp.client.session import ClientSession, ListRootsFnT, LoggingFnT, SamplingFnT
|
||||
from mcp.client.session import (
|
||||
ClientSession,
|
||||
ListRootsFnT,
|
||||
LoggingFnT,
|
||||
MessageHandlerFnT,
|
||||
SamplingFnT,
|
||||
)
|
||||
from mcp.server import Server
|
||||
from mcp.types import JSONRPCMessage
|
||||
|
||||
@@ -58,6 +64,7 @@ async def create_connected_server_and_client_session(
|
||||
sampling_callback: SamplingFnT | None = None,
|
||||
list_roots_callback: ListRootsFnT | None = None,
|
||||
logging_callback: LoggingFnT | None = None,
|
||||
message_handler: MessageHandlerFnT | None = None,
|
||||
raise_exceptions: bool = False,
|
||||
) -> AsyncGenerator[ClientSession, None]:
|
||||
"""Creates a ClientSession that is connected to a running MCP server."""
|
||||
@@ -87,6 +94,7 @@ async def create_connected_server_and_client_session(
|
||||
sampling_callback=sampling_callback,
|
||||
list_roots_callback=list_roots_callback,
|
||||
logging_callback=logging_callback,
|
||||
message_handler=message_handler,
|
||||
) as client_session:
|
||||
await client_session.initialize()
|
||||
yield client_session
|
||||
|
||||
@@ -189,19 +189,6 @@ class BaseSession(
|
||||
self._in_flight = {}
|
||||
|
||||
self._exit_stack = AsyncExitStack()
|
||||
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
|
||||
anyio.create_memory_object_stream[
|
||||
RequestResponder[ReceiveRequestT, SendResultT]
|
||||
| ReceiveNotificationT
|
||||
| Exception
|
||||
]()
|
||||
)
|
||||
self._exit_stack.push_async_callback(
|
||||
lambda: self._incoming_message_stream_reader.aclose()
|
||||
)
|
||||
self._exit_stack.push_async_callback(
|
||||
lambda: self._incoming_message_stream_writer.aclose()
|
||||
)
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
self._task_group = anyio.create_task_group()
|
||||
@@ -312,11 +299,10 @@ class BaseSession(
|
||||
async with (
|
||||
self._read_stream,
|
||||
self._write_stream,
|
||||
self._incoming_message_stream_writer,
|
||||
):
|
||||
async for message in self._read_stream:
|
||||
if isinstance(message, Exception):
|
||||
await self._incoming_message_stream_writer.send(message)
|
||||
await self._handle_incoming(message)
|
||||
elif isinstance(message.root, JSONRPCRequest):
|
||||
validated_request = self._receive_request_type.model_validate(
|
||||
message.root.model_dump(
|
||||
@@ -336,8 +322,9 @@ class BaseSession(
|
||||
|
||||
self._in_flight[responder.request_id] = responder
|
||||
await self._received_request(responder)
|
||||
|
||||
if not responder._completed: # type: ignore[reportPrivateUsage]
|
||||
await self._incoming_message_stream_writer.send(responder)
|
||||
await self._handle_incoming(responder)
|
||||
|
||||
elif isinstance(message.root, JSONRPCNotification):
|
||||
try:
|
||||
@@ -353,9 +340,7 @@ class BaseSession(
|
||||
await self._in_flight[cancelled_id].cancel()
|
||||
else:
|
||||
await self._received_notification(notification)
|
||||
await self._incoming_message_stream_writer.send(
|
||||
notification
|
||||
)
|
||||
await self._handle_incoming(notification)
|
||||
except Exception as e:
|
||||
# For other validation errors, log and continue
|
||||
logging.warning(
|
||||
@@ -367,7 +352,7 @@ class BaseSession(
|
||||
if stream:
|
||||
await stream.send(message.root)
|
||||
else:
|
||||
await self._incoming_message_stream_writer.send(
|
||||
await self._handle_incoming(
|
||||
RuntimeError(
|
||||
"Received response with an unknown "
|
||||
f"request ID: {message}"
|
||||
@@ -399,12 +384,11 @@ class BaseSession(
|
||||
processed.
|
||||
"""
|
||||
|
||||
@property
|
||||
def incoming_messages(
|
||||
async def _handle_incoming(
|
||||
self,
|
||||
) -> MemoryObjectReceiveStream[
|
||||
RequestResponder[ReceiveRequestT, SendResultT]
|
||||
req: RequestResponder[ReceiveRequestT, SendResultT]
|
||||
| ReceiveNotificationT
|
||||
| Exception
|
||||
]:
|
||||
return self._incoming_message_stream_reader
|
||||
| Exception,
|
||||
) -> None:
|
||||
"""A generic handler for incoming messages. Overwritten by subclasses."""
|
||||
pass
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from typing import Literal
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.shared.memory import (
|
||||
create_connected_server_and_client_session as create_session,
|
||||
)
|
||||
from mcp.shared.session import RequestResponder
|
||||
from mcp.types import (
|
||||
LoggingMessageNotificationParams,
|
||||
TextContent,
|
||||
@@ -46,40 +47,37 @@ async def test_logging_callback():
|
||||
)
|
||||
return True
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
async with create_session(
|
||||
server._mcp_server, logging_callback=logging_collector
|
||||
) as client_session:
|
||||
# Create a message handler to catch exceptions
|
||||
async def message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
| types.ServerNotification
|
||||
| Exception,
|
||||
) -> None:
|
||||
if isinstance(message, Exception):
|
||||
raise message
|
||||
|
||||
async def listen_session():
|
||||
try:
|
||||
async for message in client_session.incoming_messages:
|
||||
if isinstance(message, Exception):
|
||||
raise message
|
||||
except anyio.EndOfStream:
|
||||
pass
|
||||
async with create_session(
|
||||
server._mcp_server,
|
||||
logging_callback=logging_collector,
|
||||
message_handler=message_handler,
|
||||
) as client_session:
|
||||
# First verify our test tool works
|
||||
result = await client_session.call_tool("test_tool", {})
|
||||
assert result.isError is False
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert result.content[0].text == "true"
|
||||
|
||||
tg.start_soon(listen_session)
|
||||
|
||||
# First verify our test tool works
|
||||
result = await client_session.call_tool("test_tool", {})
|
||||
assert result.isError is False
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert result.content[0].text == "true"
|
||||
|
||||
# Now send a log message via our tool
|
||||
log_result = await client_session.call_tool(
|
||||
"test_tool_with_log",
|
||||
{
|
||||
"message": "Test log message",
|
||||
"level": "info",
|
||||
"logger": "test_logger",
|
||||
},
|
||||
)
|
||||
assert log_result.isError is False
|
||||
assert len(logging_collector.log_messages) == 1
|
||||
assert logging_collector.log_messages[
|
||||
0
|
||||
] == LoggingMessageNotificationParams(
|
||||
level="info", logger="test_logger", data="Test log message"
|
||||
)
|
||||
# Now send a log message via our tool
|
||||
log_result = await client_session.call_tool(
|
||||
"test_tool_with_log",
|
||||
{
|
||||
"message": "Test log message",
|
||||
"level": "info",
|
||||
"logger": "test_logger",
|
||||
},
|
||||
)
|
||||
assert log_result.isError is False
|
||||
assert len(logging_collector.log_messages) == 1
|
||||
assert logging_collector.log_messages[0] == LoggingMessageNotificationParams(
|
||||
level="info", logger="test_logger", data="Test log message"
|
||||
)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.shared.session import RequestResponder
|
||||
from mcp.types import (
|
||||
LATEST_PROTOCOL_VERSION,
|
||||
ClientNotification,
|
||||
@@ -75,13 +77,21 @@ async def test_client_session_initialize():
|
||||
)
|
||||
)
|
||||
|
||||
async def listen_session():
|
||||
async for message in session.incoming_messages:
|
||||
if isinstance(message, Exception):
|
||||
raise message
|
||||
# Create a message handler to catch exceptions
|
||||
async def message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
| types.ServerNotification
|
||||
| Exception,
|
||||
) -> None:
|
||||
if isinstance(message, Exception):
|
||||
raise message
|
||||
|
||||
async with (
|
||||
ClientSession(server_to_client_receive, client_to_server_send) as session,
|
||||
ClientSession(
|
||||
server_to_client_receive,
|
||||
client_to_server_send,
|
||||
message_handler=message_handler,
|
||||
) as session,
|
||||
anyio.create_task_group() as tg,
|
||||
client_to_server_send,
|
||||
client_to_server_receive,
|
||||
@@ -89,7 +99,6 @@ async def test_client_session_initialize():
|
||||
server_to_client_receive,
|
||||
):
|
||||
tg.start_soon(mock_server)
|
||||
tg.start_soon(listen_session)
|
||||
result = await session.initialize()
|
||||
|
||||
# Assert the result
|
||||
|
||||
@@ -6,6 +6,7 @@ from pathlib import Path
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
from anyio.abc import TaskStatus
|
||||
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.server.lowlevel import Server
|
||||
@@ -54,15 +55,21 @@ async def test_notification_validation_error(tmp_path: Path):
|
||||
return [TextContent(type="text", text=f"fast {request_count}")]
|
||||
return [TextContent(type="text", text=f"unknown {request_count}")]
|
||||
|
||||
async def server_handler(read_stream, write_stream):
|
||||
await server.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
server.create_initialization_options(),
|
||||
raise_exceptions=True,
|
||||
)
|
||||
async def server_handler(
|
||||
read_stream,
|
||||
write_stream,
|
||||
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
|
||||
):
|
||||
with anyio.CancelScope() as scope:
|
||||
task_status.started(scope) # type: ignore
|
||||
await server.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
server.create_initialization_options(),
|
||||
raise_exceptions=True,
|
||||
)
|
||||
|
||||
async def client(read_stream, write_stream):
|
||||
async def client(read_stream, write_stream, scope):
|
||||
# Use a timeout that's:
|
||||
# - Long enough for fast operations (>10ms)
|
||||
# - Short enough for slow operations (<200ms)
|
||||
@@ -90,22 +97,13 @@ async def test_notification_validation_error(tmp_path: Path):
|
||||
# proving server is still responsive
|
||||
result = await session.call_tool("fast")
|
||||
assert result.content == [TextContent(type="text", text="fast 3")]
|
||||
scope.cancel()
|
||||
|
||||
# Run server and client in separate task groups to avoid cancellation
|
||||
server_writer, server_reader = anyio.create_memory_object_stream(1)
|
||||
client_writer, client_reader = anyio.create_memory_object_stream(1)
|
||||
|
||||
server_ready = anyio.Event()
|
||||
|
||||
async def wrapped_server_handler(read_stream, write_stream):
|
||||
server_ready.set()
|
||||
await server_handler(read_stream, write_stream)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(wrapped_server_handler, server_reader, client_writer)
|
||||
# Wait for server to start and initialize
|
||||
with anyio.fail_after(1): # Timeout after 1 second
|
||||
await server_ready.wait()
|
||||
scope = await tg.start(server_handler, server_reader, client_writer)
|
||||
# Run client in a separate task to avoid cancellation
|
||||
async with anyio.create_task_group() as client_tg:
|
||||
client_tg.start_soon(client, client_reader, server_writer)
|
||||
tg.start_soon(client, client_reader, server_writer, scope)
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.server import Server
|
||||
from mcp.server.lowlevel import NotificationOptions
|
||||
from mcp.server.models import InitializationOptions
|
||||
from mcp.server.session import ServerSession
|
||||
from mcp.shared.session import RequestResponder
|
||||
from mcp.types import (
|
||||
ClientNotification,
|
||||
InitializedNotification,
|
||||
@@ -25,10 +27,14 @@ async def test_server_session_initialize():
|
||||
JSONRPCMessage
|
||||
](1)
|
||||
|
||||
async def run_client(client: ClientSession):
|
||||
async for message in client_session.incoming_messages:
|
||||
if isinstance(message, Exception):
|
||||
raise message
|
||||
# Create a message handler to catch exceptions
|
||||
async def message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
| types.ServerNotification
|
||||
| Exception,
|
||||
) -> None:
|
||||
if isinstance(message, Exception):
|
||||
raise message
|
||||
|
||||
received_initialized = False
|
||||
|
||||
@@ -57,11 +63,12 @@ async def test_server_session_initialize():
|
||||
try:
|
||||
async with (
|
||||
ClientSession(
|
||||
server_to_client_receive, client_to_server_send
|
||||
server_to_client_receive,
|
||||
client_to_server_send,
|
||||
message_handler=message_handler,
|
||||
) as client_session,
|
||||
anyio.create_task_group() as tg,
|
||||
):
|
||||
tg.start_soon(run_client, client_session)
|
||||
tg.start_soon(run_server)
|
||||
|
||||
await client_session.initialize()
|
||||
|
||||
Reference in New Issue
Block a user