mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Merge pull request #19 from modelcontextprotocol/justin/upgrade-spec
Upgrade to protocol version 2024-10-07
This commit is contained in:
@@ -2,8 +2,9 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStre
|
||||
from pydantic import AnyUrl
|
||||
|
||||
from mcp_python.shared.session import BaseSession
|
||||
from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSION
|
||||
from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
||||
from mcp_python.types import (
|
||||
LATEST_PROTOCOL_VERSION,
|
||||
CallToolResult,
|
||||
ClientCapabilities,
|
||||
ClientNotification,
|
||||
@@ -49,7 +50,7 @@ class ClientSession(
|
||||
InitializeRequest(
|
||||
method="initialize",
|
||||
params=InitializeRequestParams(
|
||||
protocolVersion=SUPPORTED_PROTOCOL_VERSION,
|
||||
protocolVersion=LATEST_PROTOCOL_VERSION,
|
||||
capabilities=ClientCapabilities(
|
||||
sampling=None, experimental=None
|
||||
),
|
||||
@@ -60,7 +61,7 @@ class ClientSession(
|
||||
InitializeResult,
|
||||
)
|
||||
|
||||
if result.protocolVersion != SUPPORTED_PROTOCOL_VERSION:
|
||||
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
|
||||
raise RuntimeError(
|
||||
"Unsupported protocol version from the server: "
|
||||
f"{result.protocolVersion}"
|
||||
|
||||
@@ -162,7 +162,7 @@ class Server:
|
||||
async def handler(_: Any):
|
||||
resources = await func()
|
||||
return ServerResult(
|
||||
ListResourcesResult(resources=resources, resourceTemplates=None)
|
||||
ListResourcesResult(resources=resources)
|
||||
)
|
||||
|
||||
self.request_handlers[ListResourcesRequest] = handler
|
||||
|
||||
@@ -11,8 +11,8 @@ from mcp_python.shared.session import (
|
||||
BaseSession,
|
||||
RequestResponder,
|
||||
)
|
||||
from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSION
|
||||
from mcp_python.types import (
|
||||
LATEST_PROTOCOL_VERSION,
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
CreateMessageResult,
|
||||
@@ -67,7 +67,7 @@ class ServerSession(
|
||||
await responder.respond(
|
||||
ServerResult(
|
||||
InitializeResult(
|
||||
protocolVersion=SUPPORTED_PROTOCOL_VERSION,
|
||||
protocolVersion=LATEST_PROTOCOL_VERSION,
|
||||
capabilities=self._init_options.capabilities,
|
||||
serverInfo=Implementation(
|
||||
name=self._init_options.server_name,
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
SUPPORTED_PROTOCOL_VERSION = 1
|
||||
from mcp_python.types import LATEST_PROTOCOL_VERSION
|
||||
|
||||
SUPPORTED_PROTOCOL_VERSIONS = [1, LATEST_PROTOCOL_VERSION]
|
||||
|
||||
@@ -21,8 +21,10 @@ for reference.
|
||||
not separate types in the schema.
|
||||
"""
|
||||
|
||||
LATEST_PROTOCOL_VERSION = "2024-10-07"
|
||||
|
||||
ProgressToken = str | int
|
||||
Cursor = str
|
||||
|
||||
|
||||
class RequestParams(BaseModel):
|
||||
@@ -64,6 +66,14 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class PaginatedRequest(Request[RequestParamsT, MethodT]):
|
||||
cursor: Cursor | None = None
|
||||
"""
|
||||
An opaque token representing the current pagination position.
|
||||
If provided, the server should return results starting after this cursor.
|
||||
"""
|
||||
|
||||
|
||||
class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
|
||||
"""Base class for JSON-RPC notifications."""
|
||||
|
||||
@@ -83,6 +93,14 @@ class Result(BaseModel):
|
||||
"""
|
||||
|
||||
|
||||
class PaginatedResult(Result):
|
||||
nextCursor: Cursor | None = None
|
||||
"""
|
||||
An opaque token representing the pagination position after the last returned result.
|
||||
If present, there may be more results available.
|
||||
"""
|
||||
|
||||
|
||||
RequestId = str | int
|
||||
|
||||
|
||||
@@ -115,6 +133,7 @@ PARSE_ERROR = -32700
|
||||
INVALID_REQUEST = -32600
|
||||
METHOD_NOT_FOUND = -32601
|
||||
INVALID_PARAMS = -32602
|
||||
INTERNAL_ERROR = -32603
|
||||
|
||||
|
||||
class ErrorData(BaseModel):
|
||||
@@ -191,7 +210,7 @@ class ServerCapabilities(BaseModel):
|
||||
class InitializeRequestParams(RequestParams):
|
||||
"""Parameters for the initialize request."""
|
||||
|
||||
protocolVersion: Literal[1]
|
||||
protocolVersion: str | int
|
||||
"""The latest version of the Model Context Protocol that the client supports."""
|
||||
capabilities: ClientCapabilities
|
||||
clientInfo: Implementation
|
||||
@@ -211,7 +230,7 @@ class InitializeRequest(Request):
|
||||
class InitializeResult(Result):
|
||||
"""After receiving an initialize request from the client, the server sends this."""
|
||||
|
||||
protocolVersion: Literal[1]
|
||||
protocolVersion: str | int
|
||||
"""The version of the Model Context Protocol that the server wants to use."""
|
||||
capabilities: ServerCapabilities
|
||||
serverInfo: Implementation
|
||||
@@ -265,7 +284,7 @@ class ProgressNotification(Notification):
|
||||
params: ProgressNotificationParams
|
||||
|
||||
|
||||
class ListResourcesRequest(Request):
|
||||
class ListResourcesRequest(PaginatedRequest):
|
||||
"""Sent from the client to request a list of resources the server has."""
|
||||
|
||||
method: Literal["resources/list"]
|
||||
@@ -277,6 +296,10 @@ class Resource(BaseModel):
|
||||
|
||||
uri: AnyUrl
|
||||
"""The URI of this resource."""
|
||||
name: str
|
||||
"""A human-readable name for this resource."""
|
||||
description: str | None = None
|
||||
"""A description of what this resource represents."""
|
||||
mimeType: str | None = None
|
||||
"""The MIME type of this resource, if known."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
@@ -290,7 +313,7 @@ class ResourceTemplate(BaseModel):
|
||||
A URI template (according to RFC 6570) that can be used to construct resource
|
||||
URIs.
|
||||
"""
|
||||
name: str | None = None
|
||||
name: str
|
||||
"""A human-readable name for the type of resource this template refers to."""
|
||||
description: str | None = None
|
||||
"""A human-readable description of what this template is for."""
|
||||
@@ -302,11 +325,23 @@ class ResourceTemplate(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ListResourcesResult(Result):
|
||||
class ListResourcesResult(PaginatedResult):
|
||||
"""The server's response to a resources/list request from the client."""
|
||||
|
||||
resourceTemplates: list[ResourceTemplate] | None = None
|
||||
resources: list[Resource] | None = None
|
||||
resources: list[Resource]
|
||||
|
||||
|
||||
class ListResourceTemplatesRequest(PaginatedRequest):
|
||||
"""Sent from the client to request a list of resource templates the server has."""
|
||||
|
||||
method: Literal["resources/templates/list"]
|
||||
params: RequestParams | None = None
|
||||
|
||||
|
||||
class ListResourceTemplatesResult(PaginatedResult):
|
||||
"""The server's response to a resources/templates/list request from the client."""
|
||||
|
||||
resourceTemplates: list[ResourceTemplate]
|
||||
|
||||
|
||||
class ReadResourceRequestParams(RequestParams):
|
||||
@@ -430,7 +465,7 @@ class ResourceUpdatedNotification(Notification):
|
||||
params: ResourceUpdatedNotificationParams
|
||||
|
||||
|
||||
class ListPromptsRequest(Request):
|
||||
class ListPromptsRequest(PaginatedRequest):
|
||||
"""Sent from the client to request a list of prompts and prompt templates."""
|
||||
|
||||
method: Literal["prompts/list"]
|
||||
@@ -461,7 +496,7 @@ class Prompt(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ListPromptsResult(Result):
|
||||
class ListPromptsResult(PaginatedResult):
|
||||
"""The server's response to a prompts/list request from the client."""
|
||||
|
||||
prompts: list[Prompt]
|
||||
@@ -526,7 +561,17 @@ class GetPromptResult(Result):
|
||||
messages: list[SamplingMessage]
|
||||
|
||||
|
||||
class ListToolsRequest(Request):
|
||||
class PromptListChangedNotification(Notification):
|
||||
"""
|
||||
An optional notification from the server to the client, informing it that the list
|
||||
of prompts it offers has changed.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/prompts/list_changed"]
|
||||
params: NotificationParams | None = None
|
||||
|
||||
|
||||
class ListToolsRequest(PaginatedRequest):
|
||||
"""Sent from the client to request a list of tools the server has."""
|
||||
|
||||
method: Literal["tools/list"]
|
||||
@@ -545,7 +590,7 @@ class Tool(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ListToolsResult(Result):
|
||||
class ListToolsResult(PaginatedResult):
|
||||
"""The server's response to a tools/list request from the client."""
|
||||
|
||||
tools: list[Tool]
|
||||
@@ -742,6 +787,7 @@ class ClientRequest(
|
||||
| GetPromptRequest
|
||||
| ListPromptsRequest
|
||||
| ListResourcesRequest
|
||||
| ListResourceTemplatesRequest
|
||||
| ReadResourceRequest
|
||||
| SubscribeRequest
|
||||
| UnsubscribeRequest
|
||||
@@ -771,6 +817,7 @@ class ServerNotification(
|
||||
| ResourceUpdatedNotification
|
||||
| ResourceListChangedNotification
|
||||
| ToolListChangedNotification
|
||||
| PromptListChangedNotification
|
||||
]
|
||||
):
|
||||
pass
|
||||
@@ -784,6 +831,7 @@ class ServerResult(
|
||||
| GetPromptResult
|
||||
| ListPromptsResult
|
||||
| ListResourcesResult
|
||||
| ListResourceTemplatesResult
|
||||
| ReadResourceResult
|
||||
| CallToolResult
|
||||
| ListToolsResult
|
||||
|
||||
@@ -3,6 +3,7 @@ import pytest
|
||||
|
||||
from mcp_python.client.session import ClientSession
|
||||
from mcp_python.types import (
|
||||
LATEST_PROTOCOL_VERSION,
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
Implementation,
|
||||
@@ -41,7 +42,7 @@ async def test_client_session_initialize():
|
||||
|
||||
result = ServerResult(
|
||||
InitializeResult(
|
||||
protocolVersion=1,
|
||||
protocolVersion=LATEST_PROTOCOL_VERSION,
|
||||
capabilities=ServerCapabilities(
|
||||
logging=None,
|
||||
resources=None,
|
||||
@@ -88,7 +89,7 @@ async def test_client_session_initialize():
|
||||
|
||||
# Assert the result
|
||||
assert isinstance(result, InitializeResult)
|
||||
assert result.protocolVersion == 1
|
||||
assert result.protocolVersion == LATEST_PROTOCOL_VERSION
|
||||
assert isinstance(result.capabilities, ServerCapabilities)
|
||||
assert result.serverInfo == Implementation(name="mock-server", version="0.1.0")
|
||||
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
from mcp_python.types import ClientRequest, JSONRPCMessage, JSONRPCRequest
|
||||
from mcp_python.types import (
|
||||
LATEST_PROTOCOL_VERSION,
|
||||
ClientRequest,
|
||||
JSONRPCMessage,
|
||||
JSONRPCRequest,
|
||||
)
|
||||
|
||||
|
||||
def test_jsonrpc_request():
|
||||
@@ -7,7 +12,7 @@ def test_jsonrpc_request():
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": 1,
|
||||
"protocolVersion": LATEST_PROTOCOL_VERSION,
|
||||
"capabilities": {"batch": None, "sampling": None},
|
||||
"clientInfo": {"name": "mcp_python", "version": "0.1.0"},
|
||||
},
|
||||
@@ -21,4 +26,4 @@ def test_jsonrpc_request():
|
||||
assert request.root.id == 1
|
||||
assert request.root.method == "initialize"
|
||||
assert request.root.params is not None
|
||||
assert request.root.params["protocolVersion"] == 1
|
||||
assert request.root.params["protocolVersion"] == LATEST_PROTOCOL_VERSION
|
||||
|
||||
Reference in New Issue
Block a user