mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 06:54:18 +01:00
Add client handling for sampling, list roots, ping (#218)
Adds sampling and list roots callbacks to the ClientSession, allowing the client to handle requests from the server. Co-authored-by: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Co-authored-by: David Soria Parra <davidsp@anthropic.com>
This commit is contained in:
14
README.md
14
README.md
@@ -476,9 +476,21 @@ server_params = StdioServerParameters(
|
|||||||
env=None # Optional environment variables
|
env=None # Optional environment variables
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Optional: create a sampling callback
|
||||||
|
async def handle_sampling_message(message: types.CreateMessageRequestParams) -> types.CreateMessageResult:
|
||||||
|
return types.CreateMessageResult(
|
||||||
|
role="assistant",
|
||||||
|
content=types.TextContent(
|
||||||
|
type="text",
|
||||||
|
text="Hello, world! from model",
|
||||||
|
),
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
stopReason="endTurn",
|
||||||
|
)
|
||||||
|
|
||||||
async def run():
|
async def run():
|
||||||
async with stdio_client(server_params) as (read, write):
|
async with stdio_client(server_params) as (read, write):
|
||||||
async with ClientSession(read, write) as session:
|
async with ClientSession(read, write, sampling_callback=handle_sampling_message) as session:
|
||||||
# Initialize the connection
|
# Initialize the connection
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,51 @@
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||||
from pydantic import AnyUrl
|
from pydantic import AnyUrl, TypeAdapter
|
||||||
|
|
||||||
import mcp.types as types
|
import mcp.types as types
|
||||||
from mcp.shared.session import BaseSession
|
from mcp.shared.context import RequestContext
|
||||||
|
from mcp.shared.session import BaseSession, RequestResponder
|
||||||
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
||||||
|
|
||||||
|
|
||||||
|
class SamplingFnT(Protocol):
|
||||||
|
async def __call__(
|
||||||
|
self,
|
||||||
|
context: RequestContext["ClientSession", Any],
|
||||||
|
params: types.CreateMessageRequestParams,
|
||||||
|
) -> types.CreateMessageResult | types.ErrorData: ...
|
||||||
|
|
||||||
|
|
||||||
|
class ListRootsFnT(Protocol):
|
||||||
|
async def __call__(
|
||||||
|
self, context: RequestContext["ClientSession", Any]
|
||||||
|
) -> types.ListRootsResult | types.ErrorData: ...
|
||||||
|
|
||||||
|
|
||||||
|
async def _default_sampling_callback(
|
||||||
|
context: RequestContext["ClientSession", Any],
|
||||||
|
params: types.CreateMessageRequestParams,
|
||||||
|
) -> types.CreateMessageResult | types.ErrorData:
|
||||||
|
return types.ErrorData(
|
||||||
|
code=types.INVALID_REQUEST,
|
||||||
|
message="Sampling not supported",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _default_list_roots_callback(
|
||||||
|
context: RequestContext["ClientSession", Any],
|
||||||
|
) -> types.ListRootsResult | types.ErrorData:
|
||||||
|
return types.ErrorData(
|
||||||
|
code=types.INVALID_REQUEST,
|
||||||
|
message="List roots not supported",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ClientResponse = TypeAdapter(types.ClientResult | types.ErrorData)
|
||||||
|
|
||||||
|
|
||||||
class ClientSession(
|
class ClientSession(
|
||||||
BaseSession[
|
BaseSession[
|
||||||
types.ClientRequest,
|
types.ClientRequest,
|
||||||
@@ -22,6 +60,8 @@ class ClientSession(
|
|||||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
|
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
|
||||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
|
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
|
||||||
read_timeout_seconds: timedelta | None = None,
|
read_timeout_seconds: timedelta | None = None,
|
||||||
|
sampling_callback: SamplingFnT | None = None,
|
||||||
|
list_roots_callback: ListRootsFnT | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
read_stream,
|
read_stream,
|
||||||
@@ -30,8 +70,24 @@ class ClientSession(
|
|||||||
types.ServerNotification,
|
types.ServerNotification,
|
||||||
read_timeout_seconds=read_timeout_seconds,
|
read_timeout_seconds=read_timeout_seconds,
|
||||||
)
|
)
|
||||||
|
self._sampling_callback = sampling_callback or _default_sampling_callback
|
||||||
|
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
|
||||||
|
|
||||||
async def initialize(self) -> types.InitializeResult:
|
async def initialize(self) -> types.InitializeResult:
|
||||||
|
sampling = (
|
||||||
|
types.SamplingCapability() if self._sampling_callback is not None else None
|
||||||
|
)
|
||||||
|
roots = (
|
||||||
|
types.RootsCapability(
|
||||||
|
# TODO: Should this be based on whether we
|
||||||
|
# _will_ send notifications, or only whether
|
||||||
|
# they're supported?
|
||||||
|
listChanged=True,
|
||||||
|
)
|
||||||
|
if self._list_roots_callback is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
result = await self.send_request(
|
result = await self.send_request(
|
||||||
types.ClientRequest(
|
types.ClientRequest(
|
||||||
types.InitializeRequest(
|
types.InitializeRequest(
|
||||||
@@ -39,14 +95,9 @@ class ClientSession(
|
|||||||
params=types.InitializeRequestParams(
|
params=types.InitializeRequestParams(
|
||||||
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
||||||
capabilities=types.ClientCapabilities(
|
capabilities=types.ClientCapabilities(
|
||||||
sampling=None,
|
sampling=sampling,
|
||||||
experimental=None,
|
experimental=None,
|
||||||
roots=types.RootsCapability(
|
roots=roots,
|
||||||
# TODO: Should this be based on whether we
|
|
||||||
# _will_ send notifications, or only whether
|
|
||||||
# they're supported?
|
|
||||||
listChanged=True
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
clientInfo=types.Implementation(name="mcp", version="0.1.0"),
|
clientInfo=types.Implementation(name="mcp", version="0.1.0"),
|
||||||
),
|
),
|
||||||
@@ -243,3 +294,32 @@ class ClientSession(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _received_request(
|
||||||
|
self, responder: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||||
|
) -> None:
|
||||||
|
ctx = RequestContext[ClientSession, Any](
|
||||||
|
request_id=responder.request_id,
|
||||||
|
meta=responder.request_meta,
|
||||||
|
session=self,
|
||||||
|
lifespan_context=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
match responder.request.root:
|
||||||
|
case types.CreateMessageRequest(params=params):
|
||||||
|
with responder:
|
||||||
|
response = await self._sampling_callback(ctx, params)
|
||||||
|
client_response = ClientResponse.validate_python(response)
|
||||||
|
await responder.respond(client_response)
|
||||||
|
|
||||||
|
case types.ListRootsRequest():
|
||||||
|
with responder:
|
||||||
|
response = await self._list_roots_callback(ctx)
|
||||||
|
client_response = ClientResponse.validate_python(response)
|
||||||
|
await responder.respond(client_response)
|
||||||
|
|
||||||
|
case types.PingRequest():
|
||||||
|
with responder:
|
||||||
|
return await responder.respond(
|
||||||
|
types.ClientResult(root=types.EmptyResult())
|
||||||
|
)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from typing import AsyncGenerator
|
|||||||
import anyio
|
import anyio
|
||||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||||
|
|
||||||
from mcp.client.session import ClientSession
|
from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT
|
||||||
from mcp.server import Server
|
from mcp.server import Server
|
||||||
from mcp.types import JSONRPCMessage
|
from mcp.types import JSONRPCMessage
|
||||||
|
|
||||||
@@ -54,6 +54,8 @@ async def create_client_server_memory_streams() -> (
|
|||||||
async def create_connected_server_and_client_session(
|
async def create_connected_server_and_client_session(
|
||||||
server: Server,
|
server: Server,
|
||||||
read_timeout_seconds: timedelta | None = None,
|
read_timeout_seconds: timedelta | None = None,
|
||||||
|
sampling_callback: SamplingFnT | None = None,
|
||||||
|
list_roots_callback: ListRootsFnT | None = None,
|
||||||
raise_exceptions: bool = False,
|
raise_exceptions: bool = False,
|
||||||
) -> AsyncGenerator[ClientSession, None]:
|
) -> AsyncGenerator[ClientSession, None]:
|
||||||
"""Creates a ClientSession that is connected to a running MCP server."""
|
"""Creates a ClientSession that is connected to a running MCP server."""
|
||||||
@@ -80,6 +82,8 @@ async def create_connected_server_and_client_session(
|
|||||||
read_stream=client_read,
|
read_stream=client_read,
|
||||||
write_stream=client_write,
|
write_stream=client_write,
|
||||||
read_timeout_seconds=read_timeout_seconds,
|
read_timeout_seconds=read_timeout_seconds,
|
||||||
|
sampling_callback=sampling_callback,
|
||||||
|
list_roots_callback=list_roots_callback,
|
||||||
) as client_session:
|
) as client_session:
|
||||||
await client_session.initialize()
|
await client_session.initialize()
|
||||||
yield client_session
|
yield client_session
|
||||||
|
|||||||
70
tests/client/test_list_roots_callback.py
Normal file
70
tests/client/test_list_roots_callback.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
import pytest
|
||||||
|
from pydantic import FileUrl
|
||||||
|
|
||||||
|
from mcp.client.session import ClientSession
|
||||||
|
from mcp.server.fastmcp.server import Context
|
||||||
|
from mcp.shared.context import RequestContext
|
||||||
|
from mcp.shared.memory import (
|
||||||
|
create_connected_server_and_client_session as create_session,
|
||||||
|
)
|
||||||
|
from mcp.types import (
|
||||||
|
ListRootsResult,
|
||||||
|
Root,
|
||||||
|
TextContent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_list_roots_callback():
|
||||||
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
|
||||||
|
server = FastMCP("test")
|
||||||
|
|
||||||
|
callback_return = ListRootsResult(
|
||||||
|
roots=[
|
||||||
|
Root(
|
||||||
|
uri=FileUrl("file://users/fake/test"),
|
||||||
|
name="Test Root 1",
|
||||||
|
),
|
||||||
|
Root(
|
||||||
|
uri=FileUrl("file://users/fake/test/2"),
|
||||||
|
name="Test Root 2",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_roots_callback(
|
||||||
|
context: RequestContext[ClientSession, None],
|
||||||
|
) -> ListRootsResult:
|
||||||
|
return callback_return
|
||||||
|
|
||||||
|
@server.tool("test_list_roots")
|
||||||
|
async def test_list_roots(context: Context, message: str):
|
||||||
|
roots = await context.session.list_roots()
|
||||||
|
assert roots == callback_return
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Test with list_roots callback
|
||||||
|
async with create_session(
|
||||||
|
server._mcp_server, list_roots_callback=list_roots_callback
|
||||||
|
) as client_session:
|
||||||
|
# Make a request to trigger sampling callback
|
||||||
|
result = await client_session.call_tool(
|
||||||
|
"test_list_roots", {"message": "test message"}
|
||||||
|
)
|
||||||
|
assert result.isError is False
|
||||||
|
assert isinstance(result.content[0], TextContent)
|
||||||
|
assert result.content[0].text == "true"
|
||||||
|
|
||||||
|
# Test without list_roots callback
|
||||||
|
async with create_session(server._mcp_server) as client_session:
|
||||||
|
# Make a request to trigger sampling callback
|
||||||
|
result = await client_session.call_tool(
|
||||||
|
"test_list_roots", {"message": "test message"}
|
||||||
|
)
|
||||||
|
assert result.isError is True
|
||||||
|
assert isinstance(result.content[0], TextContent)
|
||||||
|
assert (
|
||||||
|
result.content[0].text
|
||||||
|
== "Error executing tool test_list_roots: List roots not supported"
|
||||||
|
)
|
||||||
73
tests/client/test_sampling_callback.py
Normal file
73
tests/client/test_sampling_callback.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from mcp.client.session import ClientSession
|
||||||
|
from mcp.shared.context import RequestContext
|
||||||
|
from mcp.shared.memory import (
|
||||||
|
create_connected_server_and_client_session as create_session,
|
||||||
|
)
|
||||||
|
from mcp.types import (
|
||||||
|
CreateMessageRequestParams,
|
||||||
|
CreateMessageResult,
|
||||||
|
SamplingMessage,
|
||||||
|
TextContent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_sampling_callback():
|
||||||
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
|
||||||
|
server = FastMCP("test")
|
||||||
|
|
||||||
|
callback_return = CreateMessageResult(
|
||||||
|
role="assistant",
|
||||||
|
content=TextContent(
|
||||||
|
type="text", text="This is a response from the sampling callback"
|
||||||
|
),
|
||||||
|
model="test-model",
|
||||||
|
stopReason="endTurn",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def sampling_callback(
|
||||||
|
context: RequestContext[ClientSession, None],
|
||||||
|
params: CreateMessageRequestParams,
|
||||||
|
) -> CreateMessageResult:
|
||||||
|
return callback_return
|
||||||
|
|
||||||
|
@server.tool("test_sampling")
|
||||||
|
async def test_sampling_tool(message: str):
|
||||||
|
value = await server.get_context().session.create_message(
|
||||||
|
messages=[
|
||||||
|
SamplingMessage(
|
||||||
|
role="user", content=TextContent(type="text", text=message)
|
||||||
|
)
|
||||||
|
],
|
||||||
|
max_tokens=100,
|
||||||
|
)
|
||||||
|
assert value == callback_return
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Test with sampling callback
|
||||||
|
async with create_session(
|
||||||
|
server._mcp_server, sampling_callback=sampling_callback
|
||||||
|
) as client_session:
|
||||||
|
# Make a request to trigger sampling callback
|
||||||
|
result = await client_session.call_tool(
|
||||||
|
"test_sampling", {"message": "Test message for sampling"}
|
||||||
|
)
|
||||||
|
assert result.isError is False
|
||||||
|
assert isinstance(result.content[0], TextContent)
|
||||||
|
assert result.content[0].text == "true"
|
||||||
|
|
||||||
|
# Test without sampling callback
|
||||||
|
async with create_session(server._mcp_server) as client_session:
|
||||||
|
# Make a request to trigger sampling callback
|
||||||
|
result = await client_session.call_tool(
|
||||||
|
"test_sampling", {"message": "Test message for sampling"}
|
||||||
|
)
|
||||||
|
assert result.isError is True
|
||||||
|
assert isinstance(result.content[0], TextContent)
|
||||||
|
assert (
|
||||||
|
result.content[0].text
|
||||||
|
== "Error executing tool test_sampling: Sampling not supported"
|
||||||
|
)
|
||||||
@@ -1,12 +1,17 @@
|
|||||||
|
import shutil
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from mcp.client.stdio import StdioServerParameters, stdio_client
|
from mcp.client.stdio import StdioServerParameters, stdio_client
|
||||||
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
|
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
|
||||||
|
|
||||||
|
tee: str = shutil.which("tee") # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
|
@pytest.mark.skipif(tee is None, reason="could not find tee command")
|
||||||
async def test_stdio_client():
|
async def test_stdio_client():
|
||||||
server_parameters = StdioServerParameters(command="/usr/bin/tee")
|
server_parameters = StdioServerParameters(command=tee)
|
||||||
|
|
||||||
async with stdio_client(server_parameters) as (read_stream, write_stream):
|
async with stdio_client(server_parameters) as (read_stream, write_stream):
|
||||||
# Test sending and receiving messages
|
# Test sending and receiving messages
|
||||||
|
|||||||
Reference in New Issue
Block a user