add callback for logging message notification (#314)

This commit is contained in:
ihrpr
2025-03-19 09:40:08 +00:00
committed by GitHub
parent a9aca20205
commit 08f4e01b8f
3 changed files with 113 additions and 1 deletions

View File

@@ -24,6 +24,13 @@ class ListRootsFnT(Protocol):
) -> types.ListRootsResult | types.ErrorData: ...
class LoggingFnT(Protocol):
async def __call__(
self,
params: types.LoggingMessageNotificationParams,
) -> None: ...
async def _default_sampling_callback(
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
@@ -43,6 +50,12 @@ async def _default_list_roots_callback(
)
async def _default_logging_callback(
params: types.LoggingMessageNotificationParams,
) -> None:
pass
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(
types.ClientResult | types.ErrorData
)
@@ -64,6 +77,7 @@ class ClientSession(
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
) -> None:
super().__init__(
read_stream,
@@ -74,6 +88,7 @@ class ClientSession(
)
self._sampling_callback = sampling_callback or _default_sampling_callback
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
self._logging_callback = logging_callback or _default_logging_callback
async def initialize(self) -> types.InitializeResult:
sampling = types.SamplingCapability()
@@ -321,3 +336,13 @@ class ClientSession(
return await responder.respond(
types.ClientResult(root=types.EmptyResult())
)
async def _received_notification(
self, notification: types.ServerNotification
) -> None:
"""Handle notifications from the server."""
match notification.root:
case types.LoggingMessageNotification(params=params):
await self._logging_callback(params)
case _:
pass

View File

@@ -9,7 +9,7 @@ from typing import AsyncGenerator
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT
from mcp.client.session import ClientSession, ListRootsFnT, LoggingFnT, SamplingFnT
from mcp.server import Server
from mcp.types import JSONRPCMessage
@@ -56,6 +56,7 @@ async def create_connected_server_and_client_session(
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
raise_exceptions: bool = False,
) -> AsyncGenerator[ClientSession, None]:
"""Creates a ClientSession that is connected to a running MCP server."""
@@ -84,6 +85,7 @@ async def create_connected_server_and_client_session(
read_timeout_seconds=read_timeout_seconds,
sampling_callback=sampling_callback,
list_roots_callback=list_roots_callback,
logging_callback=logging_callback,
) as client_session:
await client_session.initialize()
yield client_session