fix: Pass cursor parameter to server (#745)

This commit is contained in:
Nate Barbettini
2025-05-21 14:27:06 -07:00
committed by GitHub
parent 2ca2de767b
commit e80c0150e1
5 changed files with 306 additions and 73 deletions

View File

@@ -209,7 +209,9 @@ class ClientSession(
types.ClientRequest(
types.ListResourcesRequest(
method="resources/list",
cursor=cursor,
params=types.PaginatedRequestParams(cursor=cursor)
if cursor is not None
else None,
)
),
types.ListResourcesResult,
@@ -223,7 +225,9 @@ class ClientSession(
types.ClientRequest(
types.ListResourceTemplatesRequest(
method="resources/templates/list",
cursor=cursor,
params=types.PaginatedRequestParams(cursor=cursor)
if cursor is not None
else None,
)
),
types.ListResourceTemplatesResult,
@@ -295,7 +299,9 @@ class ClientSession(
types.ClientRequest(
types.ListPromptsRequest(
method="prompts/list",
cursor=cursor,
params=types.PaginatedRequestParams(cursor=cursor)
if cursor is not None
else None,
)
),
types.ListPromptsResult,
@@ -340,7 +346,9 @@ class ClientSession(
types.ClientRequest(
types.ListToolsRequest(
method="tools/list",
cursor=cursor,
params=types.PaginatedRequestParams(cursor=cursor)
if cursor is not None
else None,
)
),
types.ListToolsResult,

View File

@@ -53,6 +53,14 @@ class RequestParams(BaseModel):
meta: Meta | None = Field(alias="_meta", default=None)
class PaginatedRequestParams(RequestParams):
cursor: Cursor | None = None
"""
An opaque token representing the current pagination position.
If provided, the server should return results starting after this cursor.
"""
class NotificationParams(BaseModel):
class Meta(BaseModel):
model_config = ConfigDict(extra="allow")
@@ -79,12 +87,13 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]):
model_config = ConfigDict(extra="allow")
class PaginatedRequest(Request[RequestParamsT, MethodT]):
cursor: Cursor | None = None
"""
An opaque token representing the current pagination position.
If provided, the server should return results starting after this cursor.
"""
class PaginatedRequest(
Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]
):
"""Base class for paginated requests,
matching the schema's PaginatedRequest interface."""
params: PaginatedRequestParams | None = None
class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
@@ -358,13 +367,10 @@ class ProgressNotification(
params: ProgressNotificationParams
class ListResourcesRequest(
PaginatedRequest[RequestParams | None, Literal["resources/list"]]
):
class ListResourcesRequest(PaginatedRequest[Literal["resources/list"]]):
"""Sent from the client to request a list of resources the server has."""
method: Literal["resources/list"]
params: RequestParams | None = None
class Annotations(BaseModel):
@@ -423,12 +429,11 @@ class ListResourcesResult(PaginatedResult):
class ListResourceTemplatesRequest(
PaginatedRequest[RequestParams | None, Literal["resources/templates/list"]]
PaginatedRequest[Literal["resources/templates/list"]]
):
"""Sent from the client to request a list of resource templates the server has."""
method: Literal["resources/templates/list"]
params: RequestParams | None = None
class ListResourceTemplatesResult(PaginatedResult):
@@ -570,13 +575,10 @@ class ResourceUpdatedNotification(
params: ResourceUpdatedNotificationParams
class ListPromptsRequest(
PaginatedRequest[RequestParams | None, Literal["prompts/list"]]
):
class ListPromptsRequest(PaginatedRequest[Literal["prompts/list"]]):
"""Sent from the client to request a list of prompts and prompt templates."""
method: Literal["prompts/list"]
params: RequestParams | None = None
class PromptArgument(BaseModel):
@@ -703,11 +705,10 @@ class PromptListChangedNotification(
params: NotificationParams | None = None
class ListToolsRequest(PaginatedRequest[RequestParams | None, Literal["tools/list"]]):
class ListToolsRequest(PaginatedRequest[Literal["tools/list"]]):
"""Sent from the client to request a list of tools the server has."""
method: Literal["tools/list"]
params: RequestParams | None = None
class ToolAnnotations(BaseModel):
@@ -741,7 +742,7 @@ class ToolAnnotations(BaseModel):
idempotentHint: bool | None = None
"""
If true, calling the tool repeatedly with the same arguments
If true, calling the tool repeatedly with the same arguments
will have no additional effect on the its environment.
(This property is meaningful only when `readOnlyHint == false`)
Default: false

145
tests/client/conftest.py Normal file
View File

@@ -0,0 +1,145 @@
from contextlib import asynccontextmanager
from unittest.mock import patch
import pytest
import mcp.shared.memory
from mcp.shared.message import SessionMessage
from mcp.types import (
JSONRPCNotification,
JSONRPCRequest,
)
class SpyMemoryObjectSendStream:
def __init__(self, original_stream):
self.original_stream = original_stream
self.sent_messages: list[SessionMessage] = []
async def send(self, message):
self.sent_messages.append(message)
await self.original_stream.send(message)
async def aclose(self):
await self.original_stream.aclose()
async def __aenter__(self):
return self
async def __aexit__(self, *args):
await self.aclose()
class StreamSpyCollection:
def __init__(
self,
client_spy: SpyMemoryObjectSendStream,
server_spy: SpyMemoryObjectSendStream,
):
self.client = client_spy
self.server = server_spy
def clear(self) -> None:
"""Clear all captured messages."""
self.client.sent_messages.clear()
self.server.sent_messages.clear()
def get_client_requests(self, method: str | None = None) -> list[JSONRPCRequest]:
"""Get client-sent requests, optionally filtered by method."""
return [
req.message.root
for req in self.client.sent_messages
if isinstance(req.message.root, JSONRPCRequest)
and (method is None or req.message.root.method == method)
]
def get_server_requests(self, method: str | None = None) -> list[JSONRPCRequest]:
"""Get server-sent requests, optionally filtered by method."""
return [
req.message.root
for req in self.server.sent_messages
if isinstance(req.message.root, JSONRPCRequest)
and (method is None or req.message.root.method == method)
]
def get_client_notifications(
self, method: str | None = None
) -> list[JSONRPCNotification]:
"""Get client-sent notifications, optionally filtered by method."""
return [
notif.message.root
for notif in self.client.sent_messages
if isinstance(notif.message.root, JSONRPCNotification)
and (method is None or notif.message.root.method == method)
]
def get_server_notifications(
self, method: str | None = None
) -> list[JSONRPCNotification]:
"""Get server-sent notifications, optionally filtered by method."""
return [
notif.message.root
for notif in self.server.sent_messages
if isinstance(notif.message.root, JSONRPCNotification)
and (method is None or notif.message.root.method == method)
]
@pytest.fixture
def stream_spy():
"""Fixture that provides spies for both client and server write streams.
Example usage:
async def test_something(stream_spy):
# ... set up server and client ...
spies = stream_spy()
# Run some operation that sends messages
await client.some_operation()
# Check the messages
requests = spies.get_client_requests(method="some/method")
assert len(requests) == 1
# Clear for the next operation
spies.clear()
"""
client_spy = None
server_spy = None
# Store references to our spy objects
def capture_spies(c_spy, s_spy):
nonlocal client_spy, server_spy
client_spy = c_spy
server_spy = s_spy
# Create patched version of stream creation
original_create_streams = mcp.shared.memory.create_client_server_memory_streams
@asynccontextmanager
async def patched_create_streams():
async with original_create_streams() as (client_streams, server_streams):
client_read, client_write = client_streams
server_read, server_write = server_streams
# Create spy wrappers
spy_client_write = SpyMemoryObjectSendStream(client_write)
spy_server_write = SpyMemoryObjectSendStream(server_write)
# Capture references for the test to use
capture_spies(spy_client_write, spy_server_write)
yield (client_read, spy_client_write), (server_read, spy_server_write)
# Apply the patch for the duration of the test
with patch(
"mcp.shared.memory.create_client_server_memory_streams", patched_create_streams
):
# Return a collection with helper methods
def get_spy_collection() -> StreamSpyCollection:
assert client_spy is not None, "client_spy was not initialized"
assert server_spy is not None, "server_spy was not initialized"
return StreamSpyCollection(client_spy, server_spy)
yield get_spy_collection

View File

@@ -9,11 +9,11 @@ from mcp.shared.memory import (
pytestmark = pytest.mark.anyio
async def test_list_tools_cursor_parameter():
"""Test that the cursor parameter is accepted for list_tools.
async def test_list_tools_cursor_parameter(stream_spy):
"""Test that the cursor parameter is accepted for list_tools
and that it is correctly passed to the server.
Note: FastMCP doesn't currently implement pagination, so this test
only verifies that the cursor parameter is accepted by the client.
See: https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/pagination#request-format
"""
server = FastMCP("test")
@@ -29,28 +29,46 @@ async def test_list_tools_cursor_parameter():
return "Result 2"
async with create_session(server._mcp_server) as client_session:
spies = stream_spy()
# Test without cursor parameter (omitted)
result1 = await client_session.list_tools()
assert len(result1.tools) == 2
_ = await client_session.list_tools()
list_tools_requests = spies.get_client_requests(method="tools/list")
assert len(list_tools_requests) == 1
assert list_tools_requests[0].params is None
spies.clear()
# Test with cursor=None
result2 = await client_session.list_tools(cursor=None)
assert len(result2.tools) == 2
_ = await client_session.list_tools(cursor=None)
list_tools_requests = spies.get_client_requests(method="tools/list")
assert len(list_tools_requests) == 1
assert list_tools_requests[0].params is None
spies.clear()
# Test with cursor as string
result3 = await client_session.list_tools(cursor="some_cursor_value")
assert len(result3.tools) == 2
_ = await client_session.list_tools(cursor="some_cursor_value")
list_tools_requests = spies.get_client_requests(method="tools/list")
assert len(list_tools_requests) == 1
assert list_tools_requests[0].params is not None
assert list_tools_requests[0].params["cursor"] == "some_cursor_value"
spies.clear()
# Test with empty string cursor
result4 = await client_session.list_tools(cursor="")
assert len(result4.tools) == 2
_ = await client_session.list_tools(cursor="")
list_tools_requests = spies.get_client_requests(method="tools/list")
assert len(list_tools_requests) == 1
assert list_tools_requests[0].params is not None
assert list_tools_requests[0].params["cursor"] == ""
async def test_list_resources_cursor_parameter():
"""Test that the cursor parameter is accepted for list_resources.
async def test_list_resources_cursor_parameter(stream_spy):
"""Test that the cursor parameter is accepted for list_resources
and that it is correctly passed to the server.
Note: FastMCP doesn't currently implement pagination, so this test
only verifies that the cursor parameter is accepted by the client.
See: https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/pagination#request-format
"""
server = FastMCP("test")
@@ -61,28 +79,45 @@ async def test_list_resources_cursor_parameter():
return "Test data"
async with create_session(server._mcp_server) as client_session:
spies = stream_spy()
# Test without cursor parameter (omitted)
result1 = await client_session.list_resources()
assert len(result1.resources) >= 1
_ = await client_session.list_resources()
list_resources_requests = spies.get_client_requests(method="resources/list")
assert len(list_resources_requests) == 1
assert list_resources_requests[0].params is None
spies.clear()
# Test with cursor=None
result2 = await client_session.list_resources(cursor=None)
assert len(result2.resources) >= 1
_ = await client_session.list_resources(cursor=None)
list_resources_requests = spies.get_client_requests(method="resources/list")
assert len(list_resources_requests) == 1
assert list_resources_requests[0].params is None
spies.clear()
# Test with cursor as string
result3 = await client_session.list_resources(cursor="some_cursor")
assert len(result3.resources) >= 1
_ = await client_session.list_resources(cursor="some_cursor")
list_resources_requests = spies.get_client_requests(method="resources/list")
assert len(list_resources_requests) == 1
assert list_resources_requests[0].params is not None
assert list_resources_requests[0].params["cursor"] == "some_cursor"
spies.clear()
# Test with empty string cursor
result4 = await client_session.list_resources(cursor="")
assert len(result4.resources) >= 1
_ = await client_session.list_resources(cursor="")
list_resources_requests = spies.get_client_requests(method="resources/list")
assert len(list_resources_requests) == 1
assert list_resources_requests[0].params is not None
assert list_resources_requests[0].params["cursor"] == ""
async def test_list_prompts_cursor_parameter():
"""Test that the cursor parameter is accepted for list_prompts.
Note: FastMCP doesn't currently implement pagination, so this test
only verifies that the cursor parameter is accepted by the client.
async def test_list_prompts_cursor_parameter(stream_spy):
"""Test that the cursor parameter is accepted for list_prompts
and that it is correctly passed to the server.
See: https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/pagination#request-format
"""
server = FastMCP("test")
@@ -93,28 +128,46 @@ async def test_list_prompts_cursor_parameter():
return f"Hello, {name}!"
async with create_session(server._mcp_server) as client_session:
spies = stream_spy()
# Test without cursor parameter (omitted)
result1 = await client_session.list_prompts()
assert len(result1.prompts) >= 1
_ = await client_session.list_prompts()
list_prompts_requests = spies.get_client_requests(method="prompts/list")
assert len(list_prompts_requests) == 1
assert list_prompts_requests[0].params is None
spies.clear()
# Test with cursor=None
result2 = await client_session.list_prompts(cursor=None)
assert len(result2.prompts) >= 1
_ = await client_session.list_prompts(cursor=None)
list_prompts_requests = spies.get_client_requests(method="prompts/list")
assert len(list_prompts_requests) == 1
assert list_prompts_requests[0].params is None
spies.clear()
# Test with cursor as string
result3 = await client_session.list_prompts(cursor="some_cursor")
assert len(result3.prompts) >= 1
_ = await client_session.list_prompts(cursor="some_cursor")
list_prompts_requests = spies.get_client_requests(method="prompts/list")
assert len(list_prompts_requests) == 1
assert list_prompts_requests[0].params is not None
assert list_prompts_requests[0].params["cursor"] == "some_cursor"
spies.clear()
# Test with empty string cursor
result4 = await client_session.list_prompts(cursor="")
assert len(result4.prompts) >= 1
_ = await client_session.list_prompts(cursor="")
list_prompts_requests = spies.get_client_requests(method="prompts/list")
assert len(list_prompts_requests) == 1
assert list_prompts_requests[0].params is not None
assert list_prompts_requests[0].params["cursor"] == ""
async def test_list_resource_templates_cursor_parameter():
"""Test that the cursor parameter is accepted for list_resource_templates.
async def test_list_resource_templates_cursor_parameter(stream_spy):
"""Test that the cursor parameter is accepted for list_resource_templates
and that it is correctly passed to the server.
Note: FastMCP doesn't currently implement pagination, so this test
only verifies that the cursor parameter is accepted by the client.
See: https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/pagination#request-format
"""
server = FastMCP("test")
@@ -125,18 +178,44 @@ async def test_list_resource_templates_cursor_parameter():
return f"Data for {name}"
async with create_session(server._mcp_server) as client_session:
spies = stream_spy()
# Test without cursor parameter (omitted)
result1 = await client_session.list_resource_templates()
assert len(result1.resourceTemplates) >= 1
_ = await client_session.list_resource_templates()
list_templates_requests = spies.get_client_requests(
method="resources/templates/list"
)
assert len(list_templates_requests) == 1
assert list_templates_requests[0].params is None
spies.clear()
# Test with cursor=None
result2 = await client_session.list_resource_templates(cursor=None)
assert len(result2.resourceTemplates) >= 1
_ = await client_session.list_resource_templates(cursor=None)
list_templates_requests = spies.get_client_requests(
method="resources/templates/list"
)
assert len(list_templates_requests) == 1
assert list_templates_requests[0].params is None
spies.clear()
# Test with cursor as string
result3 = await client_session.list_resource_templates(cursor="some_cursor")
assert len(result3.resourceTemplates) >= 1
_ = await client_session.list_resource_templates(cursor="some_cursor")
list_templates_requests = spies.get_client_requests(
method="resources/templates/list"
)
assert len(list_templates_requests) == 1
assert list_templates_requests[0].params is not None
assert list_templates_requests[0].params["cursor"] == "some_cursor"
spies.clear()
# Test with empty string cursor
result4 = await client_session.list_resource_templates(cursor="")
assert len(result4.resourceTemplates) >= 1
_ = await client_session.list_resource_templates(cursor="")
list_templates_requests = spies.get_client_requests(
method="resources/templates/list"
)
assert len(list_templates_requests) == 1
assert list_templates_requests[0].params is not None
assert list_templates_requests[0].params["cursor"] == ""

View File

@@ -25,7 +25,7 @@ async def test_resource_templates():
# The handler returns a ServerResult with a ListResourceTemplatesResult inside
result = await mcp._mcp_server.request_handlers[types.ListResourceTemplatesRequest](
types.ListResourceTemplatesRequest(
method="resources/templates/list", params=None, cursor=None
method="resources/templates/list", params=None
)
)
assert isinstance(result.root, types.ListResourceTemplatesResult)