mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
feat: add message to ProgressNotification (#435)
Co-authored-by: ihrpr <inna.hrpr@gmail.com>
This commit is contained in:
@@ -168,7 +168,11 @@ class ClientSession(
|
||||
)
|
||||
|
||||
async def send_progress_notification(
|
||||
self, progress_token: str | int, progress: float, total: float | None = None
|
||||
self,
|
||||
progress_token: str | int,
|
||||
progress: float,
|
||||
total: float | None = None,
|
||||
message: str | None = None,
|
||||
) -> None:
|
||||
"""Send a progress notification."""
|
||||
await self.send_notification(
|
||||
@@ -179,6 +183,7 @@ class ClientSession(
|
||||
progressToken=progress_token,
|
||||
progress=progress,
|
||||
total=total,
|
||||
message=message,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -952,13 +952,14 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
|
||||
return self._request_context
|
||||
|
||||
async def report_progress(
|
||||
self, progress: float, total: float | None = None
|
||||
self, progress: float, total: float | None = None, message: str | None = None
|
||||
) -> None:
|
||||
"""Report progress for the current operation.
|
||||
|
||||
Args:
|
||||
progress: Current progress value e.g. 24
|
||||
total: Optional total value e.g. 100
|
||||
message: Optional message e.g. Starting render...
|
||||
"""
|
||||
|
||||
progress_token = (
|
||||
@@ -971,7 +972,10 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
|
||||
return
|
||||
|
||||
await self.request_context.session.send_progress_notification(
|
||||
progress_token=progress_token, progress=progress, total=total
|
||||
progress_token=progress_token,
|
||||
progress=progress,
|
||||
total=total,
|
||||
message=message,
|
||||
)
|
||||
|
||||
async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]:
|
||||
|
||||
@@ -37,7 +37,8 @@ Usage:
|
||||
3. Define notification handlers if needed:
|
||||
@server.progress_notification()
|
||||
async def handle_progress(
|
||||
progress_token: str | int, progress: float, total: float | None
|
||||
progress_token: str | int, progress: float, total: float | None,
|
||||
message: str | None
|
||||
) -> None:
|
||||
# Implementation
|
||||
|
||||
@@ -427,13 +428,18 @@ class Server(Generic[LifespanResultT]):
|
||||
|
||||
def progress_notification(self):
|
||||
def decorator(
|
||||
func: Callable[[str | int, float, float | None], Awaitable[None]],
|
||||
func: Callable[
|
||||
[str | int, float, float | None, str | None], Awaitable[None]
|
||||
],
|
||||
):
|
||||
logger.debug("Registering handler for ProgressNotification")
|
||||
|
||||
async def handler(req: types.ProgressNotification):
|
||||
await func(
|
||||
req.params.progressToken, req.params.progress, req.params.total
|
||||
req.params.progressToken,
|
||||
req.params.progress,
|
||||
req.params.total,
|
||||
req.params.message,
|
||||
)
|
||||
|
||||
self.notification_handlers[types.ProgressNotification] = handler
|
||||
|
||||
@@ -282,6 +282,7 @@ class ServerSession(
|
||||
progress_token: str | int,
|
||||
progress: float,
|
||||
total: float | None = None,
|
||||
message: str | None = None,
|
||||
related_request_id: str | None = None,
|
||||
) -> None:
|
||||
"""Send a progress notification."""
|
||||
@@ -293,6 +294,7 @@ class ServerSession(
|
||||
progressToken=progress_token,
|
||||
progress=progress,
|
||||
total=total,
|
||||
message=message,
|
||||
),
|
||||
)
|
||||
),
|
||||
|
||||
@@ -43,11 +43,11 @@ class ProgressContext(
|
||||
total: float | None
|
||||
current: float = field(default=0.0, init=False)
|
||||
|
||||
async def progress(self, amount: float) -> None:
|
||||
async def progress(self, amount: float, message: str | None = None) -> None:
|
||||
self.current += amount
|
||||
|
||||
await self.session.send_progress_notification(
|
||||
self.progress_token, self.current, total=self.total
|
||||
self.progress_token, self.current, total=self.total, message=message
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -401,7 +401,11 @@ class BaseSession(
|
||||
"""
|
||||
|
||||
async def send_progress_notification(
|
||||
self, progress_token: str | int, progress: float, total: float | None = None
|
||||
self,
|
||||
progress_token: str | int,
|
||||
progress: float,
|
||||
total: float | None = None,
|
||||
message: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Sends a progress notification for a request that is currently being
|
||||
|
||||
@@ -337,6 +337,11 @@ class ProgressNotificationParams(NotificationParams):
|
||||
total is unknown.
|
||||
"""
|
||||
total: float | None = None
|
||||
"""
|
||||
Message related to progress. This should provide relevant human readable
|
||||
progress information.
|
||||
"""
|
||||
message: str | None = None
|
||||
"""Total number of items to process (or total progress required), if known."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user