Update protocol version handling

This commit is contained in:
Justin Spahr-Summers
2024-10-21 14:50:44 +01:00
parent 2d55eabb2f
commit eb1024c654
5 changed files with 15 additions and 11 deletions

View File

@@ -2,8 +2,9 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStre
from pydantic import AnyUrl from pydantic import AnyUrl
from mcp_python.shared.session import BaseSession from mcp_python.shared.session import BaseSession
from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSION from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSIONS
from mcp_python.types import ( from mcp_python.types import (
LATEST_PROTOCOL_VERSION,
CallToolResult, CallToolResult,
ClientCapabilities, ClientCapabilities,
ClientNotification, ClientNotification,
@@ -49,7 +50,7 @@ class ClientSession(
InitializeRequest( InitializeRequest(
method="initialize", method="initialize",
params=InitializeRequestParams( params=InitializeRequestParams(
protocolVersion=SUPPORTED_PROTOCOL_VERSION, protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=ClientCapabilities( capabilities=ClientCapabilities(
sampling=None, experimental=None sampling=None, experimental=None
), ),
@@ -60,7 +61,7 @@ class ClientSession(
InitializeResult, InitializeResult,
) )
if result.protocolVersion != SUPPORTED_PROTOCOL_VERSION: if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
raise RuntimeError( raise RuntimeError(
"Unsupported protocol version from the server: " "Unsupported protocol version from the server: "
f"{result.protocolVersion}" f"{result.protocolVersion}"

View File

@@ -11,8 +11,8 @@ from mcp_python.shared.session import (
BaseSession, BaseSession,
RequestResponder, RequestResponder,
) )
from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSION
from mcp_python.types import ( from mcp_python.types import (
LATEST_PROTOCOL_VERSION,
ClientNotification, ClientNotification,
ClientRequest, ClientRequest,
CreateMessageResult, CreateMessageResult,
@@ -67,7 +67,7 @@ class ServerSession(
await responder.respond( await responder.respond(
ServerResult( ServerResult(
InitializeResult( InitializeResult(
protocolVersion=SUPPORTED_PROTOCOL_VERSION, protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=self._init_options.capabilities, capabilities=self._init_options.capabilities,
serverInfo=Implementation( serverInfo=Implementation(
name=self._init_options.server_name, name=self._init_options.server_name,

View File

@@ -1 +1,3 @@
SUPPORTED_PROTOCOL_VERSION = 1 from mcp_python.types import LATEST_PROTOCOL_VERSION
SUPPORTED_PROTOCOL_VERSIONS = [1, LATEST_PROTOCOL_VERSION]

View File

@@ -3,6 +3,7 @@ import pytest
from mcp_python.client.session import ClientSession from mcp_python.client.session import ClientSession
from mcp_python.types import ( from mcp_python.types import (
LATEST_PROTOCOL_VERSION,
ClientNotification, ClientNotification,
ClientRequest, ClientRequest,
Implementation, Implementation,
@@ -41,7 +42,7 @@ async def test_client_session_initialize():
result = ServerResult( result = ServerResult(
InitializeResult( InitializeResult(
protocolVersion=1, protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=ServerCapabilities( capabilities=ServerCapabilities(
logging=None, logging=None,
resources=None, resources=None,
@@ -88,7 +89,7 @@ async def test_client_session_initialize():
# Assert the result # Assert the result
assert isinstance(result, InitializeResult) assert isinstance(result, InitializeResult)
assert result.protocolVersion == 1 assert result.protocolVersion == LATEST_PROTOCOL_VERSION
assert isinstance(result.capabilities, ServerCapabilities) assert isinstance(result.capabilities, ServerCapabilities)
assert result.serverInfo == Implementation(name="mock-server", version="0.1.0") assert result.serverInfo == Implementation(name="mock-server", version="0.1.0")

View File

@@ -1,4 +1,4 @@
from mcp_python.types import ClientRequest, JSONRPCMessage, JSONRPCRequest from mcp_python.types import LATEST_PROTOCOL_VERSION, ClientRequest, JSONRPCMessage, JSONRPCRequest
def test_jsonrpc_request(): def test_jsonrpc_request():
@@ -7,7 +7,7 @@ def test_jsonrpc_request():
"id": 1, "id": 1,
"method": "initialize", "method": "initialize",
"params": { "params": {
"protocolVersion": 1, "protocolVersion": LATEST_PROTOCOL_VERSION,
"capabilities": {"batch": None, "sampling": None}, "capabilities": {"batch": None, "sampling": None},
"clientInfo": {"name": "mcp_python", "version": "0.1.0"}, "clientInfo": {"name": "mcp_python", "version": "0.1.0"},
}, },
@@ -21,4 +21,4 @@ def test_jsonrpc_request():
assert request.root.id == 1 assert request.root.id == 1
assert request.root.method == "initialize" assert request.root.method == "initialize"
assert request.root.params is not None assert request.root.params is not None
assert request.root.params["protocolVersion"] == 1 assert request.root.params["protocolVersion"] == LATEST_PROTOCOL_VERSION