mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
feat: add message to ProgressNotification (#435)
Co-authored-by: ihrpr <inna.hrpr@gmail.com>
This commit is contained in:
@@ -168,7 +168,11 @@ class ClientSession(
|
||||
)
|
||||
|
||||
async def send_progress_notification(
|
||||
self, progress_token: str | int, progress: float, total: float | None = None
|
||||
self,
|
||||
progress_token: str | int,
|
||||
progress: float,
|
||||
total: float | None = None,
|
||||
message: str | None = None,
|
||||
) -> None:
|
||||
"""Send a progress notification."""
|
||||
await self.send_notification(
|
||||
@@ -179,6 +183,7 @@ class ClientSession(
|
||||
progressToken=progress_token,
|
||||
progress=progress,
|
||||
total=total,
|
||||
message=message,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -952,13 +952,14 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
|
||||
return self._request_context
|
||||
|
||||
async def report_progress(
|
||||
self, progress: float, total: float | None = None
|
||||
self, progress: float, total: float | None = None, message: str | None = None
|
||||
) -> None:
|
||||
"""Report progress for the current operation.
|
||||
|
||||
Args:
|
||||
progress: Current progress value e.g. 24
|
||||
total: Optional total value e.g. 100
|
||||
message: Optional message e.g. Starting render...
|
||||
"""
|
||||
|
||||
progress_token = (
|
||||
@@ -971,7 +972,10 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
|
||||
return
|
||||
|
||||
await self.request_context.session.send_progress_notification(
|
||||
progress_token=progress_token, progress=progress, total=total
|
||||
progress_token=progress_token,
|
||||
progress=progress,
|
||||
total=total,
|
||||
message=message,
|
||||
)
|
||||
|
||||
async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]:
|
||||
|
||||
@@ -37,7 +37,8 @@ Usage:
|
||||
3. Define notification handlers if needed:
|
||||
@server.progress_notification()
|
||||
async def handle_progress(
|
||||
progress_token: str | int, progress: float, total: float | None
|
||||
progress_token: str | int, progress: float, total: float | None,
|
||||
message: str | None
|
||||
) -> None:
|
||||
# Implementation
|
||||
|
||||
@@ -427,13 +428,18 @@ class Server(Generic[LifespanResultT]):
|
||||
|
||||
def progress_notification(self):
|
||||
def decorator(
|
||||
func: Callable[[str | int, float, float | None], Awaitable[None]],
|
||||
func: Callable[
|
||||
[str | int, float, float | None, str | None], Awaitable[None]
|
||||
],
|
||||
):
|
||||
logger.debug("Registering handler for ProgressNotification")
|
||||
|
||||
async def handler(req: types.ProgressNotification):
|
||||
await func(
|
||||
req.params.progressToken, req.params.progress, req.params.total
|
||||
req.params.progressToken,
|
||||
req.params.progress,
|
||||
req.params.total,
|
||||
req.params.message,
|
||||
)
|
||||
|
||||
self.notification_handlers[types.ProgressNotification] = handler
|
||||
|
||||
@@ -282,6 +282,7 @@ class ServerSession(
|
||||
progress_token: str | int,
|
||||
progress: float,
|
||||
total: float | None = None,
|
||||
message: str | None = None,
|
||||
related_request_id: str | None = None,
|
||||
) -> None:
|
||||
"""Send a progress notification."""
|
||||
@@ -293,6 +294,7 @@ class ServerSession(
|
||||
progressToken=progress_token,
|
||||
progress=progress,
|
||||
total=total,
|
||||
message=message,
|
||||
),
|
||||
)
|
||||
),
|
||||
|
||||
@@ -43,11 +43,11 @@ class ProgressContext(
|
||||
total: float | None
|
||||
current: float = field(default=0.0, init=False)
|
||||
|
||||
async def progress(self, amount: float) -> None:
|
||||
async def progress(self, amount: float, message: str | None = None) -> None:
|
||||
self.current += amount
|
||||
|
||||
await self.session.send_progress_notification(
|
||||
self.progress_token, self.current, total=self.total
|
||||
self.progress_token, self.current, total=self.total, message=message
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -401,7 +401,11 @@ class BaseSession(
|
||||
"""
|
||||
|
||||
async def send_progress_notification(
|
||||
self, progress_token: str | int, progress: float, total: float | None = None
|
||||
self,
|
||||
progress_token: str | int,
|
||||
progress: float,
|
||||
total: float | None = None,
|
||||
message: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Sends a progress notification for a request that is currently being
|
||||
|
||||
@@ -337,6 +337,11 @@ class ProgressNotificationParams(NotificationParams):
|
||||
total is unknown.
|
||||
"""
|
||||
total: float | None = None
|
||||
"""
|
||||
Message related to progress. This should provide relevant human readable
|
||||
progress information.
|
||||
"""
|
||||
message: str | None = None
|
||||
"""Total number of items to process (or total progress required), if known."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@@ -39,11 +39,11 @@ async def test_progress_token_zero_first_call():
|
||||
mock_session.send_progress_notification.call_count == 3
|
||||
), "All progress notifications should be sent"
|
||||
mock_session.send_progress_notification.assert_any_call(
|
||||
progress_token=0, progress=0.0, total=10.0
|
||||
progress_token=0, progress=0.0, total=10.0, message=None
|
||||
)
|
||||
mock_session.send_progress_notification.assert_any_call(
|
||||
progress_token=0, progress=5.0, total=10.0
|
||||
progress_token=0, progress=5.0, total=10.0, message=None
|
||||
)
|
||||
mock_session.send_progress_notification.assert_any_call(
|
||||
progress_token=0, progress=10.0, total=10.0
|
||||
progress_token=0, progress=10.0, total=10.0, message=None
|
||||
)
|
||||
|
||||
349
tests/shared/test_progress_notifications.py
Normal file
349
tests/shared/test_progress_notifications.py
Normal file
@@ -0,0 +1,349 @@
|
||||
from typing import Any, cast
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.server import Server
|
||||
from mcp.server.lowlevel import NotificationOptions
|
||||
from mcp.server.models import InitializationOptions
|
||||
from mcp.server.session import ServerSession
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.shared.progress import progress
|
||||
from mcp.shared.session import (
|
||||
BaseSession,
|
||||
RequestResponder,
|
||||
SessionMessage,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_bidirectional_progress_notifications():
|
||||
"""Test that both client and server can send progress notifications."""
|
||||
# Create memory streams for client/server
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](5)
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](5)
|
||||
|
||||
# Run a server session so we can send progress updates in tool
|
||||
async def run_server():
|
||||
# Create a server session
|
||||
async with ServerSession(
|
||||
client_to_server_receive,
|
||||
server_to_client_send,
|
||||
InitializationOptions(
|
||||
server_name="ProgressTestServer",
|
||||
server_version="0.1.0",
|
||||
capabilities=server.get_capabilities(NotificationOptions(), {}),
|
||||
),
|
||||
) as server_session:
|
||||
global serv_sesh
|
||||
|
||||
serv_sesh = server_session
|
||||
async for message in server_session.incoming_messages:
|
||||
try:
|
||||
await server._handle_message(message, server_session, ())
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
# Track progress updates
|
||||
server_progress_updates = []
|
||||
client_progress_updates = []
|
||||
|
||||
# Progress tokens
|
||||
server_progress_token = "server_token_123"
|
||||
client_progress_token = "client_token_456"
|
||||
|
||||
# Create a server with progress capability
|
||||
server = Server(name="ProgressTestServer")
|
||||
|
||||
# Register progress handler
|
||||
@server.progress_notification()
|
||||
async def handle_progress(
|
||||
progress_token: str | int,
|
||||
progress: float,
|
||||
total: float | None,
|
||||
message: str | None,
|
||||
):
|
||||
server_progress_updates.append(
|
||||
{
|
||||
"token": progress_token,
|
||||
"progress": progress,
|
||||
"total": total,
|
||||
"message": message,
|
||||
}
|
||||
)
|
||||
|
||||
# Register list tool handler
|
||||
@server.list_tools()
|
||||
async def handle_list_tools() -> list[types.Tool]:
|
||||
return [
|
||||
types.Tool(
|
||||
name="test_tool",
|
||||
description="A tool that sends progress notifications <o/",
|
||||
inputSchema={},
|
||||
)
|
||||
]
|
||||
|
||||
# Register tool handler
|
||||
@server.call_tool()
|
||||
async def handle_call_tool(name: str, arguments: dict | None) -> list:
|
||||
# Make sure we received a progress token
|
||||
if name == "test_tool":
|
||||
if arguments and "_meta" in arguments:
|
||||
progressToken = arguments["_meta"]["progressToken"]
|
||||
|
||||
if not progressToken:
|
||||
raise ValueError("Empty progress token received")
|
||||
|
||||
if progressToken != client_progress_token:
|
||||
raise ValueError("Server sending back incorrect progressToken")
|
||||
|
||||
# Send progress notifications
|
||||
await serv_sesh.send_progress_notification(
|
||||
progress_token=progressToken,
|
||||
progress=0.25,
|
||||
total=1.0,
|
||||
message="Server progress 25%",
|
||||
)
|
||||
|
||||
await serv_sesh.send_progress_notification(
|
||||
progress_token=progressToken,
|
||||
progress=0.5,
|
||||
total=1.0,
|
||||
message="Server progress 50%",
|
||||
)
|
||||
|
||||
await serv_sesh.send_progress_notification(
|
||||
progress_token=progressToken,
|
||||
progress=1.0,
|
||||
total=1.0,
|
||||
message="Server progress 100%",
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError("Progress token not sent.")
|
||||
|
||||
return ["Tool executed successfully"]
|
||||
|
||||
raise ValueError(f"Unknown tool: {name}")
|
||||
|
||||
# Client message handler to store progress notifications
|
||||
async def handle_client_message(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
| types.ServerNotification
|
||||
| Exception,
|
||||
) -> None:
|
||||
if isinstance(message, Exception):
|
||||
raise message
|
||||
|
||||
if isinstance(message, types.ServerNotification):
|
||||
if isinstance(message.root, types.ProgressNotification):
|
||||
params = message.root.params
|
||||
client_progress_updates.append(
|
||||
{
|
||||
"token": params.progressToken,
|
||||
"progress": params.progress,
|
||||
"total": params.total,
|
||||
"message": params.message,
|
||||
}
|
||||
)
|
||||
|
||||
# Test using client
|
||||
async with (
|
||||
ClientSession(
|
||||
server_to_client_receive,
|
||||
client_to_server_send,
|
||||
message_handler=handle_client_message,
|
||||
) as client_session,
|
||||
anyio.create_task_group() as tg,
|
||||
):
|
||||
# Start the server in a background task
|
||||
tg.start_soon(run_server)
|
||||
|
||||
# Initialize the client connection
|
||||
await client_session.initialize()
|
||||
|
||||
# Call list_tools with progress token
|
||||
await client_session.list_tools()
|
||||
|
||||
# Call test_tool with progress token
|
||||
await client_session.call_tool(
|
||||
"test_tool", {"_meta": {"progressToken": client_progress_token}}
|
||||
)
|
||||
|
||||
# Send progress notifications from client to server
|
||||
await client_session.send_progress_notification(
|
||||
progress_token=server_progress_token,
|
||||
progress=0.33,
|
||||
total=1.0,
|
||||
message="Client progress 33%",
|
||||
)
|
||||
|
||||
await client_session.send_progress_notification(
|
||||
progress_token=server_progress_token,
|
||||
progress=0.66,
|
||||
total=1.0,
|
||||
message="Client progress 66%",
|
||||
)
|
||||
|
||||
await client_session.send_progress_notification(
|
||||
progress_token=server_progress_token,
|
||||
progress=1.0,
|
||||
total=1.0,
|
||||
message="Client progress 100%",
|
||||
)
|
||||
|
||||
# Wait and exit
|
||||
await anyio.sleep(0.5)
|
||||
tg.cancel_scope.cancel()
|
||||
|
||||
# Verify client received progress updates from server
|
||||
assert len(client_progress_updates) == 3
|
||||
assert client_progress_updates[0]["token"] == client_progress_token
|
||||
assert client_progress_updates[0]["progress"] == 0.25
|
||||
assert client_progress_updates[0]["message"] == "Server progress 25%"
|
||||
assert client_progress_updates[2]["progress"] == 1.0
|
||||
|
||||
# Verify server received progress updates from client
|
||||
assert len(server_progress_updates) == 3
|
||||
assert server_progress_updates[0]["token"] == server_progress_token
|
||||
assert server_progress_updates[0]["progress"] == 0.33
|
||||
assert server_progress_updates[0]["message"] == "Client progress 33%"
|
||||
assert server_progress_updates[2]["progress"] == 1.0
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_progress_context_manager():
|
||||
"""Test client using progress context manager for sending progress notifications."""
|
||||
# Create memory streams for client/server
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](5)
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](5)
|
||||
|
||||
# Track progress updates
|
||||
server_progress_updates = []
|
||||
|
||||
server = Server(name="ProgressContextTestServer")
|
||||
|
||||
# Register progress handler
|
||||
@server.progress_notification()
|
||||
async def handle_progress(
|
||||
progress_token: str | int,
|
||||
progress: float,
|
||||
total: float | None,
|
||||
message: str | None,
|
||||
):
|
||||
server_progress_updates.append(
|
||||
{
|
||||
"token": progress_token,
|
||||
"progress": progress,
|
||||
"total": total,
|
||||
"message": message,
|
||||
}
|
||||
)
|
||||
|
||||
# Run server session to receive progress updates
|
||||
async def run_server():
|
||||
# Create a server session
|
||||
async with ServerSession(
|
||||
client_to_server_receive,
|
||||
server_to_client_send,
|
||||
InitializationOptions(
|
||||
server_name="ProgressContextTestServer",
|
||||
server_version="0.1.0",
|
||||
capabilities=server.get_capabilities(NotificationOptions(), {}),
|
||||
),
|
||||
) as server_session:
|
||||
async for message in server_session.incoming_messages:
|
||||
try:
|
||||
await server._handle_message(message, server_session, ())
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
# Client message handler
|
||||
async def handle_client_message(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
| types.ServerNotification
|
||||
| Exception,
|
||||
) -> None:
|
||||
if isinstance(message, Exception):
|
||||
raise message
|
||||
|
||||
# run client session
|
||||
async with (
|
||||
ClientSession(
|
||||
server_to_client_receive,
|
||||
client_to_server_send,
|
||||
message_handler=handle_client_message,
|
||||
) as client_session,
|
||||
anyio.create_task_group() as tg,
|
||||
):
|
||||
tg.start_soon(run_server)
|
||||
|
||||
await client_session.initialize()
|
||||
|
||||
progress_token = "client_token_456"
|
||||
|
||||
# Create request context
|
||||
meta = types.RequestParams.Meta(progressToken=progress_token)
|
||||
request_context = RequestContext(
|
||||
request_id="test-request",
|
||||
session=client_session,
|
||||
meta=meta,
|
||||
lifespan_context=None,
|
||||
)
|
||||
|
||||
# cast for type checker
|
||||
typed_context = cast(
|
||||
RequestContext[
|
||||
BaseSession[Any, Any, Any, Any, Any],
|
||||
Any,
|
||||
],
|
||||
request_context,
|
||||
)
|
||||
|
||||
# Utilize progress context manager
|
||||
with progress(typed_context, total=100) as p:
|
||||
await p.progress(10, message="Loading configuration...")
|
||||
await p.progress(30, message="Connecting to database...")
|
||||
await p.progress(40, message="Fetching data...")
|
||||
await p.progress(20, message="Processing results...")
|
||||
|
||||
# Wait for all messages to be processed
|
||||
await anyio.sleep(0.5)
|
||||
tg.cancel_scope.cancel()
|
||||
|
||||
# Verify progress updates were received by server
|
||||
assert len(server_progress_updates) == 4
|
||||
|
||||
# first update
|
||||
assert server_progress_updates[0]["token"] == progress_token
|
||||
assert server_progress_updates[0]["progress"] == 10
|
||||
assert server_progress_updates[0]["total"] == 100
|
||||
assert server_progress_updates[0]["message"] == "Loading configuration..."
|
||||
|
||||
# second update
|
||||
assert server_progress_updates[1]["token"] == progress_token
|
||||
assert server_progress_updates[1]["progress"] == 40
|
||||
assert server_progress_updates[1]["total"] == 100
|
||||
assert server_progress_updates[1]["message"] == "Connecting to database..."
|
||||
|
||||
# third update
|
||||
assert server_progress_updates[2]["token"] == progress_token
|
||||
assert server_progress_updates[2]["progress"] == 80
|
||||
assert server_progress_updates[2]["total"] == 100
|
||||
assert server_progress_updates[2]["message"] == "Fetching data..."
|
||||
|
||||
# final update
|
||||
assert server_progress_updates[3]["token"] == progress_token
|
||||
assert server_progress_updates[3]["progress"] == 100
|
||||
assert server_progress_updates[3]["total"] == 100
|
||||
assert server_progress_updates[3]["message"] == "Processing results..."
|
||||
Reference in New Issue
Block a user