fix: Pass cursor parameter to server (#745)

This commit is contained in:
Nate Barbettini
2025-05-21 14:27:06 -07:00
committed by GitHub
parent 2ca2de767b
commit e80c0150e1
5 changed files with 306 additions and 73 deletions

View File

@@ -209,7 +209,9 @@ class ClientSession(
types.ClientRequest(
types.ListResourcesRequest(
method="resources/list",
cursor=cursor,
params=types.PaginatedRequestParams(cursor=cursor)
if cursor is not None
else None,
)
),
types.ListResourcesResult,
@@ -223,7 +225,9 @@ class ClientSession(
types.ClientRequest(
types.ListResourceTemplatesRequest(
method="resources/templates/list",
cursor=cursor,
params=types.PaginatedRequestParams(cursor=cursor)
if cursor is not None
else None,
)
),
types.ListResourceTemplatesResult,
@@ -295,7 +299,9 @@ class ClientSession(
types.ClientRequest(
types.ListPromptsRequest(
method="prompts/list",
cursor=cursor,
params=types.PaginatedRequestParams(cursor=cursor)
if cursor is not None
else None,
)
),
types.ListPromptsResult,
@@ -340,7 +346,9 @@ class ClientSession(
types.ClientRequest(
types.ListToolsRequest(
method="tools/list",
cursor=cursor,
params=types.PaginatedRequestParams(cursor=cursor)
if cursor is not None
else None,
)
),
types.ListToolsResult,

View File

@@ -53,6 +53,14 @@ class RequestParams(BaseModel):
meta: Meta | None = Field(alias="_meta", default=None)
class PaginatedRequestParams(RequestParams):
cursor: Cursor | None = None
"""
An opaque token representing the current pagination position.
If provided, the server should return results starting after this cursor.
"""
class NotificationParams(BaseModel):
class Meta(BaseModel):
model_config = ConfigDict(extra="allow")
@@ -79,12 +87,13 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]):
model_config = ConfigDict(extra="allow")
class PaginatedRequest(Request[RequestParamsT, MethodT]):
cursor: Cursor | None = None
"""
An opaque token representing the current pagination position.
If provided, the server should return results starting after this cursor.
"""
class PaginatedRequest(
Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]
):
"""Base class for paginated requests,
matching the schema's PaginatedRequest interface."""
params: PaginatedRequestParams | None = None
class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
@@ -358,13 +367,10 @@ class ProgressNotification(
params: ProgressNotificationParams
class ListResourcesRequest(
PaginatedRequest[RequestParams | None, Literal["resources/list"]]
):
class ListResourcesRequest(PaginatedRequest[Literal["resources/list"]]):
"""Sent from the client to request a list of resources the server has."""
method: Literal["resources/list"]
params: RequestParams | None = None
class Annotations(BaseModel):
@@ -423,12 +429,11 @@ class ListResourcesResult(PaginatedResult):
class ListResourceTemplatesRequest(
PaginatedRequest[RequestParams | None, Literal["resources/templates/list"]]
PaginatedRequest[Literal["resources/templates/list"]]
):
"""Sent from the client to request a list of resource templates the server has."""
method: Literal["resources/templates/list"]
params: RequestParams | None = None
class ListResourceTemplatesResult(PaginatedResult):
@@ -570,13 +575,10 @@ class ResourceUpdatedNotification(
params: ResourceUpdatedNotificationParams
class ListPromptsRequest(
PaginatedRequest[RequestParams | None, Literal["prompts/list"]]
):
class ListPromptsRequest(PaginatedRequest[Literal["prompts/list"]]):
"""Sent from the client to request a list of prompts and prompt templates."""
method: Literal["prompts/list"]
params: RequestParams | None = None
class PromptArgument(BaseModel):
@@ -703,11 +705,10 @@ class PromptListChangedNotification(
params: NotificationParams | None = None
class ListToolsRequest(PaginatedRequest[RequestParams | None, Literal["tools/list"]]):
class ListToolsRequest(PaginatedRequest[Literal["tools/list"]]):
"""Sent from the client to request a list of tools the server has."""
method: Literal["tools/list"]
params: RequestParams | None = None
class ToolAnnotations(BaseModel):

145
tests/client/conftest.py Normal file
View File

@@ -0,0 +1,145 @@
from contextlib import asynccontextmanager
from unittest.mock import patch
import pytest
import mcp.shared.memory
from mcp.shared.message import SessionMessage
from mcp.types import (
JSONRPCNotification,
JSONRPCRequest,
)
class SpyMemoryObjectSendStream:
def __init__(self, original_stream):
self.original_stream = original_stream
self.sent_messages: list[SessionMessage] = []
async def send(self, message):
self.sent_messages.append(message)
await self.original_stream.send(message)
async def aclose(self):
await self.original_stream.aclose()
async def __aenter__(self):
return self
async def __aexit__(self, *args):
await self.aclose()
class StreamSpyCollection:
def __init__(
self,
client_spy: SpyMemoryObjectSendStream,
server_spy: SpyMemoryObjectSendStream,
):
self.client = client_spy
self.server = server_spy
def clear(self) -> None:
"""Clear all captured messages."""
self.client.sent_messages.clear()
self.server.sent_messages.clear()
def get_client_requests(self, method: str | None = None) -> list[JSONRPCRequest]:
"""Get client-sent requests, optionally filtered by method."""
return [
req.message.root
for req in self.client.sent_messages
if isinstance(req.message.root, JSONRPCRequest)
and (method is None or req.message.root.method == method)
]
def get_server_requests(self, method: str | None = None) -> list[JSONRPCRequest]:
"""Get server-sent requests, optionally filtered by method."""
return [
req.message.root
for req in self.server.sent_messages
if isinstance(req.message.root, JSONRPCRequest)
and (method is None or req.message.root.method == method)
]
def get_client_notifications(
self, method: str | None = None
) -> list[JSONRPCNotification]:
"""Get client-sent notifications, optionally filtered by method."""
return [
notif.message.root
for notif in self.client.sent_messages
if isinstance(notif.message.root, JSONRPCNotification)
and (method is None or notif.message.root.method == method)
]
def get_server_notifications(
self, method: str | None = None
) -> list[JSONRPCNotification]:
"""Get server-sent notifications, optionally filtered by method."""
return [
notif.message.root
for notif in self.server.sent_messages
if isinstance(notif.message.root, JSONRPCNotification)
and (method is None or notif.message.root.method == method)
]
@pytest.fixture
def stream_spy():
"""Fixture that provides spies for both client and server write streams.
Example usage:
async def test_something(stream_spy):
# ... set up server and client ...
spies = stream_spy()
# Run some operation that sends messages
await client.some_operation()
# Check the messages
requests = spies.get_client_requests(method="some/method")
assert len(requests) == 1
# Clear for the next operation
spies.clear()
"""
client_spy = None
server_spy = None
# Store references to our spy objects
def capture_spies(c_spy, s_spy):
nonlocal client_spy, server_spy
client_spy = c_spy
server_spy = s_spy
# Create patched version of stream creation
original_create_streams = mcp.shared.memory.create_client_server_memory_streams
@asynccontextmanager
async def patched_create_streams():
async with original_create_streams() as (client_streams, server_streams):
client_read, client_write = client_streams
server_read, server_write = server_streams
# Create spy wrappers
spy_client_write = SpyMemoryObjectSendStream(client_write)
spy_server_write = SpyMemoryObjectSendStream(server_write)
# Capture references for the test to use
capture_spies(spy_client_write, spy_server_write)
yield (client_read, spy_client_write), (server_read, spy_server_write)
# Apply the patch for the duration of the test
with patch(
"mcp.shared.memory.create_client_server_memory_streams", patched_create_streams
):
# Return a collection with helper methods
def get_spy_collection() -> StreamSpyCollection:
assert client_spy is not None, "client_spy was not initialized"
assert server_spy is not None, "server_spy was not initialized"
return StreamSpyCollection(client_spy, server_spy)
yield get_spy_collection

View File

@@ -9,11 +9,11 @@ from mcp.shared.memory import (
pytestmark = pytest.mark.anyio
async def test_list_tools_cursor_parameter():
"""Test that the cursor parameter is accepted for list_tools.
async def test_list_tools_cursor_parameter(stream_spy):
"""Test that the cursor parameter is accepted for list_tools
and that it is correctly passed to the server.
Note: FastMCP doesn't currently implement pagination, so this test
only verifies that the cursor parameter is accepted by the client.
See: https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/pagination#request-format
"""
server = FastMCP("test")
@@ -29,28 +29,46 @@ async def test_list_tools_cursor_parameter():
return "Result 2"
async with create_session(server._mcp_server) as client_session:
spies = stream_spy()
# Test without cursor parameter (omitted)
result1 = await client_session.list_tools()
assert len(result1.tools) == 2
_ = await client_session.list_tools()
list_tools_requests = spies.get_client_requests(method="tools/list")
assert len(list_tools_requests) == 1
assert list_tools_requests[0].params is None
spies.clear()
# Test with cursor=None
result2 = await client_session.list_tools(cursor=None)
assert len(result2.tools) == 2
_ = await client_session.list_tools(cursor=None)
list_tools_requests = spies.get_client_requests(method="tools/list")
assert len(list_tools_requests) == 1
assert list_tools_requests[0].params is None
spies.clear()
# Test with cursor as string
result3 = await client_session.list_tools(cursor="some_cursor_value")
assert len(result3.tools) == 2
_ = await client_session.list_tools(cursor="some_cursor_value")
list_tools_requests = spies.get_client_requests(method="tools/list")
assert len(list_tools_requests) == 1
assert list_tools_requests[0].params is not None
assert list_tools_requests[0].params["cursor"] == "some_cursor_value"
spies.clear()
# Test with empty string cursor
result4 = await client_session.list_tools(cursor="")
assert len(result4.tools) == 2
_ = await client_session.list_tools(cursor="")
list_tools_requests = spies.get_client_requests(method="tools/list")
assert len(list_tools_requests) == 1
assert list_tools_requests[0].params is not None
assert list_tools_requests[0].params["cursor"] == ""
async def test_list_resources_cursor_parameter():
"""Test that the cursor parameter is accepted for list_resources.
async def test_list_resources_cursor_parameter(stream_spy):
"""Test that the cursor parameter is accepted for list_resources
and that it is correctly passed to the server.
Note: FastMCP doesn't currently implement pagination, so this test
only verifies that the cursor parameter is accepted by the client.
See: https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/pagination#request-format
"""
server = FastMCP("test")
@@ -61,28 +79,45 @@ async def test_list_resources_cursor_parameter():
return "Test data"
async with create_session(server._mcp_server) as client_session:
spies = stream_spy()
# Test without cursor parameter (omitted)
result1 = await client_session.list_resources()
assert len(result1.resources) >= 1
_ = await client_session.list_resources()
list_resources_requests = spies.get_client_requests(method="resources/list")
assert len(list_resources_requests) == 1
assert list_resources_requests[0].params is None
spies.clear()
# Test with cursor=None
result2 = await client_session.list_resources(cursor=None)
assert len(result2.resources) >= 1
_ = await client_session.list_resources(cursor=None)
list_resources_requests = spies.get_client_requests(method="resources/list")
assert len(list_resources_requests) == 1
assert list_resources_requests[0].params is None
spies.clear()
# Test with cursor as string
result3 = await client_session.list_resources(cursor="some_cursor")
assert len(result3.resources) >= 1
_ = await client_session.list_resources(cursor="some_cursor")
list_resources_requests = spies.get_client_requests(method="resources/list")
assert len(list_resources_requests) == 1
assert list_resources_requests[0].params is not None
assert list_resources_requests[0].params["cursor"] == "some_cursor"
spies.clear()
# Test with empty string cursor
result4 = await client_session.list_resources(cursor="")
assert len(result4.resources) >= 1
_ = await client_session.list_resources(cursor="")
list_resources_requests = spies.get_client_requests(method="resources/list")
assert len(list_resources_requests) == 1
assert list_resources_requests[0].params is not None
assert list_resources_requests[0].params["cursor"] == ""
async def test_list_prompts_cursor_parameter():
"""Test that the cursor parameter is accepted for list_prompts.
Note: FastMCP doesn't currently implement pagination, so this test
only verifies that the cursor parameter is accepted by the client.
async def test_list_prompts_cursor_parameter(stream_spy):
"""Test that the cursor parameter is accepted for list_prompts
and that it is correctly passed to the server.
See: https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/pagination#request-format
"""
server = FastMCP("test")
@@ -93,28 +128,46 @@ async def test_list_prompts_cursor_parameter():
return f"Hello, {name}!"
async with create_session(server._mcp_server) as client_session:
spies = stream_spy()
# Test without cursor parameter (omitted)
result1 = await client_session.list_prompts()
assert len(result1.prompts) >= 1
_ = await client_session.list_prompts()
list_prompts_requests = spies.get_client_requests(method="prompts/list")
assert len(list_prompts_requests) == 1
assert list_prompts_requests[0].params is None
spies.clear()
# Test with cursor=None
result2 = await client_session.list_prompts(cursor=None)
assert len(result2.prompts) >= 1
_ = await client_session.list_prompts(cursor=None)
list_prompts_requests = spies.get_client_requests(method="prompts/list")
assert len(list_prompts_requests) == 1
assert list_prompts_requests[0].params is None
spies.clear()
# Test with cursor as string
result3 = await client_session.list_prompts(cursor="some_cursor")
assert len(result3.prompts) >= 1
_ = await client_session.list_prompts(cursor="some_cursor")
list_prompts_requests = spies.get_client_requests(method="prompts/list")
assert len(list_prompts_requests) == 1
assert list_prompts_requests[0].params is not None
assert list_prompts_requests[0].params["cursor"] == "some_cursor"
spies.clear()
# Test with empty string cursor
result4 = await client_session.list_prompts(cursor="")
assert len(result4.prompts) >= 1
_ = await client_session.list_prompts(cursor="")
list_prompts_requests = spies.get_client_requests(method="prompts/list")
assert len(list_prompts_requests) == 1
assert list_prompts_requests[0].params is not None
assert list_prompts_requests[0].params["cursor"] == ""
async def test_list_resource_templates_cursor_parameter():
"""Test that the cursor parameter is accepted for list_resource_templates.
async def test_list_resource_templates_cursor_parameter(stream_spy):
"""Test that the cursor parameter is accepted for list_resource_templates
and that it is correctly passed to the server.
Note: FastMCP doesn't currently implement pagination, so this test
only verifies that the cursor parameter is accepted by the client.
See: https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/pagination#request-format
"""
server = FastMCP("test")
@@ -125,18 +178,44 @@ async def test_list_resource_templates_cursor_parameter():
return f"Data for {name}"
async with create_session(server._mcp_server) as client_session:
spies = stream_spy()
# Test without cursor parameter (omitted)
result1 = await client_session.list_resource_templates()
assert len(result1.resourceTemplates) >= 1
_ = await client_session.list_resource_templates()
list_templates_requests = spies.get_client_requests(
method="resources/templates/list"
)
assert len(list_templates_requests) == 1
assert list_templates_requests[0].params is None
spies.clear()
# Test with cursor=None
result2 = await client_session.list_resource_templates(cursor=None)
assert len(result2.resourceTemplates) >= 1
_ = await client_session.list_resource_templates(cursor=None)
list_templates_requests = spies.get_client_requests(
method="resources/templates/list"
)
assert len(list_templates_requests) == 1
assert list_templates_requests[0].params is None
spies.clear()
# Test with cursor as string
result3 = await client_session.list_resource_templates(cursor="some_cursor")
assert len(result3.resourceTemplates) >= 1
_ = await client_session.list_resource_templates(cursor="some_cursor")
list_templates_requests = spies.get_client_requests(
method="resources/templates/list"
)
assert len(list_templates_requests) == 1
assert list_templates_requests[0].params is not None
assert list_templates_requests[0].params["cursor"] == "some_cursor"
spies.clear()
# Test with empty string cursor
result4 = await client_session.list_resource_templates(cursor="")
assert len(result4.resourceTemplates) >= 1
_ = await client_session.list_resource_templates(cursor="")
list_templates_requests = spies.get_client_requests(
method="resources/templates/list"
)
assert len(list_templates_requests) == 1
assert list_templates_requests[0].params is not None
assert list_templates_requests[0].params["cursor"] == ""

View File

@@ -25,7 +25,7 @@ async def test_resource_templates():
# The handler returns a ServerResult with a ListResourceTemplatesResult inside
result = await mcp._mcp_server.request_handlers[types.ListResourceTemplatesRequest](
types.ListResourceTemplatesRequest(
method="resources/templates/list", params=None, cursor=None
method="resources/templates/list", params=None
)
)
assert isinstance(result.root, types.ListResourceTemplatesResult)