Merge pull request #28 from modelcontextprotocol/justin/spec-updates

Update to spec version 2024-11-05
This commit is contained in:
Justin Spahr-Summers
2024-11-07 11:45:36 +00:00
committed by GitHub
8 changed files with 381 additions and 40 deletions

View File

@@ -12,14 +12,20 @@ from mcp_python.types import (
ClientNotification, ClientNotification,
ClientRequest, ClientRequest,
ClientResult, ClientResult,
CompleteResult,
EmptyResult, EmptyResult,
GetPromptResult,
Implementation, Implementation,
InitializedNotification, InitializedNotification,
InitializeResult, InitializeResult,
JSONRPCMessage, JSONRPCMessage,
ListPromptsResult,
ListResourcesResult, ListResourcesResult,
ListToolsResult,
LoggingLevel, LoggingLevel,
PromptReference,
ReadResourceResult, ReadResourceResult,
ResourceReference,
ServerNotification, ServerNotification,
ServerRequest, ServerRequest,
) )
@@ -61,7 +67,14 @@ class ClientSession(
params=InitializeRequestParams( params=InitializeRequestParams(
protocolVersion=LATEST_PROTOCOL_VERSION, protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=ClientCapabilities( capabilities=ClientCapabilities(
sampling=None, experimental=None sampling=None,
experimental=None,
roots={
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
"listChanged": True
},
), ),
clientInfo=Implementation(name="mcp_python", version="0.1.0"), clientInfo=Implementation(name="mcp_python", version="0.1.0"),
), ),
@@ -220,3 +233,80 @@ class ClientSession(
), ),
CallToolResult, CallToolResult,
) )
async def list_prompts(self) -> ListPromptsResult:
"""Send a prompts/list request."""
from mcp_python.types import ListPromptsRequest
return await self.send_request(
ClientRequest(
ListPromptsRequest(
method="prompts/list",
)
),
ListPromptsResult,
)
async def get_prompt(
self, name: str, arguments: dict[str, str] | None = None
) -> GetPromptResult:
"""Send a prompts/get request."""
from mcp_python.types import GetPromptRequest, GetPromptRequestParams
return await self.send_request(
ClientRequest(
GetPromptRequest(
method="prompts/get",
params=GetPromptRequestParams(name=name, arguments=arguments),
)
),
GetPromptResult,
)
async def complete(
self, ref: ResourceReference | PromptReference, argument: dict
) -> CompleteResult:
"""Send a completion/complete request."""
from mcp_python.types import (
CompleteRequest,
CompleteRequestParams,
CompletionArgument,
)
return await self.send_request(
ClientRequest(
CompleteRequest(
method="completion/complete",
params=CompleteRequestParams(
ref=ref,
argument=CompletionArgument(**argument),
),
)
),
CompleteResult,
)
async def list_tools(self) -> ListToolsResult:
"""Send a tools/list request."""
from mcp_python.types import ListToolsRequest
return await self.send_request(
ClientRequest(
ListToolsRequest(
method="tools/list",
)
),
ListToolsResult,
)
async def send_roots_list_changed(self) -> None:
"""Send a roots/list_changed notification."""
from mcp_python.types import RootsListChangedNotification
await self.send_notification(
ClientNotification(
RootsListChangedNotification(
method="notifications/roots/list_changed",
)
)
)

View File

@@ -18,6 +18,7 @@ from mcp_python.types import (
ClientNotification, ClientNotification,
ClientRequest, ClientRequest,
CompleteRequest, CompleteRequest,
EmbeddedResource,
EmptyResult, EmptyResult,
ErrorData, ErrorData,
JSONRPCMessage, JSONRPCMessage,
@@ -31,6 +32,7 @@ from mcp_python.types import (
PingRequest, PingRequest,
ProgressNotification, ProgressNotification,
Prompt, Prompt,
PromptMessage,
PromptReference, PromptReference,
ReadResourceRequest, ReadResourceRequest,
ReadResourceResult, ReadResourceResult,
@@ -40,6 +42,7 @@ from mcp_python.types import (
ServerResult, ServerResult,
SetLevelRequest, SetLevelRequest,
SubscribeRequest, SubscribeRequest,
TextContent,
Tool, Tool,
UnsubscribeRequest, UnsubscribeRequest,
) )
@@ -117,8 +120,6 @@ class Server:
GetPromptRequest, GetPromptRequest,
GetPromptResult, GetPromptResult,
ImageContent, ImageContent,
SamplingMessage,
TextContent,
) )
from mcp_python.types import ( from mcp_python.types import (
Role as Role, Role as Role,
@@ -133,7 +134,7 @@ class Server:
async def handler(req: GetPromptRequest): async def handler(req: GetPromptRequest):
prompt_get = await func(req.params.name, req.params.arguments) prompt_get = await func(req.params.name, req.params.arguments)
messages: list[SamplingMessage] = [] messages: list[PromptMessage] = []
for message in prompt_get.messages: for message in prompt_get.messages:
match message.content: match message.content:
case str() as text_content: case str() as text_content:
@@ -144,15 +145,17 @@ class Server:
data=img_content.data, data=img_content.data,
mimeType=img_content.mime_type, mimeType=img_content.mime_type,
) )
case types.EmbeddedResource() as resource:
content = EmbeddedResource(
type="resource", resource=resource.resource
)
case _: case _:
raise ValueError( raise ValueError(
f"Unexpected content type: {type(message.content)}" f"Unexpected content type: {type(message.content)}"
) )
sampling_message = SamplingMessage( prompt_message = PromptMessage(role=message.role, content=content)
role=message.role, content=content messages.append(prompt_message)
)
messages.append(sampling_message)
return ServerResult( return ServerResult(
GetPromptResult(description=prompt_get.desc, messages=messages) GetPromptResult(description=prompt_get.desc, messages=messages)
@@ -169,9 +172,7 @@ class Server:
async def handler(_: Any): async def handler(_: Any):
resources = await func() resources = await func()
return ServerResult( return ServerResult(ListResourcesResult(resources=resources))
ListResourcesResult(resources=resources)
)
self.request_handlers[ListResourcesRequest] = handler self.request_handlers[ListResourcesRequest] = handler
return func return func
@@ -216,7 +217,6 @@ class Server:
return decorator return decorator
def set_logging_level(self): def set_logging_level(self):
from mcp_python.types import EmptyResult from mcp_python.types import EmptyResult
@@ -276,14 +276,51 @@ class Server:
return decorator return decorator
def call_tool(self): def call_tool(self):
from mcp_python.types import CallToolResult from mcp_python.types import (
CallToolResult,
EmbeddedResource,
ImageContent,
TextContent,
)
def decorator(func: Callable[..., Awaitable[Any]]): def decorator(
func: Callable[
..., Awaitable[list[str | types.ImageContent | types.EmbeddedResource]]
],
):
logger.debug("Registering handler for CallToolRequest") logger.debug("Registering handler for CallToolRequest")
async def handler(req: CallToolRequest): async def handler(req: CallToolRequest):
result = await func(req.params.name, (req.params.arguments or {})) try:
return ServerResult(CallToolResult(toolResult=result)) results = await func(req.params.name, (req.params.arguments or {}))
content = []
for result in results:
match result:
case str() as text:
content.append(TextContent(type="text", text=text))
case types.ImageContent() as img:
content.append(
ImageContent(
type="image",
data=img.data,
mimeType=img.mime_type,
)
)
case types.EmbeddedResource() as resource:
content.append(
EmbeddedResource(
type="resource", resource=resource.resource
)
)
return ServerResult(CallToolResult(content=content, isError=False))
except Exception as e:
return ServerResult(
CallToolResult(
content=[TextContent(type="text", text=str(e))],
isError=True,
)
)
self.request_handlers[CallToolRequest] = handler self.request_handlers[CallToolRequest] = handler
return func return func

View File

@@ -23,11 +23,16 @@ from mcp_python.types import (
InitializeRequest, InitializeRequest,
InitializeResult, InitializeResult,
JSONRPCMessage, JSONRPCMessage,
ListRootsResult,
LoggingLevel, LoggingLevel,
ModelPreferences,
PromptListChangedNotification,
ResourceListChangedNotification,
SamplingMessage, SamplingMessage,
ServerNotification, ServerNotification,
ServerRequest, ServerRequest,
ServerResult, ServerResult,
ToolListChangedNotification,
) )
@@ -132,7 +137,7 @@ class ServerSession(
) )
) )
async def request_create_message( async def create_message(
self, self,
messages: list[SamplingMessage], messages: list[SamplingMessage],
*, *,
@@ -142,6 +147,7 @@ class ServerSession(
temperature: float | None = None, temperature: float | None = None,
stop_sequences: list[str] | None = None, stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,
model_preferences: ModelPreferences | None = None,
) -> CreateMessageResult: ) -> CreateMessageResult:
"""Send a sampling/create_message request.""" """Send a sampling/create_message request."""
from mcp_python.types import ( from mcp_python.types import (
@@ -161,12 +167,26 @@ class ServerSession(
maxTokens=max_tokens, maxTokens=max_tokens,
stopSequences=stop_sequences, stopSequences=stop_sequences,
metadata=metadata, metadata=metadata,
modelPreferences=model_preferences,
), ),
) )
), ),
CreateMessageResult, CreateMessageResult,
) )
async def list_roots(self) -> ListRootsResult:
"""Send a roots/list request."""
from mcp_python.types import ListRootsRequest
return await self.send_request(
ServerRequest(
ListRootsRequest(
method="roots/list",
)
),
ListRootsResult,
)
async def send_ping(self) -> EmptyResult: async def send_ping(self) -> EmptyResult:
"""Send a ping request.""" """Send a ping request."""
from mcp_python.types import PingRequest from mcp_python.types import PingRequest
@@ -198,3 +218,33 @@ class ServerSession(
) )
) )
) )
async def send_resource_list_changed(self) -> None:
"""Send a resource list changed notification."""
await self.send_notification(
ServerNotification(
ResourceListChangedNotification(
method="notifications/resources/list_changed",
)
)
)
async def send_tool_list_changed(self) -> None:
"""Send a tool list changed notification."""
await self.send_notification(
ServerNotification(
ToolListChangedNotification(
method="notifications/tools/list_changed",
)
)
)
async def send_prompt_list_changed(self) -> None:
"""Send a prompt list changed notification."""
await self.send_notification(
ServerNotification(
PromptListChangedNotification(
method="notifications/prompts/list_changed",
)
)
)

View File

@@ -1,5 +1,6 @@
""" """
This module provides simpler types to use with the server for managing prompts. This module provides simpler types to use with the server for managing prompts
and tools.
""" """
from dataclasses import dataclass from dataclasses import dataclass
@@ -7,7 +8,12 @@ from typing import Literal
from pydantic import BaseModel from pydantic import BaseModel
from mcp_python.types import Role, ServerCapabilities from mcp_python.types import (
BlobResourceContents,
Role,
ServerCapabilities,
TextResourceContents,
)
@dataclass @dataclass
@@ -17,10 +23,15 @@ class ImageContent:
mime_type: str mime_type: str
@dataclass
class EmbeddedResource:
resource: TextResourceContents | BlobResourceContents
@dataclass @dataclass
class Message: class Message:
role: Role role: Role
content: str | ImageContent content: str | ImageContent | EmbeddedResource
@dataclass @dataclass

View File

@@ -15,14 +15,14 @@ from mcp_python.types import JSONRPCMessage
MessageStream = tuple[ MessageStream = tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception], MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage] MemoryObjectSendStream[JSONRPCMessage],
] ]
@asynccontextmanager @asynccontextmanager
async def create_client_server_memory_streams() -> AsyncGenerator[ async def create_client_server_memory_streams() -> (
tuple[MessageStream, MessageStream], AsyncGenerator[tuple[MessageStream, MessageStream], None]
None ):
]:
""" """
Creates a pair of bidirectional memory streams for client-server communication. Creates a pair of bidirectional memory streams for client-server communication.

View File

@@ -154,7 +154,8 @@ class BaseSession(
try: try:
with anyio.fail_after( with anyio.fail_after(
None if self._read_timeout_seconds is None None
if self._read_timeout_seconds is None
else self._read_timeout_seconds.total_seconds() else self._read_timeout_seconds.total_seconds()
): ):
response_or_error = await response_stream_reader.receive() response_or_error = await response_stream_reader.receive()
@@ -168,7 +169,6 @@ class BaseSession(
f"{self._read_timeout_seconds} seconds." f"{self._read_timeout_seconds} seconds."
), ),
) )
) )
if isinstance(response_or_error, JSONRPCError): if isinstance(response_or_error, JSONRPCError):

View File

@@ -1,12 +1,12 @@
from typing import Any, Generic, Literal, TypeVar from typing import Any, Generic, Literal, TypeVar
from pydantic import BaseModel, ConfigDict, RootModel from pydantic import BaseModel, ConfigDict, FileUrl, RootModel
from pydantic.networks import AnyUrl from pydantic.networks import AnyUrl
""" """
Model Context Protocol bindings for Python Model Context Protocol bindings for Python
These bindings were generated from https://github.com/anthropic-experimental/mcp-spec, These bindings were generated from https://github.com/modelcontextprotocol/specification,
using Claude, with a prompt something like the following: using Claude, with a prompt something like the following:
Generate idiomatic Python bindings for this schema for MCP, or the "Model Context Generate idiomatic Python bindings for this schema for MCP, or the "Model Context
@@ -21,7 +21,7 @@ for reference.
not separate types in the schema. not separate types in the schema.
""" """
LATEST_PROTOCOL_VERSION = "2024-10-07" LATEST_PROTOCOL_VERSION = "2024-11-05"
ProgressToken = str | int ProgressToken = str | int
Cursor = str Cursor = str
@@ -191,6 +191,8 @@ class ClientCapabilities(BaseModel):
"""Experimental, non-standard capabilities that the client supports.""" """Experimental, non-standard capabilities that the client supports."""
sampling: dict[str, Any] | None = None sampling: dict[str, Any] | None = None
"""Present if the client supports sampling from an LLM.""" """Present if the client supports sampling from an LLM."""
roots: dict[str, Any] | None = None
"""Present if the client supports listing roots."""
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@@ -556,12 +558,33 @@ class SamplingMessage(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
class EmbeddedResource(BaseModel):
"""
The contents of a resource, embedded into a prompt or tool call result.
It is up to the client how best to render embedded resources for the benefit
of the LLM and/or the user.
"""
type: Literal["resource"]
resource: TextResourceContents | BlobResourceContents
model_config = ConfigDict(extra="allow")
class PromptMessage(BaseModel):
"""Describes a message returned as part of a prompt."""
role: Role
content: TextContent | ImageContent | EmbeddedResource
model_config = ConfigDict(extra="allow")
class GetPromptResult(Result): class GetPromptResult(Result):
"""The server's response to a prompts/get request from the client.""" """The server's response to a prompts/get request from the client."""
description: str | None = None description: str | None = None
"""An optional description for the prompt.""" """An optional description for the prompt."""
messages: list[SamplingMessage] messages: list[PromptMessage]
class PromptListChangedNotification(Notification): class PromptListChangedNotification(Notification):
@@ -617,7 +640,8 @@ class CallToolRequest(Request):
class CallToolResult(Result): class CallToolResult(Result):
"""The server's response to a tool call.""" """The server's response to a tool call."""
toolResult: Any content: list[TextContent | ImageContent | EmbeddedResource]
isError: bool
class ToolListChangedNotification(Notification): class ToolListChangedNotification(Notification):
@@ -630,7 +654,9 @@ class ToolListChangedNotification(Notification):
params: NotificationParams | None = None params: NotificationParams | None = None
LoggingLevel = Literal["debug", "info", "warning", "error"] LoggingLevel = Literal[
"debug", "info", "notice", "warning", "error", "critical", "alert", "emergency"
]
class SetLevelRequestParams(RequestParams): class SetLevelRequestParams(RequestParams):
@@ -673,10 +699,75 @@ class LoggingMessageNotification(Notification):
IncludeContext = Literal["none", "thisServer", "allServers"] IncludeContext = Literal["none", "thisServer", "allServers"]
class ModelHint(BaseModel):
"""Hints to use for model selection."""
name: str | None = None
"""A hint for a model name."""
model_config = ConfigDict(extra="allow")
class ModelPreferences(BaseModel):
"""
The server's preferences for model selection, requested of the client during
sampling.
Because LLMs can vary along multiple dimensions, choosing the "best" model is
rarely straightforward. Different models excel in different areas—some are
faster but less capable, others are more capable but more expensive, and so
on. This interface allows servers to express their priorities across multiple
dimensions to help clients make an appropriate selection for their use case.
These preferences are always advisory. The client MAY ignore them. It is also
up to the client to decide how to interpret these preferences and how to
balance them against other considerations.
"""
hints: list[ModelHint] | None = None
"""
Optional hints to use for model selection.
If multiple hints are specified, the client MUST evaluate them in order
(such that the first match is taken).
The client SHOULD prioritize these hints over the numeric priorities, but
MAY still use the priorities to select from ambiguous matches.
"""
costPriority: float | None = None
"""
How much to prioritize cost when selecting a model. A value of 0 means cost
is not important, while a value of 1 means cost is the most important
factor.
"""
speedPriority: float | None = None
"""
How much to prioritize sampling speed (latency) when selecting a model. A
value of 0 means speed is not important, while a value of 1 means speed is
the most important factor.
"""
intelligencePriority: float | None = None
"""
How much to prioritize intelligence and capabilities when selecting a
model. A value of 0 means intelligence is not important, while a value of 1
means intelligence is the most important factor.
"""
model_config = ConfigDict(extra="allow")
class CreateMessageRequestParams(RequestParams): class CreateMessageRequestParams(RequestParams):
"""Parameters for creating a message.""" """Parameters for creating a message."""
messages: list[SamplingMessage] messages: list[SamplingMessage]
modelPreferences: ModelPreferences | None = None
"""
The server's preferences for which model to select. The client MAY ignore
these preferences.
"""
systemPrompt: str | None = None systemPrompt: str | None = None
"""An optional system prompt the server wants to use for sampling.""" """An optional system prompt the server wants to use for sampling."""
includeContext: IncludeContext | None = None includeContext: IncludeContext | None = None
@@ -700,7 +791,7 @@ class CreateMessageRequest(Request):
params: CreateMessageRequestParams params: CreateMessageRequestParams
StopReason = Literal["endTurn", "stopSequence", "maxTokens"] StopReason = Literal["endTurn", "stopSequence", "maxTokens"] | str
class CreateMessageResult(Result): class CreateMessageResult(Result):
@@ -710,8 +801,8 @@ class CreateMessageResult(Result):
content: TextContent | ImageContent content: TextContent | ImageContent
model: str model: str
"""The name of the model that generated the message.""" """The name of the model that generated the message."""
stopReason: StopReason stopReason: StopReason | None = None
"""The reason why sampling stopped.""" """The reason why sampling stopped, if known."""
class ResourceReference(BaseModel): class ResourceReference(BaseModel):
@@ -781,6 +872,63 @@ class CompleteResult(Result):
completion: Completion completion: Completion
class ListRootsRequest(Request):
"""
Sent from the server to request a list of root URIs from the client. Roots allow
servers to ask for specific directories or files to operate on. A common example
for roots is providing a set of repositories or directories a server should operate
on.
This request is typically used when the server needs to understand the file system
structure or access specific locations that the client has permission to read from.
"""
method: Literal["roots/list"]
params: RequestParams | None = None
class Root(BaseModel):
"""Represents a root directory or file that the server can operate on."""
uri: FileUrl
"""
The URI identifying the root. This *must* start with file:// for now.
This restriction may be relaxed in future versions of the protocol to allow
other URI schemes.
"""
name: str | None = None
"""
An optional name for the root. This can be used to provide a human-readable
identifier for the root, which may be useful for display purposes or for
referencing the root in other parts of the application.
"""
model_config = ConfigDict(extra="allow")
class ListRootsResult(Result):
"""
The client's response to a roots/list request from the server.
This result contains an array of Root objects, each representing a root directory
or file that the server can operate on.
"""
roots: list[Root]
class RootsListChangedNotification(Notification):
"""
A notification from the client to the server, informing it that the list of
roots has changed.
This notification should be sent whenever the client adds, removes, or
modifies any root. The server should then request an updated list of roots
using the ListRootsRequest.
"""
method: Literal["notifications/roots/list_changed"]
params: NotificationParams | None = None
class ClientRequest( class ClientRequest(
RootModel[ RootModel[
PingRequest PingRequest
@@ -801,15 +949,19 @@ class ClientRequest(
pass pass
class ClientNotification(RootModel[ProgressNotification | InitializedNotification]): class ClientNotification(
RootModel[
ProgressNotification | InitializedNotification | RootsListChangedNotification
]
):
pass pass
class ClientResult(RootModel[EmptyResult | CreateMessageResult]): class ClientResult(RootModel[EmptyResult | CreateMessageResult | ListRootsResult]):
pass pass
class ServerRequest(RootModel[PingRequest | CreateMessageRequest]): class ServerRequest(RootModel[PingRequest | CreateMessageRequest | ListRootsRequest]):
pass pass

View File

@@ -11,6 +11,7 @@ TEST_INITIALIZATION_OPTIONS = InitializationOptions(
capabilities=ServerCapabilities(), capabilities=ServerCapabilities(),
) )
@pytest.fixture @pytest.fixture
def mcp_server() -> Server: def mcp_server() -> Server:
server = Server(name="test_server") server = Server(name="test_server")
@@ -21,7 +22,7 @@ def mcp_server() -> Server:
Resource( Resource(
uri=AnyUrl("memory://test"), uri=AnyUrl("memory://test"),
name="Test Resource", name="Test Resource",
description="A test resource" description="A test resource",
) )
] ]