mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Add progress notification callback for client (#721)
This commit is contained in:
@@ -3,7 +3,7 @@ from collections.abc import Callable
|
||||
from contextlib import AsyncExitStack
|
||||
from datetime import timedelta
|
||||
from types import TracebackType
|
||||
from typing import Any, Generic, TypeVar
|
||||
from typing import Any, Generic, Protocol, TypeVar
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
@@ -24,6 +24,7 @@ from mcp.types import (
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
ProgressNotification,
|
||||
RequestParams,
|
||||
ServerNotification,
|
||||
ServerRequest,
|
||||
@@ -42,6 +43,14 @@ ReceiveNotificationT = TypeVar(
|
||||
RequestId = str | int
|
||||
|
||||
|
||||
class ProgressFnT(Protocol):
|
||||
"""Protocol for progress notification callbacks."""
|
||||
|
||||
async def __call__(
|
||||
self, progress: float, total: float | None, message: str | None
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||
"""Handles responding to MCP requests and manages request lifecycle.
|
||||
|
||||
@@ -169,6 +178,7 @@ class BaseSession(
|
||||
]
|
||||
_request_id: int
|
||||
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
|
||||
_progress_callbacks: dict[RequestId, ProgressFnT]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -187,6 +197,7 @@ class BaseSession(
|
||||
self._receive_notification_type = receive_notification_type
|
||||
self._session_read_timeout_seconds = read_timeout_seconds
|
||||
self._in_flight = {}
|
||||
self._progress_callbacks = {}
|
||||
self._exit_stack = AsyncExitStack()
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
@@ -214,6 +225,7 @@ class BaseSession(
|
||||
result_type: type[ReceiveResultT],
|
||||
request_read_timeout_seconds: timedelta | None = None,
|
||||
metadata: MessageMetadata = None,
|
||||
progress_callback: ProgressFnT | None = None,
|
||||
) -> ReceiveResultT:
|
||||
"""
|
||||
Sends a request and wait for a response. Raises an McpError if the
|
||||
@@ -231,15 +243,25 @@ class BaseSession(
|
||||
](1)
|
||||
self._response_streams[request_id] = response_stream
|
||||
|
||||
# Set up progress token if progress callback is provided
|
||||
request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
if progress_callback is not None:
|
||||
# Use request_id as progress token
|
||||
if "params" not in request_data:
|
||||
request_data["params"] = {}
|
||||
if "_meta" not in request_data["params"]:
|
||||
request_data["params"]["_meta"] = {}
|
||||
request_data["params"]["_meta"]["progressToken"] = request_id
|
||||
# Store the callback for this request
|
||||
self._progress_callbacks[request_id] = progress_callback
|
||||
|
||||
try:
|
||||
jsonrpc_request = JSONRPCRequest(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
**request_data,
|
||||
)
|
||||
|
||||
# TODO: Support progress callbacks
|
||||
|
||||
await self._write_stream.send(
|
||||
SessionMessage(
|
||||
message=JSONRPCMessage(jsonrpc_request), metadata=metadata
|
||||
@@ -275,6 +297,7 @@ class BaseSession(
|
||||
|
||||
finally:
|
||||
self._response_streams.pop(request_id, None)
|
||||
self._progress_callbacks.pop(request_id, None)
|
||||
await response_stream.aclose()
|
||||
await response_stream_reader.aclose()
|
||||
|
||||
@@ -333,7 +356,6 @@ class BaseSession(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
)
|
||||
)
|
||||
|
||||
responder = RequestResponder(
|
||||
request_id=message.message.root.id,
|
||||
request_meta=validated_request.root.params.meta
|
||||
@@ -363,6 +385,18 @@ class BaseSession(
|
||||
if cancelled_id in self._in_flight:
|
||||
await self._in_flight[cancelled_id].cancel()
|
||||
else:
|
||||
# Handle progress notifications callback
|
||||
if isinstance(notification.root, ProgressNotification):
|
||||
progress_token = notification.root.params.progressToken
|
||||
# If there is a progress callback for this token,
|
||||
# call it with the progress information
|
||||
if progress_token in self._progress_callbacks:
|
||||
callback = self._progress_callbacks[progress_token]
|
||||
await callback(
|
||||
notification.root.params.progress,
|
||||
notification.root.params.total,
|
||||
notification.root.params.message,
|
||||
)
|
||||
await self._received_notification(notification)
|
||||
await self._handle_incoming(notification)
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user