Add progress notification callback for client (#721)

This commit is contained in:
ihrpr
2025-05-15 17:45:58 +01:00
committed by GitHub
parent 1bdeed33c2
commit 5d33861cad
6 changed files with 609 additions and 12 deletions

View File

@@ -3,7 +3,7 @@ from collections.abc import Callable
from contextlib import AsyncExitStack
from datetime import timedelta
from types import TracebackType
from typing import Any, Generic, TypeVar
from typing import Any, Generic, Protocol, TypeVar
import anyio
import httpx
@@ -24,6 +24,7 @@ from mcp.types import (
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
ProgressNotification,
RequestParams,
ServerNotification,
ServerRequest,
@@ -42,6 +43,14 @@ ReceiveNotificationT = TypeVar(
RequestId = str | int
class ProgressFnT(Protocol):
"""Protocol for progress notification callbacks."""
async def __call__(
self, progress: float, total: float | None, message: str | None
) -> None: ...
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
"""Handles responding to MCP requests and manages request lifecycle.
@@ -169,6 +178,7 @@ class BaseSession(
]
_request_id: int
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
_progress_callbacks: dict[RequestId, ProgressFnT]
def __init__(
self,
@@ -187,6 +197,7 @@ class BaseSession(
self._receive_notification_type = receive_notification_type
self._session_read_timeout_seconds = read_timeout_seconds
self._in_flight = {}
self._progress_callbacks = {}
self._exit_stack = AsyncExitStack()
async def __aenter__(self) -> Self:
@@ -214,6 +225,7 @@ class BaseSession(
result_type: type[ReceiveResultT],
request_read_timeout_seconds: timedelta | None = None,
metadata: MessageMetadata = None,
progress_callback: ProgressFnT | None = None,
) -> ReceiveResultT:
"""
Sends a request and wait for a response. Raises an McpError if the
@@ -231,15 +243,25 @@ class BaseSession(
](1)
self._response_streams[request_id] = response_stream
# Set up progress token if progress callback is provided
request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True)
if progress_callback is not None:
# Use request_id as progress token
if "params" not in request_data:
request_data["params"] = {}
if "_meta" not in request_data["params"]:
request_data["params"]["_meta"] = {}
request_data["params"]["_meta"]["progressToken"] = request_id
# Store the callback for this request
self._progress_callbacks[request_id] = progress_callback
try:
jsonrpc_request = JSONRPCRequest(
jsonrpc="2.0",
id=request_id,
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
**request_data,
)
# TODO: Support progress callbacks
await self._write_stream.send(
SessionMessage(
message=JSONRPCMessage(jsonrpc_request), metadata=metadata
@@ -275,6 +297,7 @@ class BaseSession(
finally:
self._response_streams.pop(request_id, None)
self._progress_callbacks.pop(request_id, None)
await response_stream.aclose()
await response_stream_reader.aclose()
@@ -333,7 +356,6 @@ class BaseSession(
by_alias=True, mode="json", exclude_none=True
)
)
responder = RequestResponder(
request_id=message.message.root.id,
request_meta=validated_request.root.params.meta
@@ -363,6 +385,18 @@ class BaseSession(
if cancelled_id in self._in_flight:
await self._in_flight[cancelled_id].cancel()
else:
# Handle progress notifications callback
if isinstance(notification.root, ProgressNotification):
progress_token = notification.root.params.progressToken
# If there is a progress callback for this token,
# call it with the progress information
if progress_token in self._progress_callbacks:
callback = self._progress_callbacks[progress_token]
await callback(
notification.root.params.progress,
notification.root.params.total,
notification.root.params.message,
)
await self._received_notification(notification)
await self._handle_incoming(notification)
except Exception as e: