Merge pull request #19 from modelcontextprotocol/justin/upgrade-spec

Upgrade to protocol version 2024-10-07
This commit is contained in:
Justin Spahr-Summers
2024-10-21 15:33:30 +01:00
committed by GitHub
7 changed files with 80 additions and 23 deletions

View File

@@ -2,8 +2,9 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStre
from pydantic import AnyUrl from pydantic import AnyUrl
from mcp_python.shared.session import BaseSession from mcp_python.shared.session import BaseSession
from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSION from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSIONS
from mcp_python.types import ( from mcp_python.types import (
LATEST_PROTOCOL_VERSION,
CallToolResult, CallToolResult,
ClientCapabilities, ClientCapabilities,
ClientNotification, ClientNotification,
@@ -49,7 +50,7 @@ class ClientSession(
InitializeRequest( InitializeRequest(
method="initialize", method="initialize",
params=InitializeRequestParams( params=InitializeRequestParams(
protocolVersion=SUPPORTED_PROTOCOL_VERSION, protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=ClientCapabilities( capabilities=ClientCapabilities(
sampling=None, experimental=None sampling=None, experimental=None
), ),
@@ -60,7 +61,7 @@ class ClientSession(
InitializeResult, InitializeResult,
) )
if result.protocolVersion != SUPPORTED_PROTOCOL_VERSION: if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
raise RuntimeError( raise RuntimeError(
"Unsupported protocol version from the server: " "Unsupported protocol version from the server: "
f"{result.protocolVersion}" f"{result.protocolVersion}"

View File

@@ -162,7 +162,7 @@ class Server:
async def handler(_: Any): async def handler(_: Any):
resources = await func() resources = await func()
return ServerResult( return ServerResult(
ListResourcesResult(resources=resources, resourceTemplates=None) ListResourcesResult(resources=resources)
) )
self.request_handlers[ListResourcesRequest] = handler self.request_handlers[ListResourcesRequest] = handler

View File

@@ -11,8 +11,8 @@ from mcp_python.shared.session import (
BaseSession, BaseSession,
RequestResponder, RequestResponder,
) )
from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSION
from mcp_python.types import ( from mcp_python.types import (
LATEST_PROTOCOL_VERSION,
ClientNotification, ClientNotification,
ClientRequest, ClientRequest,
CreateMessageResult, CreateMessageResult,
@@ -67,7 +67,7 @@ class ServerSession(
await responder.respond( await responder.respond(
ServerResult( ServerResult(
InitializeResult( InitializeResult(
protocolVersion=SUPPORTED_PROTOCOL_VERSION, protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=self._init_options.capabilities, capabilities=self._init_options.capabilities,
serverInfo=Implementation( serverInfo=Implementation(
name=self._init_options.server_name, name=self._init_options.server_name,

View File

@@ -1 +1,3 @@
SUPPORTED_PROTOCOL_VERSION = 1 from mcp_python.types import LATEST_PROTOCOL_VERSION
SUPPORTED_PROTOCOL_VERSIONS = [1, LATEST_PROTOCOL_VERSION]

View File

@@ -21,8 +21,10 @@ for reference.
not separate types in the schema. not separate types in the schema.
""" """
LATEST_PROTOCOL_VERSION = "2024-10-07"
ProgressToken = str | int ProgressToken = str | int
Cursor = str
class RequestParams(BaseModel): class RequestParams(BaseModel):
@@ -64,6 +66,14 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
class PaginatedRequest(Request[RequestParamsT, MethodT]):
cursor: Cursor | None = None
"""
An opaque token representing the current pagination position.
If provided, the server should return results starting after this cursor.
"""
class Notification(BaseModel, Generic[NotificationParamsT, MethodT]): class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
"""Base class for JSON-RPC notifications.""" """Base class for JSON-RPC notifications."""
@@ -83,6 +93,14 @@ class Result(BaseModel):
""" """
class PaginatedResult(Result):
nextCursor: Cursor | None = None
"""
An opaque token representing the pagination position after the last returned result.
If present, there may be more results available.
"""
RequestId = str | int RequestId = str | int
@@ -115,6 +133,7 @@ PARSE_ERROR = -32700
INVALID_REQUEST = -32600 INVALID_REQUEST = -32600
METHOD_NOT_FOUND = -32601 METHOD_NOT_FOUND = -32601
INVALID_PARAMS = -32602 INVALID_PARAMS = -32602
INTERNAL_ERROR = -32603
class ErrorData(BaseModel): class ErrorData(BaseModel):
@@ -191,7 +210,7 @@ class ServerCapabilities(BaseModel):
class InitializeRequestParams(RequestParams): class InitializeRequestParams(RequestParams):
"""Parameters for the initialize request.""" """Parameters for the initialize request."""
protocolVersion: Literal[1] protocolVersion: str | int
"""The latest version of the Model Context Protocol that the client supports.""" """The latest version of the Model Context Protocol that the client supports."""
capabilities: ClientCapabilities capabilities: ClientCapabilities
clientInfo: Implementation clientInfo: Implementation
@@ -211,7 +230,7 @@ class InitializeRequest(Request):
class InitializeResult(Result): class InitializeResult(Result):
"""After receiving an initialize request from the client, the server sends this.""" """After receiving an initialize request from the client, the server sends this."""
protocolVersion: Literal[1] protocolVersion: str | int
"""The version of the Model Context Protocol that the server wants to use.""" """The version of the Model Context Protocol that the server wants to use."""
capabilities: ServerCapabilities capabilities: ServerCapabilities
serverInfo: Implementation serverInfo: Implementation
@@ -265,7 +284,7 @@ class ProgressNotification(Notification):
params: ProgressNotificationParams params: ProgressNotificationParams
class ListResourcesRequest(Request): class ListResourcesRequest(PaginatedRequest):
"""Sent from the client to request a list of resources the server has.""" """Sent from the client to request a list of resources the server has."""
method: Literal["resources/list"] method: Literal["resources/list"]
@@ -277,6 +296,10 @@ class Resource(BaseModel):
uri: AnyUrl uri: AnyUrl
"""The URI of this resource.""" """The URI of this resource."""
name: str
"""A human-readable name for this resource."""
description: str | None = None
"""A description of what this resource represents."""
mimeType: str | None = None mimeType: str | None = None
"""The MIME type of this resource, if known.""" """The MIME type of this resource, if known."""
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@@ -290,7 +313,7 @@ class ResourceTemplate(BaseModel):
A URI template (according to RFC 6570) that can be used to construct resource A URI template (according to RFC 6570) that can be used to construct resource
URIs. URIs.
""" """
name: str | None = None name: str
"""A human-readable name for the type of resource this template refers to.""" """A human-readable name for the type of resource this template refers to."""
description: str | None = None description: str | None = None
"""A human-readable description of what this template is for.""" """A human-readable description of what this template is for."""
@@ -302,11 +325,23 @@ class ResourceTemplate(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
class ListResourcesResult(Result): class ListResourcesResult(PaginatedResult):
"""The server's response to a resources/list request from the client.""" """The server's response to a resources/list request from the client."""
resourceTemplates: list[ResourceTemplate] | None = None resources: list[Resource]
resources: list[Resource] | None = None
class ListResourceTemplatesRequest(PaginatedRequest):
"""Sent from the client to request a list of resource templates the server has."""
method: Literal["resources/templates/list"]
params: RequestParams | None = None
class ListResourceTemplatesResult(PaginatedResult):
"""The server's response to a resources/templates/list request from the client."""
resourceTemplates: list[ResourceTemplate]
class ReadResourceRequestParams(RequestParams): class ReadResourceRequestParams(RequestParams):
@@ -430,7 +465,7 @@ class ResourceUpdatedNotification(Notification):
params: ResourceUpdatedNotificationParams params: ResourceUpdatedNotificationParams
class ListPromptsRequest(Request): class ListPromptsRequest(PaginatedRequest):
"""Sent from the client to request a list of prompts and prompt templates.""" """Sent from the client to request a list of prompts and prompt templates."""
method: Literal["prompts/list"] method: Literal["prompts/list"]
@@ -461,7 +496,7 @@ class Prompt(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
class ListPromptsResult(Result): class ListPromptsResult(PaginatedResult):
"""The server's response to a prompts/list request from the client.""" """The server's response to a prompts/list request from the client."""
prompts: list[Prompt] prompts: list[Prompt]
@@ -526,7 +561,17 @@ class GetPromptResult(Result):
messages: list[SamplingMessage] messages: list[SamplingMessage]
class ListToolsRequest(Request): class PromptListChangedNotification(Notification):
"""
An optional notification from the server to the client, informing it that the list
of prompts it offers has changed.
"""
method: Literal["notifications/prompts/list_changed"]
params: NotificationParams | None = None
class ListToolsRequest(PaginatedRequest):
"""Sent from the client to request a list of tools the server has.""" """Sent from the client to request a list of tools the server has."""
method: Literal["tools/list"] method: Literal["tools/list"]
@@ -545,7 +590,7 @@ class Tool(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
class ListToolsResult(Result): class ListToolsResult(PaginatedResult):
"""The server's response to a tools/list request from the client.""" """The server's response to a tools/list request from the client."""
tools: list[Tool] tools: list[Tool]
@@ -742,6 +787,7 @@ class ClientRequest(
| GetPromptRequest | GetPromptRequest
| ListPromptsRequest | ListPromptsRequest
| ListResourcesRequest | ListResourcesRequest
| ListResourceTemplatesRequest
| ReadResourceRequest | ReadResourceRequest
| SubscribeRequest | SubscribeRequest
| UnsubscribeRequest | UnsubscribeRequest
@@ -771,6 +817,7 @@ class ServerNotification(
| ResourceUpdatedNotification | ResourceUpdatedNotification
| ResourceListChangedNotification | ResourceListChangedNotification
| ToolListChangedNotification | ToolListChangedNotification
| PromptListChangedNotification
] ]
): ):
pass pass
@@ -784,6 +831,7 @@ class ServerResult(
| GetPromptResult | GetPromptResult
| ListPromptsResult | ListPromptsResult
| ListResourcesResult | ListResourcesResult
| ListResourceTemplatesResult
| ReadResourceResult | ReadResourceResult
| CallToolResult | CallToolResult
| ListToolsResult | ListToolsResult

View File

@@ -3,6 +3,7 @@ import pytest
from mcp_python.client.session import ClientSession from mcp_python.client.session import ClientSession
from mcp_python.types import ( from mcp_python.types import (
LATEST_PROTOCOL_VERSION,
ClientNotification, ClientNotification,
ClientRequest, ClientRequest,
Implementation, Implementation,
@@ -41,7 +42,7 @@ async def test_client_session_initialize():
result = ServerResult( result = ServerResult(
InitializeResult( InitializeResult(
protocolVersion=1, protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=ServerCapabilities( capabilities=ServerCapabilities(
logging=None, logging=None,
resources=None, resources=None,
@@ -88,7 +89,7 @@ async def test_client_session_initialize():
# Assert the result # Assert the result
assert isinstance(result, InitializeResult) assert isinstance(result, InitializeResult)
assert result.protocolVersion == 1 assert result.protocolVersion == LATEST_PROTOCOL_VERSION
assert isinstance(result.capabilities, ServerCapabilities) assert isinstance(result.capabilities, ServerCapabilities)
assert result.serverInfo == Implementation(name="mock-server", version="0.1.0") assert result.serverInfo == Implementation(name="mock-server", version="0.1.0")

View File

@@ -1,4 +1,9 @@
from mcp_python.types import ClientRequest, JSONRPCMessage, JSONRPCRequest from mcp_python.types import (
LATEST_PROTOCOL_VERSION,
ClientRequest,
JSONRPCMessage,
JSONRPCRequest,
)
def test_jsonrpc_request(): def test_jsonrpc_request():
@@ -7,7 +12,7 @@ def test_jsonrpc_request():
"id": 1, "id": 1,
"method": "initialize", "method": "initialize",
"params": { "params": {
"protocolVersion": 1, "protocolVersion": LATEST_PROTOCOL_VERSION,
"capabilities": {"batch": None, "sampling": None}, "capabilities": {"batch": None, "sampling": None},
"clientInfo": {"name": "mcp_python", "version": "0.1.0"}, "clientInfo": {"name": "mcp_python", "version": "0.1.0"},
}, },
@@ -21,4 +26,4 @@ def test_jsonrpc_request():
assert request.root.id == 1 assert request.root.id == 1
assert request.root.method == "initialize" assert request.root.method == "initialize"
assert request.root.params is not None assert request.root.params is not None
assert request.root.params["protocolVersion"] == 1 assert request.root.params["protocolVersion"] == LATEST_PROTOCOL_VERSION