mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
85 lines
1.9 KiB
Python
85 lines
1.9 KiB
Python
from collections.abc import Generator
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass, field
|
|
from typing import Generic
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from mcp.shared.context import LifespanContextT, RequestContext
|
|
from mcp.shared.session import (
|
|
BaseSession,
|
|
ReceiveNotificationT,
|
|
ReceiveRequestT,
|
|
SendNotificationT,
|
|
SendRequestT,
|
|
SendResultT,
|
|
)
|
|
from mcp.types import ProgressToken
|
|
|
|
|
|
class Progress(BaseModel):
|
|
progress: float
|
|
total: float | None
|
|
|
|
|
|
@dataclass
|
|
class ProgressContext(
|
|
Generic[
|
|
SendRequestT,
|
|
SendNotificationT,
|
|
SendResultT,
|
|
ReceiveRequestT,
|
|
ReceiveNotificationT,
|
|
]
|
|
):
|
|
session: BaseSession[
|
|
SendRequestT,
|
|
SendNotificationT,
|
|
SendResultT,
|
|
ReceiveRequestT,
|
|
ReceiveNotificationT,
|
|
]
|
|
progress_token: ProgressToken
|
|
total: float | None
|
|
current: float = field(default=0.0, init=False)
|
|
|
|
async def progress(self, amount: float) -> None:
|
|
self.current += amount
|
|
|
|
await self.session.send_progress_notification(
|
|
self.progress_token, self.current, total=self.total
|
|
)
|
|
|
|
|
|
@contextmanager
|
|
def progress(
|
|
ctx: RequestContext[
|
|
BaseSession[
|
|
SendRequestT,
|
|
SendNotificationT,
|
|
SendResultT,
|
|
ReceiveRequestT,
|
|
ReceiveNotificationT,
|
|
],
|
|
LifespanContextT,
|
|
],
|
|
total: float | None = None,
|
|
) -> Generator[
|
|
ProgressContext[
|
|
SendRequestT,
|
|
SendNotificationT,
|
|
SendResultT,
|
|
ReceiveRequestT,
|
|
ReceiveNotificationT,
|
|
],
|
|
None,
|
|
]:
|
|
if ctx.meta is None or ctx.meta.progressToken is None:
|
|
raise ValueError("No progress token provided")
|
|
|
|
progress_ctx = ProgressContext(ctx.session, ctx.meta.progressToken, total)
|
|
try:
|
|
yield progress_ctx
|
|
finally:
|
|
pass
|