mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-18 22:44:20 +01:00
feat: add client capability checking to ServerSession
Add methods to track and verify client capabilities during initialization. This includes storing client parameters from the initialize request and providing a check_client_capability method to verify if specific capabilities are supported by the connected client.
This commit is contained in:
@@ -30,6 +30,7 @@ class ServerSession(
|
||||
]
|
||||
):
|
||||
_initialized: InitializationState = InitializationState.NotInitialized
|
||||
_client_params: types.InitializeRequestParams | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -43,12 +44,47 @@ class ServerSession(
|
||||
self._initialization_state = InitializationState.NotInitialized
|
||||
self._init_options = init_options
|
||||
|
||||
@property
|
||||
def client_params(self) -> types.InitializeRequestParams | None:
|
||||
return self._client_params
|
||||
|
||||
def check_client_capability(self, capability: types.ClientCapabilities) -> bool:
|
||||
"""Check if the client supports a specific capability."""
|
||||
if self._client_params is None:
|
||||
return False
|
||||
|
||||
# Get client capabilities from initialization params
|
||||
client_caps = self._client_params.capabilities
|
||||
|
||||
# Check each specified capability in the passed in capability object
|
||||
if capability.roots is not None:
|
||||
if client_caps.roots is None:
|
||||
return False
|
||||
if capability.roots.listChanged and not client_caps.roots.listChanged:
|
||||
return False
|
||||
|
||||
if capability.sampling is not None:
|
||||
if client_caps.sampling is None:
|
||||
return False
|
||||
|
||||
if capability.experimental is not None:
|
||||
if client_caps.experimental is None:
|
||||
return False
|
||||
# Check each experimental capability
|
||||
for exp_key, exp_value in capability.experimental.items():
|
||||
if (exp_key not in client_caps.experimental or
|
||||
client_caps.experimental[exp_key] != exp_value):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _received_request(
|
||||
self, responder: RequestResponder[types.ClientRequest, types.ServerResult]
|
||||
):
|
||||
match responder.request.root:
|
||||
case types.InitializeRequest():
|
||||
case types.InitializeRequest(params=params):
|
||||
self._initialization_state = InitializationState.Initializing
|
||||
self._client_params = params
|
||||
await responder.respond(
|
||||
types.ServerResult(
|
||||
types.InitializeResult(
|
||||
@@ -81,6 +117,7 @@ class ServerSession(
|
||||
"Received notification before initialization was complete"
|
||||
)
|
||||
|
||||
|
||||
async def send_log_message(
|
||||
self, level: types.LoggingLevel, data: Any, logger: str | None = None
|
||||
) -> None:
|
||||
|
||||
Reference in New Issue
Block a user