mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
490 lines
19 KiB
Python
490 lines
19 KiB
Python
import logging
|
|
from collections.abc import Callable
|
|
from contextlib import AsyncExitStack
|
|
from datetime import timedelta
|
|
from types import TracebackType
|
|
from typing import Any, Generic, Protocol, TypeVar
|
|
|
|
import anyio
|
|
import httpx
|
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
from pydantic import BaseModel
|
|
from typing_extensions import Self
|
|
|
|
from mcp.shared.exceptions import McpError
|
|
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
|
|
from mcp.types import (
|
|
CONNECTION_CLOSED,
|
|
INVALID_PARAMS,
|
|
CancelledNotification,
|
|
ClientNotification,
|
|
ClientRequest,
|
|
ClientResult,
|
|
ErrorData,
|
|
JSONRPCError,
|
|
JSONRPCMessage,
|
|
JSONRPCNotification,
|
|
JSONRPCRequest,
|
|
JSONRPCResponse,
|
|
ProgressNotification,
|
|
RequestParams,
|
|
ServerNotification,
|
|
ServerRequest,
|
|
ServerResult,
|
|
)
|
|
|
|
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
|
|
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
|
|
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
|
|
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
|
|
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
|
|
ReceiveNotificationT = TypeVar(
|
|
"ReceiveNotificationT", ClientNotification, ServerNotification
|
|
)
|
|
|
|
RequestId = str | int
|
|
|
|
|
|
class ProgressFnT(Protocol):
|
|
"""Protocol for progress notification callbacks."""
|
|
|
|
async def __call__(
|
|
self, progress: float, total: float | None, message: str | None
|
|
) -> None: ...
|
|
|
|
|
|
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
|
"""Handles responding to MCP requests and manages request lifecycle.
|
|
|
|
This class MUST be used as a context manager to ensure proper cleanup and
|
|
cancellation handling:
|
|
|
|
Example:
|
|
with request_responder as resp:
|
|
await resp.respond(result)
|
|
|
|
The context manager ensures:
|
|
1. Proper cancellation scope setup and cleanup
|
|
2. Request completion tracking
|
|
3. Cleanup of in-flight requests
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
request_id: RequestId,
|
|
request_meta: RequestParams.Meta | None,
|
|
request: ReceiveRequestT,
|
|
session: """BaseSession[
|
|
SendRequestT,
|
|
SendNotificationT,
|
|
SendResultT,
|
|
ReceiveRequestT,
|
|
ReceiveNotificationT
|
|
]""",
|
|
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
|
message_metadata: MessageMetadata = None,
|
|
) -> None:
|
|
self.request_id = request_id
|
|
self.request_meta = request_meta
|
|
self.request = request
|
|
self.message_metadata = message_metadata
|
|
self._session = session
|
|
self._completed = False
|
|
self._cancel_scope = anyio.CancelScope()
|
|
self._on_complete = on_complete
|
|
self._entered = False # Track if we're in a context manager
|
|
|
|
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
|
|
"""Enter the context manager, enabling request cancellation tracking."""
|
|
self._entered = True
|
|
self._cancel_scope = anyio.CancelScope()
|
|
self._cancel_scope.__enter__()
|
|
return self
|
|
|
|
def __exit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: TracebackType | None,
|
|
) -> None:
|
|
"""Exit the context manager, performing cleanup and notifying completion."""
|
|
try:
|
|
if self._completed:
|
|
self._on_complete(self)
|
|
finally:
|
|
self._entered = False
|
|
if not self._cancel_scope:
|
|
raise RuntimeError("No active cancel scope")
|
|
self._cancel_scope.__exit__(exc_type, exc_val, exc_tb)
|
|
|
|
async def respond(self, response: SendResultT | ErrorData) -> None:
|
|
"""Send a response for this request.
|
|
|
|
Must be called within a context manager block.
|
|
Raises:
|
|
RuntimeError: If not used within a context manager
|
|
AssertionError: If request was already responded to
|
|
"""
|
|
if not self._entered:
|
|
raise RuntimeError("RequestResponder must be used as a context manager")
|
|
assert not self._completed, "Request already responded to"
|
|
|
|
if not self.cancelled:
|
|
self._completed = True
|
|
|
|
await self._session._send_response( # type: ignore[reportPrivateUsage]
|
|
request_id=self.request_id, response=response
|
|
)
|
|
|
|
async def cancel(self) -> None:
|
|
"""Cancel this request and mark it as completed."""
|
|
if not self._entered:
|
|
raise RuntimeError("RequestResponder must be used as a context manager")
|
|
if not self._cancel_scope:
|
|
raise RuntimeError("No active cancel scope")
|
|
|
|
self._cancel_scope.cancel()
|
|
self._completed = True # Mark as completed so it's removed from in_flight
|
|
# Send an error response to indicate cancellation
|
|
await self._session._send_response( # type: ignore[reportPrivateUsage]
|
|
request_id=self.request_id,
|
|
response=ErrorData(code=0, message="Request cancelled", data=None),
|
|
)
|
|
|
|
@property
|
|
def in_flight(self) -> bool:
|
|
return not self._completed and not self.cancelled
|
|
|
|
@property
|
|
def cancelled(self) -> bool:
|
|
return self._cancel_scope.cancel_called
|
|
|
|
|
|
class BaseSession(
|
|
Generic[
|
|
SendRequestT,
|
|
SendNotificationT,
|
|
SendResultT,
|
|
ReceiveRequestT,
|
|
ReceiveNotificationT,
|
|
],
|
|
):
|
|
"""
|
|
Implements an MCP "session" on top of read/write streams, including features
|
|
like request/response linking, notifications, and progress.
|
|
|
|
This class is an async context manager that automatically starts processing
|
|
messages when entered.
|
|
"""
|
|
|
|
_response_streams: dict[
|
|
RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]
|
|
]
|
|
_request_id: int
|
|
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
|
|
_progress_callbacks: dict[RequestId, ProgressFnT]
|
|
|
|
def __init__(
|
|
self,
|
|
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
|
|
write_stream: MemoryObjectSendStream[SessionMessage],
|
|
receive_request_type: type[ReceiveRequestT],
|
|
receive_notification_type: type[ReceiveNotificationT],
|
|
# If none, reading will never time out
|
|
read_timeout_seconds: timedelta | None = None,
|
|
) -> None:
|
|
self._read_stream = read_stream
|
|
self._write_stream = write_stream
|
|
self._response_streams = {}
|
|
self._request_id = 0
|
|
self._receive_request_type = receive_request_type
|
|
self._receive_notification_type = receive_notification_type
|
|
self._session_read_timeout_seconds = read_timeout_seconds
|
|
self._in_flight = {}
|
|
self._progress_callbacks = {}
|
|
self._exit_stack = AsyncExitStack()
|
|
|
|
async def __aenter__(self) -> Self:
|
|
self._task_group = anyio.create_task_group()
|
|
await self._task_group.__aenter__()
|
|
self._task_group.start_soon(self._receive_loop)
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: TracebackType | None,
|
|
) -> bool | None:
|
|
await self._exit_stack.aclose()
|
|
# Using BaseSession as a context manager should not block on exit (this
|
|
# would be very surprising behavior), so make sure to cancel the tasks
|
|
# in the task group.
|
|
self._task_group.cancel_scope.cancel()
|
|
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
|
|
|
|
async def send_request(
|
|
self,
|
|
request: SendRequestT,
|
|
result_type: type[ReceiveResultT],
|
|
request_read_timeout_seconds: timedelta | None = None,
|
|
metadata: MessageMetadata = None,
|
|
progress_callback: ProgressFnT | None = None,
|
|
) -> ReceiveResultT:
|
|
"""
|
|
Sends a request and wait for a response. Raises an McpError if the
|
|
response contains an error. If a request read timeout is provided, it
|
|
will take precedence over the session read timeout.
|
|
|
|
Do not use this method to emit notifications! Use send_notification()
|
|
instead.
|
|
"""
|
|
request_id = self._request_id
|
|
self._request_id = request_id + 1
|
|
|
|
response_stream, response_stream_reader = anyio.create_memory_object_stream[
|
|
JSONRPCResponse | JSONRPCError
|
|
](1)
|
|
self._response_streams[request_id] = response_stream
|
|
|
|
# Set up progress token if progress callback is provided
|
|
request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True)
|
|
if progress_callback is not None:
|
|
# Use request_id as progress token
|
|
if "params" not in request_data:
|
|
request_data["params"] = {}
|
|
if "_meta" not in request_data["params"]:
|
|
request_data["params"]["_meta"] = {}
|
|
request_data["params"]["_meta"]["progressToken"] = request_id
|
|
# Store the callback for this request
|
|
self._progress_callbacks[request_id] = progress_callback
|
|
|
|
try:
|
|
jsonrpc_request = JSONRPCRequest(
|
|
jsonrpc="2.0",
|
|
id=request_id,
|
|
**request_data,
|
|
)
|
|
|
|
await self._write_stream.send(
|
|
SessionMessage(
|
|
message=JSONRPCMessage(jsonrpc_request), metadata=metadata
|
|
)
|
|
)
|
|
|
|
# request read timeout takes precedence over session read timeout
|
|
timeout = None
|
|
if request_read_timeout_seconds is not None:
|
|
timeout = request_read_timeout_seconds.total_seconds()
|
|
elif self._session_read_timeout_seconds is not None:
|
|
timeout = self._session_read_timeout_seconds.total_seconds()
|
|
|
|
try:
|
|
with anyio.fail_after(timeout):
|
|
response_or_error = await response_stream_reader.receive()
|
|
except TimeoutError:
|
|
raise McpError(
|
|
ErrorData(
|
|
code=httpx.codes.REQUEST_TIMEOUT,
|
|
message=(
|
|
f"Timed out while waiting for response to "
|
|
f"{request.__class__.__name__}. Waited "
|
|
f"{timeout} seconds."
|
|
),
|
|
)
|
|
)
|
|
|
|
if isinstance(response_or_error, JSONRPCError):
|
|
raise McpError(response_or_error.error)
|
|
else:
|
|
return result_type.model_validate(response_or_error.result)
|
|
|
|
finally:
|
|
self._response_streams.pop(request_id, None)
|
|
self._progress_callbacks.pop(request_id, None)
|
|
await response_stream.aclose()
|
|
await response_stream_reader.aclose()
|
|
|
|
async def send_notification(
|
|
self,
|
|
notification: SendNotificationT,
|
|
related_request_id: RequestId | None = None,
|
|
) -> None:
|
|
"""
|
|
Emits a notification, which is a one-way message that does not expect
|
|
a response.
|
|
"""
|
|
# Some transport implementations may need to set the related_request_id
|
|
# to attribute to the notifications to the request that triggered them.
|
|
jsonrpc_notification = JSONRPCNotification(
|
|
jsonrpc="2.0",
|
|
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
|
|
)
|
|
session_message = SessionMessage(
|
|
message=JSONRPCMessage(jsonrpc_notification),
|
|
metadata=ServerMessageMetadata(related_request_id=related_request_id)
|
|
if related_request_id
|
|
else None,
|
|
)
|
|
await self._write_stream.send(session_message)
|
|
|
|
async def _send_response(
|
|
self, request_id: RequestId, response: SendResultT | ErrorData
|
|
) -> None:
|
|
if isinstance(response, ErrorData):
|
|
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
|
|
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
|
|
await self._write_stream.send(session_message)
|
|
else:
|
|
jsonrpc_response = JSONRPCResponse(
|
|
jsonrpc="2.0",
|
|
id=request_id,
|
|
result=response.model_dump(
|
|
by_alias=True, mode="json", exclude_none=True
|
|
),
|
|
)
|
|
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
|
|
await self._write_stream.send(session_message)
|
|
|
|
async def _receive_loop(self) -> None:
|
|
async with (
|
|
self._read_stream,
|
|
self._write_stream,
|
|
):
|
|
async for message in self._read_stream:
|
|
if isinstance(message, Exception):
|
|
await self._handle_incoming(message)
|
|
elif isinstance(message.message.root, JSONRPCRequest):
|
|
try:
|
|
validated_request = self._receive_request_type.model_validate(
|
|
message.message.root.model_dump(
|
|
by_alias=True, mode="json", exclude_none=True
|
|
)
|
|
)
|
|
responder = RequestResponder(
|
|
request_id=message.message.root.id,
|
|
request_meta=validated_request.root.params.meta
|
|
if validated_request.root.params
|
|
else None,
|
|
request=validated_request,
|
|
session=self,
|
|
on_complete=lambda r: self._in_flight.pop(
|
|
r.request_id, None),
|
|
message_metadata=message.metadata,
|
|
)
|
|
self._in_flight[responder.request_id] = responder
|
|
await self._received_request(responder)
|
|
|
|
if not responder._completed: # type: ignore[reportPrivateUsage]
|
|
await self._handle_incoming(responder)
|
|
except Exception as e:
|
|
# For request validation errors, send a proper JSON-RPC error
|
|
# response instead of crashing the server
|
|
logging.warning(f"Failed to validate request: {e}")
|
|
logging.debug(
|
|
f"Message that failed validation: {message.message.root}"
|
|
)
|
|
error_response = JSONRPCError(
|
|
jsonrpc="2.0",
|
|
id=message.message.root.id,
|
|
error=ErrorData(
|
|
code=INVALID_PARAMS,
|
|
message="Invalid request parameters",
|
|
data="",
|
|
),
|
|
)
|
|
session_message = SessionMessage(
|
|
message=JSONRPCMessage(error_response))
|
|
await self._write_stream.send(session_message)
|
|
|
|
elif isinstance(message.message.root, JSONRPCNotification):
|
|
try:
|
|
notification = self._receive_notification_type.model_validate(
|
|
message.message.root.model_dump(
|
|
by_alias=True, mode="json", exclude_none=True
|
|
)
|
|
)
|
|
# Handle cancellation notifications
|
|
if isinstance(notification.root, CancelledNotification):
|
|
cancelled_id = notification.root.params.requestId
|
|
if cancelled_id in self._in_flight:
|
|
await self._in_flight[cancelled_id].cancel()
|
|
else:
|
|
# Handle progress notifications callback
|
|
if isinstance(notification.root, ProgressNotification):
|
|
progress_token = notification.root.params.progressToken
|
|
# If there is a progress callback for this token,
|
|
# call it with the progress information
|
|
if progress_token in self._progress_callbacks:
|
|
callback = self._progress_callbacks[progress_token]
|
|
await callback(
|
|
notification.root.params.progress,
|
|
notification.root.params.total,
|
|
notification.root.params.message,
|
|
)
|
|
await self._received_notification(notification)
|
|
await self._handle_incoming(notification)
|
|
except Exception as e:
|
|
# For other validation errors, log and continue
|
|
logging.warning(
|
|
f"Failed to validate notification: {e}. "
|
|
f"Message was: {message.message.root}"
|
|
)
|
|
else: # Response or error
|
|
stream = self._response_streams.pop(message.message.root.id, None)
|
|
if stream:
|
|
await stream.send(message.message.root)
|
|
else:
|
|
await self._handle_incoming(
|
|
RuntimeError(
|
|
"Received response with an unknown "
|
|
f"request ID: {message}"
|
|
)
|
|
)
|
|
|
|
# after the read stream is closed, we need to send errors
|
|
# to any pending requests
|
|
for id, stream in self._response_streams.items():
|
|
error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed")
|
|
await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error))
|
|
await stream.aclose()
|
|
self._response_streams.clear()
|
|
|
|
async def _received_request(
|
|
self, responder: RequestResponder[ReceiveRequestT, SendResultT]
|
|
) -> None:
|
|
"""
|
|
Can be overridden by subclasses to handle a request without needing to
|
|
listen on the message stream.
|
|
|
|
If the request is responded to within this method, it will not be
|
|
forwarded on to the message stream.
|
|
"""
|
|
|
|
async def _received_notification(self, notification: ReceiveNotificationT) -> None:
|
|
"""
|
|
Can be overridden by subclasses to handle a notification without needing
|
|
to listen on the message stream.
|
|
"""
|
|
|
|
async def send_progress_notification(
|
|
self,
|
|
progress_token: str | int,
|
|
progress: float,
|
|
total: float | None = None,
|
|
message: str | None = None,
|
|
) -> None:
|
|
"""
|
|
Sends a progress notification for a request that is currently being
|
|
processed.
|
|
"""
|
|
|
|
async def _handle_incoming(
|
|
self,
|
|
req: RequestResponder[ReceiveRequestT, SendResultT]
|
|
| ReceiveNotificationT
|
|
| Exception,
|
|
) -> None:
|
|
"""A generic handler for incoming messages. Overwritten by subclasses."""
|
|
pass
|