Merge pull request #28 from modelcontextprotocol/justin/spec-updates

Update to spec version 2024-11-05
This commit is contained in:
Justin Spahr-Summers
2024-11-07 11:45:36 +00:00
committed by GitHub
8 changed files with 381 additions and 40 deletions

View File

@@ -12,14 +12,20 @@ from mcp_python.types import (
ClientNotification,
ClientRequest,
ClientResult,
CompleteResult,
EmptyResult,
GetPromptResult,
Implementation,
InitializedNotification,
InitializeResult,
JSONRPCMessage,
ListPromptsResult,
ListResourcesResult,
ListToolsResult,
LoggingLevel,
PromptReference,
ReadResourceResult,
ResourceReference,
ServerNotification,
ServerRequest,
)
@@ -61,7 +67,14 @@ class ClientSession(
params=InitializeRequestParams(
protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=ClientCapabilities(
sampling=None, experimental=None
sampling=None,
experimental=None,
roots={
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
"listChanged": True
},
),
clientInfo=Implementation(name="mcp_python", version="0.1.0"),
),
@@ -220,3 +233,80 @@ class ClientSession(
),
CallToolResult,
)
async def list_prompts(self) -> ListPromptsResult:
"""Send a prompts/list request."""
from mcp_python.types import ListPromptsRequest
return await self.send_request(
ClientRequest(
ListPromptsRequest(
method="prompts/list",
)
),
ListPromptsResult,
)
async def get_prompt(
self, name: str, arguments: dict[str, str] | None = None
) -> GetPromptResult:
"""Send a prompts/get request."""
from mcp_python.types import GetPromptRequest, GetPromptRequestParams
return await self.send_request(
ClientRequest(
GetPromptRequest(
method="prompts/get",
params=GetPromptRequestParams(name=name, arguments=arguments),
)
),
GetPromptResult,
)
async def complete(
self, ref: ResourceReference | PromptReference, argument: dict
) -> CompleteResult:
"""Send a completion/complete request."""
from mcp_python.types import (
CompleteRequest,
CompleteRequestParams,
CompletionArgument,
)
return await self.send_request(
ClientRequest(
CompleteRequest(
method="completion/complete",
params=CompleteRequestParams(
ref=ref,
argument=CompletionArgument(**argument),
),
)
),
CompleteResult,
)
async def list_tools(self) -> ListToolsResult:
"""Send a tools/list request."""
from mcp_python.types import ListToolsRequest
return await self.send_request(
ClientRequest(
ListToolsRequest(
method="tools/list",
)
),
ListToolsResult,
)
async def send_roots_list_changed(self) -> None:
"""Send a roots/list_changed notification."""
from mcp_python.types import RootsListChangedNotification
await self.send_notification(
ClientNotification(
RootsListChangedNotification(
method="notifications/roots/list_changed",
)
)
)

View File

@@ -18,6 +18,7 @@ from mcp_python.types import (
ClientNotification,
ClientRequest,
CompleteRequest,
EmbeddedResource,
EmptyResult,
ErrorData,
JSONRPCMessage,
@@ -31,6 +32,7 @@ from mcp_python.types import (
PingRequest,
ProgressNotification,
Prompt,
PromptMessage,
PromptReference,
ReadResourceRequest,
ReadResourceResult,
@@ -40,6 +42,7 @@ from mcp_python.types import (
ServerResult,
SetLevelRequest,
SubscribeRequest,
TextContent,
Tool,
UnsubscribeRequest,
)
@@ -117,8 +120,6 @@ class Server:
GetPromptRequest,
GetPromptResult,
ImageContent,
SamplingMessage,
TextContent,
)
from mcp_python.types import (
Role as Role,
@@ -133,7 +134,7 @@ class Server:
async def handler(req: GetPromptRequest):
prompt_get = await func(req.params.name, req.params.arguments)
messages: list[SamplingMessage] = []
messages: list[PromptMessage] = []
for message in prompt_get.messages:
match message.content:
case str() as text_content:
@@ -144,15 +145,17 @@ class Server:
data=img_content.data,
mimeType=img_content.mime_type,
)
case types.EmbeddedResource() as resource:
content = EmbeddedResource(
type="resource", resource=resource.resource
)
case _:
raise ValueError(
f"Unexpected content type: {type(message.content)}"
)
sampling_message = SamplingMessage(
role=message.role, content=content
)
messages.append(sampling_message)
prompt_message = PromptMessage(role=message.role, content=content)
messages.append(prompt_message)
return ServerResult(
GetPromptResult(description=prompt_get.desc, messages=messages)
@@ -169,9 +172,7 @@ class Server:
async def handler(_: Any):
resources = await func()
return ServerResult(
ListResourcesResult(resources=resources)
)
return ServerResult(ListResourcesResult(resources=resources))
self.request_handlers[ListResourcesRequest] = handler
return func
@@ -216,7 +217,6 @@ class Server:
return decorator
def set_logging_level(self):
from mcp_python.types import EmptyResult
@@ -276,14 +276,51 @@ class Server:
return decorator
def call_tool(self):
from mcp_python.types import CallToolResult
from mcp_python.types import (
CallToolResult,
EmbeddedResource,
ImageContent,
TextContent,
)
def decorator(func: Callable[..., Awaitable[Any]]):
def decorator(
func: Callable[
..., Awaitable[list[str | types.ImageContent | types.EmbeddedResource]]
],
):
logger.debug("Registering handler for CallToolRequest")
async def handler(req: CallToolRequest):
result = await func(req.params.name, (req.params.arguments or {}))
return ServerResult(CallToolResult(toolResult=result))
try:
results = await func(req.params.name, (req.params.arguments or {}))
content = []
for result in results:
match result:
case str() as text:
content.append(TextContent(type="text", text=text))
case types.ImageContent() as img:
content.append(
ImageContent(
type="image",
data=img.data,
mimeType=img.mime_type,
)
)
case types.EmbeddedResource() as resource:
content.append(
EmbeddedResource(
type="resource", resource=resource.resource
)
)
return ServerResult(CallToolResult(content=content, isError=False))
except Exception as e:
return ServerResult(
CallToolResult(
content=[TextContent(type="text", text=str(e))],
isError=True,
)
)
self.request_handlers[CallToolRequest] = handler
return func

View File

@@ -23,11 +23,16 @@ from mcp_python.types import (
InitializeRequest,
InitializeResult,
JSONRPCMessage,
ListRootsResult,
LoggingLevel,
ModelPreferences,
PromptListChangedNotification,
ResourceListChangedNotification,
SamplingMessage,
ServerNotification,
ServerRequest,
ServerResult,
ToolListChangedNotification,
)
@@ -132,7 +137,7 @@ class ServerSession(
)
)
async def request_create_message(
async def create_message(
self,
messages: list[SamplingMessage],
*,
@@ -142,6 +147,7 @@ class ServerSession(
temperature: float | None = None,
stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None,
model_preferences: ModelPreferences | None = None,
) -> CreateMessageResult:
"""Send a sampling/create_message request."""
from mcp_python.types import (
@@ -161,12 +167,26 @@ class ServerSession(
maxTokens=max_tokens,
stopSequences=stop_sequences,
metadata=metadata,
modelPreferences=model_preferences,
),
)
),
CreateMessageResult,
)
async def list_roots(self) -> ListRootsResult:
"""Send a roots/list request."""
from mcp_python.types import ListRootsRequest
return await self.send_request(
ServerRequest(
ListRootsRequest(
method="roots/list",
)
),
ListRootsResult,
)
async def send_ping(self) -> EmptyResult:
"""Send a ping request."""
from mcp_python.types import PingRequest
@@ -198,3 +218,33 @@ class ServerSession(
)
)
)
async def send_resource_list_changed(self) -> None:
"""Send a resource list changed notification."""
await self.send_notification(
ServerNotification(
ResourceListChangedNotification(
method="notifications/resources/list_changed",
)
)
)
async def send_tool_list_changed(self) -> None:
"""Send a tool list changed notification."""
await self.send_notification(
ServerNotification(
ToolListChangedNotification(
method="notifications/tools/list_changed",
)
)
)
async def send_prompt_list_changed(self) -> None:
"""Send a prompt list changed notification."""
await self.send_notification(
ServerNotification(
PromptListChangedNotification(
method="notifications/prompts/list_changed",
)
)
)

View File

@@ -1,5 +1,6 @@
"""
This module provides simpler types to use with the server for managing prompts.
This module provides simpler types to use with the server for managing prompts
and tools.
"""
from dataclasses import dataclass
@@ -7,7 +8,12 @@ from typing import Literal
from pydantic import BaseModel
from mcp_python.types import Role, ServerCapabilities
from mcp_python.types import (
BlobResourceContents,
Role,
ServerCapabilities,
TextResourceContents,
)
@dataclass
@@ -17,10 +23,15 @@ class ImageContent:
mime_type: str
@dataclass
class EmbeddedResource:
resource: TextResourceContents | BlobResourceContents
@dataclass
class Message:
role: Role
content: str | ImageContent
content: str | ImageContent | EmbeddedResource
@dataclass

View File

@@ -15,14 +15,14 @@ from mcp_python.types import JSONRPCMessage
MessageStream = tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage]
MemoryObjectSendStream[JSONRPCMessage],
]
@asynccontextmanager
async def create_client_server_memory_streams() -> AsyncGenerator[
tuple[MessageStream, MessageStream],
None
]:
async def create_client_server_memory_streams() -> (
AsyncGenerator[tuple[MessageStream, MessageStream], None]
):
"""
Creates a pair of bidirectional memory streams for client-server communication.

View File

@@ -154,7 +154,8 @@ class BaseSession(
try:
with anyio.fail_after(
None if self._read_timeout_seconds is None
None
if self._read_timeout_seconds is None
else self._read_timeout_seconds.total_seconds()
):
response_or_error = await response_stream_reader.receive()
@@ -168,7 +169,6 @@ class BaseSession(
f"{self._read_timeout_seconds} seconds."
),
)
)
if isinstance(response_or_error, JSONRPCError):

View File

@@ -1,12 +1,12 @@
from typing import Any, Generic, Literal, TypeVar
from pydantic import BaseModel, ConfigDict, RootModel
from pydantic import BaseModel, ConfigDict, FileUrl, RootModel
from pydantic.networks import AnyUrl
"""
Model Context Protocol bindings for Python
These bindings were generated from https://github.com/anthropic-experimental/mcp-spec,
These bindings were generated from https://github.com/modelcontextprotocol/specification,
using Claude, with a prompt something like the following:
Generate idiomatic Python bindings for this schema for MCP, or the "Model Context
@@ -21,7 +21,7 @@ for reference.
not separate types in the schema.
"""
LATEST_PROTOCOL_VERSION = "2024-10-07"
LATEST_PROTOCOL_VERSION = "2024-11-05"
ProgressToken = str | int
Cursor = str
@@ -191,6 +191,8 @@ class ClientCapabilities(BaseModel):
"""Experimental, non-standard capabilities that the client supports."""
sampling: dict[str, Any] | None = None
"""Present if the client supports sampling from an LLM."""
roots: dict[str, Any] | None = None
"""Present if the client supports listing roots."""
model_config = ConfigDict(extra="allow")
@@ -556,12 +558,33 @@ class SamplingMessage(BaseModel):
model_config = ConfigDict(extra="allow")
class EmbeddedResource(BaseModel):
"""
The contents of a resource, embedded into a prompt or tool call result.
It is up to the client how best to render embedded resources for the benefit
of the LLM and/or the user.
"""
type: Literal["resource"]
resource: TextResourceContents | BlobResourceContents
model_config = ConfigDict(extra="allow")
class PromptMessage(BaseModel):
"""Describes a message returned as part of a prompt."""
role: Role
content: TextContent | ImageContent | EmbeddedResource
model_config = ConfigDict(extra="allow")
class GetPromptResult(Result):
"""The server's response to a prompts/get request from the client."""
description: str | None = None
"""An optional description for the prompt."""
messages: list[SamplingMessage]
messages: list[PromptMessage]
class PromptListChangedNotification(Notification):
@@ -617,7 +640,8 @@ class CallToolRequest(Request):
class CallToolResult(Result):
"""The server's response to a tool call."""
toolResult: Any
content: list[TextContent | ImageContent | EmbeddedResource]
isError: bool
class ToolListChangedNotification(Notification):
@@ -630,7 +654,9 @@ class ToolListChangedNotification(Notification):
params: NotificationParams | None = None
LoggingLevel = Literal["debug", "info", "warning", "error"]
LoggingLevel = Literal[
"debug", "info", "notice", "warning", "error", "critical", "alert", "emergency"
]
class SetLevelRequestParams(RequestParams):
@@ -673,10 +699,75 @@ class LoggingMessageNotification(Notification):
IncludeContext = Literal["none", "thisServer", "allServers"]
class ModelHint(BaseModel):
"""Hints to use for model selection."""
name: str | None = None
"""A hint for a model name."""
model_config = ConfigDict(extra="allow")
class ModelPreferences(BaseModel):
"""
The server's preferences for model selection, requested of the client during
sampling.
Because LLMs can vary along multiple dimensions, choosing the "best" model is
rarely straightforward. Different models excel in different areas—some are
faster but less capable, others are more capable but more expensive, and so
on. This interface allows servers to express their priorities across multiple
dimensions to help clients make an appropriate selection for their use case.
These preferences are always advisory. The client MAY ignore them. It is also
up to the client to decide how to interpret these preferences and how to
balance them against other considerations.
"""
hints: list[ModelHint] | None = None
"""
Optional hints to use for model selection.
If multiple hints are specified, the client MUST evaluate them in order
(such that the first match is taken).
The client SHOULD prioritize these hints over the numeric priorities, but
MAY still use the priorities to select from ambiguous matches.
"""
costPriority: float | None = None
"""
How much to prioritize cost when selecting a model. A value of 0 means cost
is not important, while a value of 1 means cost is the most important
factor.
"""
speedPriority: float | None = None
"""
How much to prioritize sampling speed (latency) when selecting a model. A
value of 0 means speed is not important, while a value of 1 means speed is
the most important factor.
"""
intelligencePriority: float | None = None
"""
How much to prioritize intelligence and capabilities when selecting a
model. A value of 0 means intelligence is not important, while a value of 1
means intelligence is the most important factor.
"""
model_config = ConfigDict(extra="allow")
class CreateMessageRequestParams(RequestParams):
"""Parameters for creating a message."""
messages: list[SamplingMessage]
modelPreferences: ModelPreferences | None = None
"""
The server's preferences for which model to select. The client MAY ignore
these preferences.
"""
systemPrompt: str | None = None
"""An optional system prompt the server wants to use for sampling."""
includeContext: IncludeContext | None = None
@@ -700,7 +791,7 @@ class CreateMessageRequest(Request):
params: CreateMessageRequestParams
StopReason = Literal["endTurn", "stopSequence", "maxTokens"]
StopReason = Literal["endTurn", "stopSequence", "maxTokens"] | str
class CreateMessageResult(Result):
@@ -710,8 +801,8 @@ class CreateMessageResult(Result):
content: TextContent | ImageContent
model: str
"""The name of the model that generated the message."""
stopReason: StopReason
"""The reason why sampling stopped."""
stopReason: StopReason | None = None
"""The reason why sampling stopped, if known."""
class ResourceReference(BaseModel):
@@ -781,6 +872,63 @@ class CompleteResult(Result):
completion: Completion
class ListRootsRequest(Request):
"""
Sent from the server to request a list of root URIs from the client. Roots allow
servers to ask for specific directories or files to operate on. A common example
for roots is providing a set of repositories or directories a server should operate
on.
This request is typically used when the server needs to understand the file system
structure or access specific locations that the client has permission to read from.
"""
method: Literal["roots/list"]
params: RequestParams | None = None
class Root(BaseModel):
"""Represents a root directory or file that the server can operate on."""
uri: FileUrl
"""
The URI identifying the root. This *must* start with file:// for now.
This restriction may be relaxed in future versions of the protocol to allow
other URI schemes.
"""
name: str | None = None
"""
An optional name for the root. This can be used to provide a human-readable
identifier for the root, which may be useful for display purposes or for
referencing the root in other parts of the application.
"""
model_config = ConfigDict(extra="allow")
class ListRootsResult(Result):
"""
The client's response to a roots/list request from the server.
This result contains an array of Root objects, each representing a root directory
or file that the server can operate on.
"""
roots: list[Root]
class RootsListChangedNotification(Notification):
"""
A notification from the client to the server, informing it that the list of
roots has changed.
This notification should be sent whenever the client adds, removes, or
modifies any root. The server should then request an updated list of roots
using the ListRootsRequest.
"""
method: Literal["notifications/roots/list_changed"]
params: NotificationParams | None = None
class ClientRequest(
RootModel[
PingRequest
@@ -801,15 +949,19 @@ class ClientRequest(
pass
class ClientNotification(RootModel[ProgressNotification | InitializedNotification]):
class ClientNotification(
RootModel[
ProgressNotification | InitializedNotification | RootsListChangedNotification
]
):
pass
class ClientResult(RootModel[EmptyResult | CreateMessageResult]):
class ClientResult(RootModel[EmptyResult | CreateMessageResult | ListRootsResult]):
pass
class ServerRequest(RootModel[PingRequest | CreateMessageRequest]):
class ServerRequest(RootModel[PingRequest | CreateMessageRequest | ListRootsRequest]):
pass