feat: add structured capability types

Replace generic capability dictionaries with structured types for prompts,
resources, tools, and roots. This improves type safety and makes capability
features like listChanged and subscribe more explicit in the protocol.
This commit is contained in:
David Soria Parra
2024-11-06 22:50:37 +00:00
parent 14addfb872
commit 5497da0afd
6 changed files with 150 additions and 29 deletions

View File

@@ -33,10 +33,13 @@ from .types import (
Notification, Notification,
PingRequest, PingRequest,
ProgressNotification, ProgressNotification,
PromptsCapability,
ReadResourceRequest, ReadResourceRequest,
ReadResourceResult, ReadResourceResult,
Resource, Resource,
ResourcesCapability,
ResourceUpdatedNotification, ResourceUpdatedNotification,
RootsCapability,
SamplingMessage, SamplingMessage,
ServerCapabilities, ServerCapabilities,
ServerNotification, ServerNotification,
@@ -46,6 +49,7 @@ from .types import (
StopReason, StopReason,
SubscribeRequest, SubscribeRequest,
Tool, Tool,
ToolsCapability,
UnsubscribeRequest, UnsubscribeRequest,
) )
from .types import ( from .types import (
@@ -82,10 +86,13 @@ __all__ = [
"Notification", "Notification",
"PingRequest", "PingRequest",
"ProgressNotification", "ProgressNotification",
"PromptsCapability",
"ReadResourceRequest", "ReadResourceRequest",
"ReadResourceResult", "ReadResourceResult",
"ResourcesCapability",
"ResourceUpdatedNotification", "ResourceUpdatedNotification",
"Resource", "Resource",
"RootsCapability",
"SamplingMessage", "SamplingMessage",
"SamplingRole", "SamplingRole",
"ServerCapabilities", "ServerCapabilities",
@@ -98,6 +105,7 @@ __all__ = [
"StopReason", "StopReason",
"SubscribeRequest", "SubscribeRequest",
"Tool", "Tool",
"ToolsCapability",
"UnsubscribeRequest", "UnsubscribeRequest",
"stdio_client", "stdio_client",
"stdio_server", "stdio_server",

View File

@@ -26,6 +26,7 @@ from mcp_python.types import (
PromptReference, PromptReference,
ReadResourceResult, ReadResourceResult,
ResourceReference, ResourceReference,
RootsCapability,
ServerNotification, ServerNotification,
ServerRequest, ServerRequest,
) )
@@ -69,12 +70,12 @@ class ClientSession(
capabilities=ClientCapabilities( capabilities=ClientCapabilities(
sampling=None, sampling=None,
experimental=None, experimental=None,
roots={ roots=RootsCapability(
# TODO: Should this be based on whether we # TODO: Should this be based on whether we
# _will_ send notifications, or only whether # _will_ send notifications, or only whether
# they're supported? # they're supported?
"listChanged": True listChanged=True
}, ),
), ),
clientInfo=Implementation(name="mcp_python", version="0.1.0"), clientInfo=Implementation(name="mcp_python", version="0.1.0"),
), ),

View File

@@ -12,9 +12,19 @@ from mcp_python.types import JSONRPCMessage
# Environment variables to inherit by default # Environment variables to inherit by default
DEFAULT_INHERITED_ENV_VARS = ( DEFAULT_INHERITED_ENV_VARS = (
["APPDATA", "HOMEDRIVE", "HOMEPATH", "LOCALAPPDATA", "PATH", [
"PROCESSOR_ARCHITECTURE", "SYSTEMDRIVE", "SYSTEMROOT", "TEMP", "APPDATA",
"USERNAME", "USERPROFILE"] "HOMEDRIVE",
"HOMEPATH",
"LOCALAPPDATA",
"PATH",
"PROCESSOR_ARCHITECTURE",
"SYSTEMDRIVE",
"SYSTEMROOT",
"TEMP",
"USERNAME",
"USERPROFILE",
]
if sys.platform == "win32" if sys.platform == "win32"
else ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"] else ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"]
) )
@@ -74,7 +84,7 @@ async def stdio_client(server: StdioServerParameters):
process = await anyio.open_process( process = await anyio.open_process(
[server.command, *server.args], [server.command, *server.args],
env=server.env if server.env is not None else get_default_environment(), env=server.env if server.env is not None else get_default_environment(),
stderr=sys.stderr stderr=sys.stderr,
) )
async def stdout_reader(): async def stdout_reader():

View File

@@ -28,22 +28,26 @@ from mcp_python.types import (
ListResourcesResult, ListResourcesResult,
ListToolsRequest, ListToolsRequest,
ListToolsResult, ListToolsResult,
LoggingCapability,
LoggingLevel, LoggingLevel,
PingRequest, PingRequest,
ProgressNotification, ProgressNotification,
Prompt, Prompt,
PromptMessage, PromptMessage,
PromptReference, PromptReference,
PromptsCapability,
ReadResourceRequest, ReadResourceRequest,
ReadResourceResult, ReadResourceResult,
Resource, Resource,
ResourceReference, ResourceReference,
ResourcesCapability,
ServerCapabilities, ServerCapabilities,
ServerResult, ServerResult,
SetLevelRequest, SetLevelRequest,
SubscribeRequest, SubscribeRequest,
TextContent, TextContent,
Tool, Tool,
ToolsCapability,
UnsubscribeRequest, UnsubscribeRequest,
) )
@@ -54,6 +58,18 @@ request_ctx: contextvars.ContextVar[RequestContext] = contextvars.ContextVar(
) )
class NotificationOptions:
def __init__(
self,
prompts_changed: bool = False,
resources_changed: bool = False,
tools_changed: bool = False,
):
self.prompts_changed = prompts_changed
self.resources_changed = resources_changed
self.tools_changed = tools_changed
class Server: class Server:
def __init__(self, name: str): def __init__(self, name: str):
self.name = name self.name = name
@@ -61,9 +77,14 @@ class Server:
PingRequest: _ping_handler, PingRequest: _ping_handler,
} }
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
self.notification_options = NotificationOptions()
logger.debug(f"Initializing server '{name}'") logger.debug(f"Initializing server '{name}'")
def create_initialization_options(self) -> types.InitializationOptions: def create_initialization_options(
self,
notification_options: NotificationOptions | None = None,
experimental_capabilities: dict[str, dict[str, Any]] | None = None,
) -> types.InitializationOptions:
"""Create initialization options from this server instance.""" """Create initialization options from this server instance."""
def pkg_version(package: str) -> str: def pkg_version(package: str) -> str:
@@ -81,20 +102,51 @@ class Server:
return types.InitializationOptions( return types.InitializationOptions(
server_name=self.name, server_name=self.name,
server_version=pkg_version("mcp_python"), server_version=pkg_version("mcp_python"),
capabilities=self.get_capabilities(), capabilities=self.get_capabilities(
notification_options or NotificationOptions(),
experimental_capabilities or {},
),
) )
def get_capabilities(self) -> ServerCapabilities: def get_capabilities(
self,
notification_options: NotificationOptions,
experimental_capabilities: dict[str, dict[str, Any]],
) -> ServerCapabilities:
"""Convert existing handlers to a ServerCapabilities object.""" """Convert existing handlers to a ServerCapabilities object."""
prompts_capability = None
resources_capability = None
tools_capability = None
logging_capability = None
def get_capability(req_type: type) -> dict[str, Any] | None: # Set prompt capabilities if handler exists
return {} if req_type in self.request_handlers else None if ListPromptsRequest in self.request_handlers:
prompts_capability = PromptsCapability(
listChanged=notification_options.prompts_changed
)
# Set resource capabilities if handler exists
if ListResourcesRequest in self.request_handlers:
resources_capability = ResourcesCapability(
subscribe=False, listChanged=notification_options.resources_changed
)
# Set tool capabilities if handler exists
if ListToolsRequest in self.request_handlers:
tools_capability = ToolsCapability(
listChanged=notification_options.tools_changed
)
# Set logging capabilities if handler exists
if SetLevelRequest in self.request_handlers:
logging_capability = LoggingCapability()
return ServerCapabilities( return ServerCapabilities(
prompts=get_capability(ListPromptsRequest), prompts=prompts_capability,
resources=get_capability(ListResourcesRequest), resources=resources_capability,
tools=get_capability(ListToolsRequest), tools=tools_capability,
logging=get_capability(SetLevelRequest), logging=logging_capability,
experimental=experimental_capabilities,
) )
@property @property

View File

@@ -184,30 +184,76 @@ class Implementation(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
class RootsCapability(BaseModel):
"""Capability for root operations."""
listChanged: bool | None = None
"""Whether the client supports notifications for changes to the roots list."""
model_config = ConfigDict(extra="allow")
class SamplingCapability(BaseModel):
"""Capability for logging operations."""
model_config = ConfigDict(extra="allow")
class ClientCapabilities(BaseModel): class ClientCapabilities(BaseModel):
"""Capabilities a client may support.""" """Capabilities a client may support."""
experimental: dict[str, dict[str, Any]] | None = None experimental: dict[str, dict[str, Any]] | None = None
"""Experimental, non-standard capabilities that the client supports.""" """Experimental, non-standard capabilities that the client supports."""
sampling: dict[str, Any] | None = None sampling: SamplingCapability | None = None
"""Present if the client supports sampling from an LLM.""" """Present if the client supports sampling from an LLM."""
roots: dict[str, Any] | None = None roots: RootsCapability | None = None
"""Present if the client supports listing roots.""" """Present if the client supports listing roots."""
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
class PromptsCapability(BaseModel):
"""Capability for prompts operations."""
listChanged: bool | None = None
"""Whether this server supports notifications for changes to the prompt list."""
model_config = ConfigDict(extra="allow")
class ResourcesCapability(BaseModel):
"""Capability for resources operations."""
subscribe: bool | None = None
"""Whether this server supports subscribing to resource updates."""
listChanged: bool | None = None
"""Whether this server supports notifications for changes to the resource list."""
model_config = ConfigDict(extra="allow")
class ToolsCapability(BaseModel):
"""Capability for tools operations."""
listChanged: bool | None = None
"""Whether this server supports notifications for changes to the tool list."""
model_config = ConfigDict(extra="allow")
class LoggingCapability(BaseModel):
"""Capability for logging operations."""
model_config = ConfigDict(extra="allow")
class ServerCapabilities(BaseModel): class ServerCapabilities(BaseModel):
"""Capabilities that a server may support.""" """Capabilities that a server may support."""
experimental: dict[str, dict[str, Any]] | None = None experimental: dict[str, dict[str, Any]] | None = None
"""Experimental, non-standard capabilities that the server supports.""" """Experimental, non-standard capabilities that the server supports."""
logging: dict[str, Any] | None = None logging: LoggingCapability | None = None
"""Present if the server supports sending log messages to the client.""" """Present if the server supports sending log messages to the client."""
prompts: dict[str, Any] | None = None prompts: PromptsCapability | None = None
"""Present if the server offers any prompt templates.""" """Present if the server offers any prompt templates."""
resources: dict[str, Any] | None = None resources: ResourcesCapability | None = None
"""Present if the server offers any resources to read.""" """Present if the server offers any resources to read."""
tools: dict[str, Any] | None = None tools: ToolsCapability | None = None
"""Present if the server offers any tools to call.""" """Present if the server offers any tools to call."""
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")

View File

@@ -2,13 +2,15 @@ import anyio
import pytest import pytest
from mcp_python.client.session import ClientSession from mcp_python.client.session import ClientSession
from mcp_python.server import Server from mcp_python.server import NotificationOptions, Server
from mcp_python.server.session import ServerSession from mcp_python.server.session import ServerSession
from mcp_python.server.types import InitializationOptions from mcp_python.server.types import InitializationOptions
from mcp_python.types import ( from mcp_python.types import (
ClientNotification, ClientNotification,
InitializedNotification, InitializedNotification,
JSONRPCMessage, JSONRPCMessage,
PromptsCapability,
ResourcesCapability,
ServerCapabilities, ServerCapabilities,
) )
@@ -71,9 +73,11 @@ async def test_server_session_initialize():
@pytest.mark.anyio @pytest.mark.anyio
async def test_server_capabilities(): async def test_server_capabilities():
server = Server("test") server = Server("test")
notification_options = NotificationOptions()
experimental_capabilities = {}
# Initially no capabilities # Initially no capabilities
caps = server.get_capabilities() caps = server.get_capabilities(notification_options, experimental_capabilities)
assert caps.prompts is None assert caps.prompts is None
assert caps.resources is None assert caps.resources is None
@@ -82,8 +86,8 @@ async def test_server_capabilities():
async def list_prompts(): async def list_prompts():
return [] return []
caps = server.get_capabilities() caps = server.get_capabilities(notification_options, experimental_capabilities)
assert caps.prompts == {} assert caps.prompts == PromptsCapability(listChanged=False)
assert caps.resources is None assert caps.resources is None
# Add a resources handler # Add a resources handler
@@ -91,6 +95,6 @@ async def test_server_capabilities():
async def list_resources(): async def list_resources():
return [] return []
caps = server.get_capabilities() caps = server.get_capabilities(notification_options, experimental_capabilities)
assert caps.prompts == {} assert caps.prompts == PromptsCapability(listChanged=False)
assert caps.resources == {} assert caps.resources == ResourcesCapability(subscribe=False, listChanged=False)