Formatting

This commit is contained in:
Justin Spahr-Summers
2024-11-06 12:35:32 +00:00
parent a891ad4689
commit c7d8f11e0c
8 changed files with 87 additions and 58 deletions

View File

@@ -1,7 +1,7 @@
from datetime import timedelta from datetime import timedelta
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl, FileUrl from pydantic import AnyUrl
from mcp_python.shared.session import BaseSession from mcp_python.shared.session import BaseSession
from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSIONS
@@ -21,7 +21,6 @@ from mcp_python.types import (
JSONRPCMessage, JSONRPCMessage,
ListPromptsResult, ListPromptsResult,
ListResourcesResult, ListResourcesResult,
ListRootsResult,
ListToolsResult, ListToolsResult,
LoggingLevel, LoggingLevel,
PromptReference, PromptReference,
@@ -71,9 +70,11 @@ class ClientSession(
sampling=None, sampling=None,
experimental=None, experimental=None,
roots={ roots={
# TODO: Should this be based on whether we _will_ send notifications, or only whether they're supported? # TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
"listChanged": True "listChanged": True
} },
), ),
clientInfo=Implementation(name="mcp_python", version="0.1.0"), clientInfo=Implementation(name="mcp_python", version="0.1.0"),
), ),
@@ -246,7 +247,9 @@ class ClientSession(
ListPromptsResult, ListPromptsResult,
) )
async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> GetPromptResult: async def get_prompt(
self, name: str, arguments: dict[str, str] | None = None
) -> GetPromptResult:
"""Send a prompts/get request.""" """Send a prompts/get request."""
from mcp_python.types import GetPromptRequest, GetPromptRequestParams from mcp_python.types import GetPromptRequest, GetPromptRequestParams
@@ -260,9 +263,15 @@ class ClientSession(
GetPromptResult, GetPromptResult,
) )
async def complete(self, ref: ResourceReference | PromptReference, argument: dict) -> CompleteResult: async def complete(
self, ref: ResourceReference | PromptReference, argument: dict
) -> CompleteResult:
"""Send a completion/complete request.""" """Send a completion/complete request."""
from mcp_python.types import CompleteRequest, CompleteRequestParams, CompletionArgument from mcp_python.types import (
CompleteRequest,
CompleteRequestParams,
CompletionArgument,
)
return await self.send_request( return await self.send_request(
ClientRequest( ClientRequest(

View File

@@ -18,6 +18,7 @@ from mcp_python.types import (
ClientNotification, ClientNotification,
ClientRequest, ClientRequest,
CompleteRequest, CompleteRequest,
EmbeddedResource,
EmptyResult, EmptyResult,
ErrorData, ErrorData,
JSONRPCMessage, JSONRPCMessage,
@@ -31,6 +32,7 @@ from mcp_python.types import (
PingRequest, PingRequest,
ProgressNotification, ProgressNotification,
Prompt, Prompt,
PromptMessage,
PromptReference, PromptReference,
ReadResourceRequest, ReadResourceRequest,
ReadResourceResult, ReadResourceResult,
@@ -40,11 +42,9 @@ from mcp_python.types import (
ServerResult, ServerResult,
SetLevelRequest, SetLevelRequest,
SubscribeRequest, SubscribeRequest,
TextContent,
Tool, Tool,
UnsubscribeRequest, UnsubscribeRequest,
TextContent,
EmbeddedResource,
PromptMessage,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -147,17 +147,14 @@ class Server:
) )
case types.EmbeddedResource() as resource: case types.EmbeddedResource() as resource:
content = EmbeddedResource( content = EmbeddedResource(
type="resource", type="resource", resource=resource.resource
resource=resource.resource
) )
case _: case _:
raise ValueError( raise ValueError(
f"Unexpected content type: {type(message.content)}" f"Unexpected content type: {type(message.content)}"
) )
prompt_message = PromptMessage( prompt_message = PromptMessage(role=message.role, content=content)
role=message.role, content=content
)
messages.append(prompt_message) messages.append(prompt_message)
return ServerResult( return ServerResult(
@@ -175,9 +172,7 @@ class Server:
async def handler(_: Any): async def handler(_: Any):
resources = await func() resources = await func()
return ServerResult( return ServerResult(ListResourcesResult(resources=resources))
ListResourcesResult(resources=resources)
)
self.request_handlers[ListResourcesRequest] = handler self.request_handlers[ListResourcesRequest] = handler
return func return func
@@ -222,7 +217,6 @@ class Server:
return decorator return decorator
def set_logging_level(self): def set_logging_level(self):
from mcp_python.types import EmptyResult from mcp_python.types import EmptyResult
@@ -282,10 +276,17 @@ class Server:
return decorator return decorator
def call_tool(self): def call_tool(self):
from mcp_python.types import CallToolResult, TextContent, ImageContent, EmbeddedResource from mcp_python.types import (
CallToolResult,
EmbeddedResource,
ImageContent,
TextContent,
)
def decorator( def decorator(
func: Callable[..., Awaitable[list[str | types.ImageContent | types.EmbeddedResource]]] func: Callable[
..., Awaitable[list[str | types.ImageContent | types.EmbeddedResource]]
],
): ):
logger.debug("Registering handler for CallToolRequest") logger.debug("Registering handler for CallToolRequest")
@@ -298,28 +299,26 @@ class Server:
case str() as text: case str() as text:
content.append(TextContent(type="text", text=text)) content.append(TextContent(type="text", text=text))
case types.ImageContent() as img: case types.ImageContent() as img:
content.append(ImageContent( content.append(
type="image", ImageContent(
data=img.data, type="image",
mimeType=img.mime_type data=img.data,
)) mimeType=img.mime_type,
)
)
case types.EmbeddedResource() as resource: case types.EmbeddedResource() as resource:
content.append(EmbeddedResource( content.append(
type="resource", EmbeddedResource(
resource=resource.resource type="resource", resource=resource.resource
)) )
)
return ServerResult( return ServerResult(CallToolResult(content=content, isError=False))
CallToolResult(
content=content,
isError=False
)
)
except Exception as e: except Exception as e:
return ServerResult( return ServerResult(
CallToolResult( CallToolResult(
content=[TextContent(type="text", text=str(e))], content=[TextContent(type="text", text=str(e))],
isError=True isError=True,
) )
) )

View File

@@ -12,7 +12,7 @@ from mcp_python.shared.session import (
RequestResponder, RequestResponder,
) )
from mcp_python.types import ( from mcp_python.types import (
ListRootsResult, LATEST_PROTOCOL_VERSION, LATEST_PROTOCOL_VERSION,
ClientNotification, ClientNotification,
ClientRequest, ClientRequest,
CreateMessageResult, CreateMessageResult,
@@ -23,15 +23,16 @@ from mcp_python.types import (
InitializeRequest, InitializeRequest,
InitializeResult, InitializeResult,
JSONRPCMessage, JSONRPCMessage,
ListRootsResult,
LoggingLevel, LoggingLevel,
ModelPreferences,
PromptListChangedNotification,
ResourceListChangedNotification,
SamplingMessage, SamplingMessage,
ServerNotification, ServerNotification,
ServerRequest, ServerRequest,
ServerResult, ServerResult,
ResourceListChangedNotification,
ToolListChangedNotification, ToolListChangedNotification,
PromptListChangedNotification,
ModelPreferences,
) )

View File

@@ -1,5 +1,6 @@
""" """
This module provides simpler types to use with the server for managing prompts and tools. This module provides simpler types to use with the server for managing prompts
and tools.
""" """
from dataclasses import dataclass from dataclasses import dataclass
@@ -7,7 +8,12 @@ from typing import Literal
from pydantic import BaseModel from pydantic import BaseModel
from mcp_python.types import Role, ServerCapabilities, TextResourceContents, BlobResourceContents from mcp_python.types import (
BlobResourceContents,
Role,
ServerCapabilities,
TextResourceContents,
)
@dataclass @dataclass

View File

@@ -15,14 +15,14 @@ from mcp_python.types import JSONRPCMessage
MessageStream = tuple[ MessageStream = tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception], MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage] MemoryObjectSendStream[JSONRPCMessage],
] ]
@asynccontextmanager @asynccontextmanager
async def create_client_server_memory_streams() -> AsyncGenerator[ async def create_client_server_memory_streams() -> (
tuple[MessageStream, MessageStream], AsyncGenerator[tuple[MessageStream, MessageStream], None]
None ):
]:
""" """
Creates a pair of bidirectional memory streams for client-server communication. Creates a pair of bidirectional memory streams for client-server communication.

View File

@@ -154,7 +154,8 @@ class BaseSession(
try: try:
with anyio.fail_after( with anyio.fail_after(
None if self._read_timeout_seconds is None None
if self._read_timeout_seconds is None
else self._read_timeout_seconds.total_seconds() else self._read_timeout_seconds.total_seconds()
): ):
response_or_error = await response_stream_reader.receive() response_or_error = await response_stream_reader.receive()
@@ -168,7 +169,6 @@ class BaseSession(
f"{self._read_timeout_seconds} seconds." f"{self._read_timeout_seconds} seconds."
), ),
) )
) )
if isinstance(response_or_error, JSONRPCError): if isinstance(response_or_error, JSONRPCError):

