feat: add client capability checking to ServerSession

Add methods to track and verify client capabilities during initialization. This
includes storing client parameters from the initialize request and providing a
check_client_capability method to verify if specific capabilities are supported
by the connected client.
This commit is contained in:
David Soria Parra
2024-11-11 21:53:34 +00:00
parent df33a9b71c
commit 76a0b80c4c

View File

@@ -30,6 +30,7 @@ class ServerSession(
]
):
_initialized: InitializationState = InitializationState.NotInitialized
_client_params: types.InitializeRequestParams | None = None
def __init__(
self,
@@ -43,12 +44,47 @@ class ServerSession(
self._initialization_state = InitializationState.NotInitialized
self._init_options = init_options
@property
def client_params(self) -> types.InitializeRequestParams | None:
return self._client_params
def check_client_capability(self, capability: types.ClientCapabilities) -> bool:
"""Check if the client supports a specific capability."""
if self._client_params is None:
return False
# Get client capabilities from initialization params
client_caps = self._client_params.capabilities
# Check each specified capability in the passed in capability object
if capability.roots is not None:
if client_caps.roots is None:
return False
if capability.roots.listChanged and not client_caps.roots.listChanged:
return False
if capability.sampling is not None:
if client_caps.sampling is None:
return False
if capability.experimental is not None:
if client_caps.experimental is None:
return False
# Check each experimental capability
for exp_key, exp_value in capability.experimental.items():
if (exp_key not in client_caps.experimental or
client_caps.experimental[exp_key] != exp_value):
return False
return True
async def _received_request(
self, responder: RequestResponder[types.ClientRequest, types.ServerResult]
):
match responder.request.root:
case types.InitializeRequest():
case types.InitializeRequest(params=params):
self._initialization_state = InitializationState.Initializing
self._client_params = params
await responder.respond(
types.ServerResult(
types.InitializeResult(
@@ -81,6 +117,7 @@ class ServerSession(
"Received notification before initialization was complete"
)
async def send_log_message(
self, level: types.LoggingLevel, data: Any, logger: str | None = None
) -> None: