Merge pull request #30 from modelcontextprotocol/davidsp/capabilities

feat: add structured capability types
This commit is contained in:
David Soria Parra
2024-11-07 14:51:20 +00:00
committed by GitHub
6 changed files with 150 additions and 29 deletions

View File

@@ -33,10 +33,13 @@ from .types import (
Notification, Notification,
PingRequest, PingRequest,
ProgressNotification, ProgressNotification,
PromptsCapability,
ReadResourceRequest, ReadResourceRequest,
ReadResourceResult, ReadResourceResult,
Resource, Resource,
ResourcesCapability,
ResourceUpdatedNotification, ResourceUpdatedNotification,
RootsCapability,
SamplingMessage, SamplingMessage,
ServerCapabilities, ServerCapabilities,
ServerNotification, ServerNotification,
@@ -46,6 +49,7 @@ from .types import (
StopReason, StopReason,
SubscribeRequest, SubscribeRequest,
Tool, Tool,
ToolsCapability,
UnsubscribeRequest, UnsubscribeRequest,
) )
from .types import ( from .types import (
@@ -82,10 +86,13 @@ __all__ = [
"Notification", "Notification",
"PingRequest", "PingRequest",
"ProgressNotification", "ProgressNotification",
"PromptsCapability",
"ReadResourceRequest", "ReadResourceRequest",
"ReadResourceResult", "ReadResourceResult",
"ResourcesCapability",
"ResourceUpdatedNotification", "ResourceUpdatedNotification",
"Resource", "Resource",
"RootsCapability",
"SamplingMessage", "SamplingMessage",
"SamplingRole", "SamplingRole",
"ServerCapabilities", "ServerCapabilities",
@@ -98,6 +105,7 @@ __all__ = [
"StopReason", "StopReason",
"SubscribeRequest", "SubscribeRequest",
"Tool", "Tool",
"ToolsCapability",
"UnsubscribeRequest", "UnsubscribeRequest",
"stdio_client", "stdio_client",
"stdio_server", "stdio_server",

View File

@@ -26,6 +26,7 @@ from mcp_python.types import (
PromptReference, PromptReference,
ReadResourceResult, ReadResourceResult,
ResourceReference, ResourceReference,
RootsCapability,
ServerNotification, ServerNotification,
ServerRequest, ServerRequest,
) )
@@ -69,12 +70,12 @@ class ClientSession(
capabilities=ClientCapabilities( capabilities=ClientCapabilities(
sampling=None, sampling=None,
experimental=None, experimental=None,
roots={ roots=RootsCapability(
# TODO: Should this be based on whether we # TODO: Should this be based on whether we
# _will_ send notifications, or only whether # _will_ send notifications, or only whether
# they're supported? # they're supported?
"listChanged": True listChanged=True
}, ),
), ),
clientInfo=Implementation(name="mcp_python", version="0.1.0"), clientInfo=Implementation(name="mcp_python", version="0.1.0"),
), ),

View File

@@ -12,9 +12,19 @@ from mcp_python.types import JSONRPCMessage
# Environment variables to inherit by default # Environment variables to inherit by default
DEFAULT_INHERITED_ENV_VARS = ( DEFAULT_INHERITED_ENV_VARS = (
["APPDATA", "HOMEDRIVE", "HOMEPATH", "LOCALAPPDATA", "PATH", [
"PROCESSOR_ARCHITECTURE", "SYSTEMDRIVE", "SYSTEMROOT", "TEMP", "APPDATA",
"USERNAME", "USERPROFILE"] "HOMEDRIVE",
"HOMEPATH",
"LOCALAPPDATA",
"PATH",
"PROCESSOR_ARCHITECTURE",
"SYSTEMDRIVE",
"SYSTEMROOT",
"TEMP",
"USERNAME",
"USERPROFILE",
]
if sys.platform == "win32" if sys.platform == "win32"
else ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"] else ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"]
) )
@@ -74,7 +84,7 @@ async def stdio_client(server: StdioServerParameters):
process = await anyio.open_process( process = await anyio.open_process(
[server.command, *server.args], [server.command, *server.args],
env=server.env if server.env is not None else get_default_environment(), env=server.env if server.env is not None else get_default_environment(),
stderr=sys.stderr stderr=sys.stderr,
) )
async def stdout_reader(): async def stdout_reader():

View File

@@ -28,22 +28,26 @@ from mcp_python.types import (
ListResourcesResult, ListResourcesResult,
ListToolsRequest, ListToolsRequest,
ListToolsResult, ListToolsResult,
LoggingCapability,
LoggingLevel, LoggingLevel,
PingRequest, PingRequest,
ProgressNotification, ProgressNotification,
Prompt, Prompt,
PromptMessage, PromptMessage,
PromptReference, PromptReference,
PromptsCapability,
ReadResourceRequest, ReadResourceRequest,
ReadResourceResult, ReadResourceResult,
Resource, Resource,
ResourceReference, ResourceReference,
ResourcesCapability,
ServerCapabilities, ServerCapabilities,
ServerResult, ServerResult,
SetLevelRequest, SetLevelRequest,
SubscribeRequest, SubscribeRequest,
TextContent, TextContent,
Tool, Tool,
ToolsCapability,
UnsubscribeRequest, UnsubscribeRequest,
) )
@@ -54,6 +58,18 @@ request_ctx: contextvars.ContextVar[RequestContext] = contextvars.ContextVar(
) )
class NotificationOptions:
def __init__(
self,
prompts_changed: bool = False,
resources_changed: bool = False,
tools_changed: bool = False,
):
self.prompts_changed = prompts_changed
self.resources_changed = resources_changed
self.tools_changed = tools_changed
class Server: class Server:
def __init__(self, name: str): def __init__(self, name: str):
self.name = name self.name = name
@@ -61,9 +77,14 @@ class Server:
PingRequest: _ping_handler, PingRequest: _ping_handler,
} }
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
self.notification_options = NotificationOptions()
logger.debug(f"Initializing server '{name}'") logger.debug(f"Initializing server '{name}'")
def create_initialization_options(self) -> types.InitializationOptions: def create_initialization_options(
self,
notification_options: NotificationOptions | None = None,
experimental_capabilities: dict[str, dict[str, Any]] | None = None,
) -> types.InitializationOptions:
"""Create initialization options from this server instance.""" """Create initialization options from this server instance."""
def pkg_version(package: str) -> str: def pkg_version(package: str) -> str:
@@ -81,20 +102,51 @@ class Server:
return types.InitializationOptions( return types.InitializationOptions(
server_name=self.name, server_name=self.name,
server_version=pkg_version("mcp_python"), server_version=pkg_version("mcp_python"),
capabilities=self.get_capabilities(), capabilities=self.get_capabilities(
notification_options or NotificationOptions(),
experimental_capabilities or {},
),
) )
def get_capabilities(self) -> ServerCapabilities: def get_capabilities(
self,
notification_options: NotificationOptions,
experimental_capabilities: dict[str, dict[str, Any]],
) -> ServerCapabilities:
"""Convert existing handlers to a ServerCapabilities object.""" """Convert existing handlers to a ServerCapabilities object."""
prompts_capability = None
resources_capability = None
tools_capability = None
logging_capability = None
def get_capability(req_type: type) -> dict[str, Any] | None: # Set prompt capabilities if handler exists
return {} if req_type in self.request_handlers else None if ListPromptsRequest in self.request_handlers:
prompts_capability = PromptsCapability(
listChanged=notification_options.prompts_changed
)
# Set resource capabilities if handler exists
if ListResourcesRequest in self.request_handlers:
resources_capability = ResourcesCapability(
subscribe=False, listChanged=notification_options.resources_changed
)
# Set tool capabilities if handler exists
if ListToolsRequest in self.request_handlers:
tools_capability = ToolsCapability(
listChanged=notification_options.tools_changed
)
# Set logging capabilities if handler exists
if SetLevelRequest in self.request_handlers:
logging_capability = LoggingCapability()
return ServerCapabilities( return ServerCapabilities(
prompts=get_capability(ListPromptsRequest), prompts=prompts_capability,
resources=get_capability(ListResourcesRequest), resources=resources_capability,
tools=get_capability(ListToolsRequest), tools=tools_capability,
logging=get_capability(SetLevelRequest), logging=logging_capability,
experimental=experimental_capabilities,
) )
@property @property

View File

@@ -184,30 +184,76 @@ class Implementation(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
class RootsCapability(BaseModel):
"""Capability for root operations."""
listChanged: bool | None = None
"""Whether the client supports notifications for changes to the roots list."""
model_config = ConfigDict(extra="allow")
class SamplingCapability(BaseModel):
"""Capability for logging operations."""
model_config = ConfigDict(extra="allow")
class ClientCapabilities(BaseModel): class ClientCapabilities(BaseModel):
"""Capabilities a client may support.""" """Capabilities a client may support."""
experimental: dict[str, dict[str, Any]] | None = None experimental: dict[str, dict[str, Any]] | None = None
"""Experimental, non-standard capabilities that the client supports.""" """Experimental, non-standard capabilities that the client supports."""
sampling: dict[str, Any] | None = None sampling: SamplingCapability | None = None
"""Present if the client supports sampling from an LLM.""" """Present if the client supports sampling from an LLM."""
roots: dict[str, Any] | None = None roots: RootsCapability | None = None
"""Present if the client supports listing roots.""" """Present if the client supports listing roots."""
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
class PromptsCapability(BaseModel):
"""Capability for prompts operations."""
listChanged: bool | None = None
"""Whether this server supports notifications for changes to the prompt list."""
model_config = ConfigDict(extra="allow")
class ResourcesCapability(BaseModel):
"""Capability for resources operations."""
subscribe: bool | None = None
"""Whether this server supports subscribing to resource updates."""
listChanged: bool | None = None
"""Whether this server supports notifications for changes to the resource list."""
model_config = ConfigDict(extra="allow")
class ToolsCapability(BaseModel):
"""Capability for tools operations."""
listChanged: bool | None = None
"""Whether this server supports notifications for changes to the tool list."""
model_config = ConfigDict(extra="allow")
class LoggingCapability(BaseModel):
"""Capability for logging operations."""
model_config = ConfigDict(extra="allow")
class ServerCapabilities(BaseModel): class ServerCapabilities(BaseModel):
"""Capabilities that a server may support.""" """Capabilities that a server may support."""
experimental: dict[str, dict[str, Any]] | None = None experimental: dict[str, dict[str, Any]] | None = None
"""Experimental, non-standard capabilities that the server supports.""" """Experimental, non-standard capabilities that the server supports."""
logging: dict[str, Any] | None = None logging: LoggingCapability | None = None
"""Present if the server supports sending log messages to the client.""" """Present if the server supports sending log messages to the client."""
prompts: dict[str, Any] | None = None prompts: PromptsCapability | None = None
"""Present if the server offers any prompt templates.""" """Present if the server offers any prompt templates."""
resources: dict[str, Any] | None = None resources: ResourcesCapability | None = None
"""Present if the server offers any resources to read.""" """Present if the server offers any resources to read."""
tools: dict[str, Any] | None = None tools: ToolsCapability | None = None
"""Present if the server offers any tools to call.""" """Present if the server offers any tools to call."""
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")

View File

@@ -2,13 +2,15 @@ import anyio
import pytest import pytest
from mcp_python.client.session import ClientSession from mcp_python.client.session import ClientSession
from mcp_python.server import Server from mcp_python.server import NotificationOptions, Server
from mcp_python.server.session import ServerSession from mcp_python.server.session import ServerSession
from mcp_python.server.types import InitializationOptions from mcp_python.server.types import InitializationOptions
from mcp_python.types import ( from mcp_python.types import (
ClientNotification, ClientNotification,
InitializedNotification, InitializedNotification,
JSONRPCMessage, JSONRPCMessage,
PromptsCapability,
ResourcesCapability,
ServerCapabilities, ServerCapabilities,
) )
@@ -71,9 +73,11 @@ async def test_server_session_initialize():
@pytest.mark.anyio @pytest.mark.anyio
async def test_server_capabilities(): async def test_server_capabilities():
server = Server("test") server = Server("test")
notification_options = NotificationOptions()
experimental_capabilities = {}
# Initially no capabilities # Initially no capabilities
caps = server.get_capabilities() caps = server.get_capabilities(notification_options, experimental_capabilities)
assert caps.prompts is None assert caps.prompts is None
assert caps.resources is None assert caps.resources is None
@@ -82,8 +86,8 @@ async def test_server_capabilities():
async def list_prompts(): async def list_prompts():
return [] return []
caps = server.get_capabilities() caps = server.get_capabilities(notification_options, experimental_capabilities)
assert caps.prompts == {} assert caps.prompts == PromptsCapability(listChanged=False)
assert caps.resources is None assert caps.resources is None
# Add a resources handler # Add a resources handler
@@ -91,6 +95,6 @@ async def test_server_capabilities():
async def list_resources(): async def list_resources():
return [] return []
caps = server.get_capabilities() caps = server.get_capabilities(notification_options, experimental_capabilities)
assert caps.prompts == {} assert caps.prompts == PromptsCapability(listChanged=False)
assert caps.resources == {} assert caps.resources == ResourcesCapability(subscribe=False, listChanged=False)