View File

@@ -654,7 +654,9 @@ class ToolListChangedNotification(Notification):
params: NotificationParams | None = None params: NotificationParams | None = None
LoggingLevel = Literal["debug", "info", "notice", "warning", "error", "critical", "alert", "emergency"] LoggingLevel = Literal[
"debug", "info", "notice", "warning", "error", "critical", "alert", "emergency"
]
class SetLevelRequestParams(RequestParams): class SetLevelRequestParams(RequestParams):
@@ -708,7 +710,8 @@ class ModelHint(BaseModel):
class ModelPreferences(BaseModel): class ModelPreferences(BaseModel):
""" """
The server's preferences for model selection, requested of the client during sampling. The server's preferences for model selection, requested of the client during
sampling.
Because LLMs can vary along multiple dimensions, choosing the "best" model is Because LLMs can vary along multiple dimensions, choosing the "best" model is
rarely straightforward. Different models excel in different areas—some are rarely straightforward. Different models excel in different areas—some are
@@ -761,7 +764,10 @@ class CreateMessageRequestParams(RequestParams):
messages: list[SamplingMessage] messages: list[SamplingMessage]
modelPreferences: ModelPreferences | None = None modelPreferences: ModelPreferences | None = None
"""The server's preferences for which model to select. The client MAY ignore these preferences.""" """
The server's preferences for which model to select. The client MAY ignore
these preferences.
"""
systemPrompt: str | None = None systemPrompt: str | None = None
"""An optional system prompt the server wants to use for sampling.""" """An optional system prompt the server wants to use for sampling."""
includeContext: IncludeContext | None = None includeContext: IncludeContext | None = None
@@ -911,9 +917,12 @@ class ListRootsResult(Result):
class RootsListChangedNotification(Notification): class RootsListChangedNotification(Notification):
""" """
A notification from the client to the server, informing it that the list of roots has changed. A notification from the client to the server, informing it that the list of
This notification should be sent whenever the client adds, removes, or modifies any root. roots has changed.
The server should then request an updated list of roots using the ListRootsRequest.
This notification should be sent whenever the client adds, removes, or
modifies any root. The server should then request an updated list of roots
using the ListRootsRequest.
""" """
method: Literal["notifications/roots/list_changed"] method: Literal["notifications/roots/list_changed"]
@@ -940,7 +949,11 @@ class ClientRequest(
pass pass
class ClientNotification(RootModel[ProgressNotification | InitializedNotification | RootsListChangedNotification]): class ClientNotification(
RootModel[
ProgressNotification | InitializedNotification | RootsListChangedNotification
]
):
pass pass

View File

@@ -11,6 +11,7 @@ TEST_INITIALIZATION_OPTIONS = InitializationOptions(
capabilities=ServerCapabilities(), capabilities=ServerCapabilities(),
) )
@pytest.fixture @pytest.fixture
def mcp_server() -> Server: def mcp_server() -> Server:
server = Server(name="test_server") server = Server(name="test_server")
@@ -21,7 +22,7 @@ def mcp_server() -> Server:
Resource( Resource(
uri=AnyUrl("memory://test"), uri=AnyUrl("memory://test"),
name="Test Resource", name="Test Resource",
description="A test resource" description="A test resource",
) )
] ]