mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 23:04:25 +01:00
Add progress notification callback for client (#721)
This commit is contained in:
@@ -8,7 +8,7 @@ from pydantic import AnyUrl, TypeAdapter
|
||||
import mcp.types as types
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.shared.message import SessionMessage
|
||||
from mcp.shared.session import BaseSession, RequestResponder
|
||||
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
|
||||
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
||||
|
||||
DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
|
||||
@@ -270,18 +270,23 @@ class ClientSession(
|
||||
name: str,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
progress_callback: ProgressFnT | None = None,
|
||||
) -> types.CallToolResult:
|
||||
"""Send a tools/call request."""
|
||||
"""Send a tools/call request with optional progress callback support."""
|
||||
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.CallToolRequest(
|
||||
method="tools/call",
|
||||
params=types.CallToolRequestParams(name=name, arguments=arguments),
|
||||
params=types.CallToolRequestParams(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
),
|
||||
)
|
||||
),
|
||||
types.CallToolResult,
|
||||
request_read_timeout_seconds=read_timeout_seconds,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResult:
|
||||
|
||||
@@ -963,7 +963,6 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
|
||||
total: Optional total value e.g. 100
|
||||
message: Optional message e.g. Starting render...
|
||||
"""
|
||||
|
||||
progress_token = (
|
||||
self.request_context.meta.progressToken
|
||||
if self.request_context.meta
|
||||
|
||||
@@ -3,7 +3,7 @@ from collections.abc import Callable
|
||||
from contextlib import AsyncExitStack
|
||||
from datetime import timedelta
|
||||
from types import TracebackType
|
||||
from typing import Any, Generic, TypeVar
|
||||
from typing import Any, Generic, Protocol, TypeVar
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
@@ -24,6 +24,7 @@ from mcp.types import (
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
ProgressNotification,
|
||||
RequestParams,
|
||||
ServerNotification,
|
||||
ServerRequest,
|
||||
@@ -42,6 +43,14 @@ ReceiveNotificationT = TypeVar(
|
||||
RequestId = str | int
|
||||
|
||||
|
||||
class ProgressFnT(Protocol):
|
||||
"""Protocol for progress notification callbacks."""
|
||||
|
||||
async def __call__(
|
||||
self, progress: float, total: float | None, message: str | None
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||
"""Handles responding to MCP requests and manages request lifecycle.
|
||||
|
||||
@@ -169,6 +178,7 @@ class BaseSession(
|
||||
]
|
||||
_request_id: int
|
||||
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
|
||||
_progress_callbacks: dict[RequestId, ProgressFnT]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -187,6 +197,7 @@ class BaseSession(
|
||||
self._receive_notification_type = receive_notification_type
|
||||
self._session_read_timeout_seconds = read_timeout_seconds
|
||||
self._in_flight = {}
|
||||
self._progress_callbacks = {}
|
||||
self._exit_stack = AsyncExitStack()
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
@@ -214,6 +225,7 @@ class BaseSession(
|
||||
result_type: type[ReceiveResultT],
|
||||
request_read_timeout_seconds: timedelta | None = None,
|
||||
metadata: MessageMetadata = None,
|
||||
progress_callback: ProgressFnT | None = None,
|
||||
) -> ReceiveResultT:
|
||||
"""
|
||||
Sends a request and wait for a response. Raises an McpError if the
|
||||
@@ -231,15 +243,25 @@ class BaseSession(
|
||||
](1)
|
||||
self._response_streams[request_id] = response_stream
|
||||
|
||||
# Set up progress token if progress callback is provided
|
||||
request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
if progress_callback is not None:
|
||||
# Use request_id as progress token
|
||||
if "params" not in request_data:
|
||||
request_data["params"] = {}
|
||||
if "_meta" not in request_data["params"]:
|
||||
request_data["params"]["_meta"] = {}
|
||||
request_data["params"]["_meta"]["progressToken"] = request_id
|
||||
# Store the callback for this request
|
||||
self._progress_callbacks[request_id] = progress_callback
|
||||
|
||||
try:
|
||||
jsonrpc_request = JSONRPCRequest(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
**request_data,
|
||||
)
|
||||
|
||||
# TODO: Support progress callbacks
|
||||
|
||||
await self._write_stream.send(
|
||||
SessionMessage(
|
||||
message=JSONRPCMessage(jsonrpc_request), metadata=metadata
|
||||
@@ -275,6 +297,7 @@ class BaseSession(
|
||||
|
||||
finally:
|
||||
self._response_streams.pop(request_id, None)
|
||||
self._progress_callbacks.pop(request_id, None)
|
||||
await response_stream.aclose()
|
||||
await response_stream_reader.aclose()
|
||||
|
||||
@@ -333,7 +356,6 @@ class BaseSession(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
)
|
||||
)
|
||||
|
||||
responder = RequestResponder(
|
||||
request_id=message.message.root.id,
|
||||
request_meta=validated_request.root.params.meta
|
||||
@@ -363,6 +385,18 @@ class BaseSession(
|
||||
if cancelled_id in self._in_flight:
|
||||
await self._in_flight[cancelled_id].cancel()
|
||||
else:
|
||||
# Handle progress notifications callback
|
||||
if isinstance(notification.root, ProgressNotification):
|
||||
progress_token = notification.root.params.progressToken
|
||||
# If there is a progress callback for this token,
|
||||
# call it with the progress information
|
||||
if progress_token in self._progress_callbacks:
|
||||
callback = self._progress_callbacks[progress_token]
|
||||
await callback(
|
||||
notification.root.params.progress,
|
||||
notification.root.params.total,
|
||||
notification.root.params.message,
|
||||
)
|
||||
await self._received_notification(notification)
|
||||
await self._handle_incoming(notification)
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user