From eb1024c654d787fcca81acefd161d3daadbb6b0a Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Mon, 21 Oct 2024 14:50:44 +0100 Subject: [PATCH] Update protocol version handling --- mcp_python/client/session.py | 7 ++++--- mcp_python/server/session.py | 4 ++-- mcp_python/shared/version.py | 4 +++- tests/client/test_session.py | 5 +++-- tests/test_types.py | 6 +++--- 5 files changed, 15 insertions(+), 11 deletions(-) diff --git a/mcp_python/client/session.py b/mcp_python/client/session.py index 769e945..266e741 100644 --- a/mcp_python/client/session.py +++ b/mcp_python/client/session.py @@ -2,8 +2,9 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStre from pydantic import AnyUrl from mcp_python.shared.session import BaseSession -from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSION +from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp_python.types import ( + LATEST_PROTOCOL_VERSION, CallToolResult, ClientCapabilities, ClientNotification, @@ -49,7 +50,7 @@ class ClientSession( InitializeRequest( method="initialize", params=InitializeRequestParams( - protocolVersion=SUPPORTED_PROTOCOL_VERSION, + protocolVersion=LATEST_PROTOCOL_VERSION, capabilities=ClientCapabilities( sampling=None, experimental=None ), @@ -60,7 +61,7 @@ class ClientSession( InitializeResult, ) - if result.protocolVersion != SUPPORTED_PROTOCOL_VERSION: + if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS: raise RuntimeError( "Unsupported protocol version from the server: " f"{result.protocolVersion}" diff --git a/mcp_python/server/session.py b/mcp_python/server/session.py index 375e557..be8a4df 100644 --- a/mcp_python/server/session.py +++ b/mcp_python/server/session.py @@ -11,8 +11,8 @@ from mcp_python.shared.session import ( BaseSession, RequestResponder, ) -from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSION from mcp_python.types import ( + LATEST_PROTOCOL_VERSION, ClientNotification, ClientRequest, CreateMessageResult, @@ -67,7 +67,7 @@ class ServerSession( await responder.respond( ServerResult( InitializeResult( - protocolVersion=SUPPORTED_PROTOCOL_VERSION, + protocolVersion=LATEST_PROTOCOL_VERSION, capabilities=self._init_options.capabilities, serverInfo=Implementation( name=self._init_options.server_name, diff --git a/mcp_python/shared/version.py b/mcp_python/shared/version.py index bc8db20..61d1fbe 100644 --- a/mcp_python/shared/version.py +++ b/mcp_python/shared/version.py @@ -1 +1,3 @@ -SUPPORTED_PROTOCOL_VERSION = 1 +from mcp_python.types import LATEST_PROTOCOL_VERSION + +SUPPORTED_PROTOCOL_VERSIONS = [1, LATEST_PROTOCOL_VERSION] diff --git a/tests/client/test_session.py b/tests/client/test_session.py index cb7f038..f71a4cb 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -3,6 +3,7 @@ import pytest from mcp_python.client.session import ClientSession from mcp_python.types import ( + LATEST_PROTOCOL_VERSION, ClientNotification, ClientRequest, Implementation, @@ -41,7 +42,7 @@ async def test_client_session_initialize(): result = ServerResult( InitializeResult( - protocolVersion=1, + protocolVersion=LATEST_PROTOCOL_VERSION, capabilities=ServerCapabilities( logging=None, resources=None, @@ -88,7 +89,7 @@ async def test_client_session_initialize(): # Assert the result assert isinstance(result, InitializeResult) - assert result.protocolVersion == 1 + assert result.protocolVersion == LATEST_PROTOCOL_VERSION assert isinstance(result.capabilities, ServerCapabilities) assert result.serverInfo == Implementation(name="mock-server", version="0.1.0") diff --git a/tests/test_types.py b/tests/test_types.py index decd1c8..d9bc46a 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -1,4 +1,4 @@ -from mcp_python.types import ClientRequest, JSONRPCMessage, JSONRPCRequest +from mcp_python.types import LATEST_PROTOCOL_VERSION, ClientRequest, JSONRPCMessage, JSONRPCRequest def test_jsonrpc_request(): @@ -7,7 +7,7 @@ def test_jsonrpc_request(): "id": 1, "method": "initialize", "params": { - "protocolVersion": 1, + "protocolVersion": LATEST_PROTOCOL_VERSION, "capabilities": {"batch": None, "sampling": None}, "clientInfo": {"name": "mcp_python", "version": "0.1.0"}, }, @@ -21,4 +21,4 @@ def test_jsonrpc_request(): assert request.root.id == 1 assert request.root.method == "initialize" assert request.root.params is not None - assert request.root.params["protocolVersion"] == 1 + assert request.root.params["protocolVersion"] == LATEST_PROTOCOL_VERSION