Formatting

This commit is contained in:
Justin Spahr-Summers
2024-11-06 12:35:32 +00:00
parent a891ad4689
commit c7d8f11e0c
8 changed files with 87 additions and 58 deletions

View File

@@ -1,7 +1,7 @@
from datetime import timedelta
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl, FileUrl
from pydantic import AnyUrl
from mcp_python.shared.session import BaseSession
from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSIONS
@@ -21,7 +21,6 @@ from mcp_python.types import (
JSONRPCMessage,
ListPromptsResult,
ListResourcesResult,
ListRootsResult,
ListToolsResult,
LoggingLevel,
PromptReference,
@@ -71,9 +70,11 @@ class ClientSession(
sampling=None,
experimental=None,
roots={
# TODO: Should this be based on whether we _will_ send notifications, or only whether they're supported?
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
"listChanged": True
}
},
),
clientInfo=Implementation(name="mcp_python", version="0.1.0"),
),
@@ -246,7 +247,9 @@ class ClientSession(
ListPromptsResult,
)
async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> GetPromptResult:
async def get_prompt(
self, name: str, arguments: dict[str, str] | None = None
) -> GetPromptResult:
"""Send a prompts/get request."""
from mcp_python.types import GetPromptRequest, GetPromptRequestParams
@@ -260,9 +263,15 @@ class ClientSession(
GetPromptResult,
)
async def complete(self, ref: ResourceReference | PromptReference, argument: dict) -> CompleteResult:
async def complete(
self, ref: ResourceReference | PromptReference, argument: dict
) -> CompleteResult:
"""Send a completion/complete request."""
from mcp_python.types import CompleteRequest, CompleteRequestParams, CompletionArgument
from mcp_python.types import (
CompleteRequest,
CompleteRequestParams,
CompletionArgument,
)
return await self.send_request(
ClientRequest(

View File

@@ -18,6 +18,7 @@ from mcp_python.types import (
ClientNotification,
ClientRequest,
CompleteRequest,
EmbeddedResource,
EmptyResult,
ErrorData,
JSONRPCMessage,
@@ -31,6 +32,7 @@ from mcp_python.types import (
PingRequest,
ProgressNotification,
Prompt,
PromptMessage,
PromptReference,
ReadResourceRequest,
ReadResourceResult,
@@ -40,11 +42,9 @@ from mcp_python.types import (
ServerResult,
SetLevelRequest,
SubscribeRequest,
TextContent,
Tool,
UnsubscribeRequest,
TextContent,
EmbeddedResource,
PromptMessage,
)
logger = logging.getLogger(__name__)
@@ -147,17 +147,14 @@ class Server:
)
case types.EmbeddedResource() as resource:
content = EmbeddedResource(
type="resource",
resource=resource.resource
type="resource", resource=resource.resource
)
case _:
raise ValueError(
f"Unexpected content type: {type(message.content)}"
)
prompt_message = PromptMessage(
role=message.role, content=content
)
prompt_message = PromptMessage(role=message.role, content=content)
messages.append(prompt_message)
return ServerResult(
@@ -175,9 +172,7 @@ class Server:
async def handler(_: Any):
resources = await func()
return ServerResult(
ListResourcesResult(resources=resources)
)
return ServerResult(ListResourcesResult(resources=resources))
self.request_handlers[ListResourcesRequest] = handler
return func
@@ -222,7 +217,6 @@ class Server:
return decorator
def set_logging_level(self):
from mcp_python.types import EmptyResult
@@ -282,10 +276,17 @@ class Server:
return decorator
def call_tool(self):
from mcp_python.types import CallToolResult, TextContent, ImageContent, EmbeddedResource
from mcp_python.types import (
CallToolResult,
EmbeddedResource,
ImageContent,
TextContent,
)
def decorator(
func: Callable[..., Awaitable[list[str | types.ImageContent | types.EmbeddedResource]]]
func: Callable[
..., Awaitable[list[str | types.ImageContent | types.EmbeddedResource]]
],
):
logger.debug("Registering handler for CallToolRequest")
@@ -298,28 +299,26 @@ class Server:
case str() as text:
content.append(TextContent(type="text", text=text))
case types.ImageContent() as img:
content.append(ImageContent(
type="image",
data=img.data,
mimeType=img.mime_type
))
content.append(
ImageContent(
type="image",
data=img.data,
mimeType=img.mime_type,
)
)
case types.EmbeddedResource() as resource:
content.append(EmbeddedResource(
type="resource",
resource=resource.resource
))
content.append(
EmbeddedResource(
type="resource", resource=resource.resource
)
)
return ServerResult(
CallToolResult(
content=content,
isError=False
)
)
return ServerResult(CallToolResult(content=content, isError=False))
except Exception as e:
return ServerResult(
CallToolResult(
content=[TextContent(type="text", text=str(e))],
isError=True
isError=True,
)
)

View File

@@ -12,7 +12,7 @@ from mcp_python.shared.session import (
RequestResponder,
)
from mcp_python.types import (
ListRootsResult, LATEST_PROTOCOL_VERSION,
LATEST_PROTOCOL_VERSION,
ClientNotification,
ClientRequest,
CreateMessageResult,
@@ -23,15 +23,16 @@ from mcp_python.types import (
InitializeRequest,
InitializeResult,
JSONRPCMessage,
ListRootsResult,
LoggingLevel,
ModelPreferences,
PromptListChangedNotification,
ResourceListChangedNotification,
SamplingMessage,
ServerNotification,
ServerRequest,
ServerResult,
ResourceListChangedNotification,
ToolListChangedNotification,
PromptListChangedNotification,
ModelPreferences,
)

View File

@@ -1,5 +1,6 @@
"""
This module provides simpler types to use with the server for managing prompts and tools.
This module provides simpler types to use with the server for managing prompts
and tools.
"""
from dataclasses import dataclass
@@ -7,7 +8,12 @@ from typing import Literal
from pydantic import BaseModel
from mcp_python.types import Role, ServerCapabilities, TextResourceContents, BlobResourceContents
from mcp_python.types import (
BlobResourceContents,
Role,
ServerCapabilities,
TextResourceContents,
)
@dataclass

View File

@@ -15,14 +15,14 @@ from mcp_python.types import JSONRPCMessage
MessageStream = tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage]
MemoryObjectSendStream[JSONRPCMessage],
]
@asynccontextmanager
async def create_client_server_memory_streams() -> AsyncGenerator[
tuple[MessageStream, MessageStream],
None
]:
async def create_client_server_memory_streams() -> (
AsyncGenerator[tuple[MessageStream, MessageStream], None]
):
"""
Creates a pair of bidirectional memory streams for client-server communication.

View File

@@ -154,7 +154,8 @@ class BaseSession(
try:
with anyio.fail_after(
None if self._read_timeout_seconds is None
None
if self._read_timeout_seconds is None
else self._read_timeout_seconds.total_seconds()
):
response_or_error = await response_stream_reader.receive()
@@ -168,7 +169,6 @@ class BaseSession(
f"{self._read_timeout_seconds} seconds."
),
)
)
if isinstance(response_or_error, JSONRPCError):

View File

@@ -654,7 +654,9 @@ class ToolListChangedNotification(Notification):
params: NotificationParams | None = None
LoggingLevel = Literal["debug", "info", "notice", "warning", "error", "critical", "alert", "emergency"]
LoggingLevel = Literal[
"debug", "info", "notice", "warning", "error", "critical", "alert", "emergency"
]
class SetLevelRequestParams(RequestParams):
@@ -708,7 +710,8 @@ class ModelHint(BaseModel):
class ModelPreferences(BaseModel):
"""
The server's preferences for model selection, requested of the client during sampling.
The server's preferences for model selection, requested of the client during
sampling.
Because LLMs can vary along multiple dimensions, choosing the "best" model is
rarely straightforward. Different models excel in different areas—some are
@@ -761,7 +764,10 @@ class CreateMessageRequestParams(RequestParams):
messages: list[SamplingMessage]
modelPreferences: ModelPreferences | None = None
"""The server's preferences for which model to select. The client MAY ignore these preferences."""
"""
The server's preferences for which model to select. The client MAY ignore
these preferences.
"""
systemPrompt: str | None = None
"""An optional system prompt the server wants to use for sampling."""
includeContext: IncludeContext | None = None
@@ -911,9 +917,12 @@ class ListRootsResult(Result):
class RootsListChangedNotification(Notification):
"""
A notification from the client to the server, informing it that the list of roots has changed.
This notification should be sent whenever the client adds, removes, or modifies any root.
The server should then request an updated list of roots using the ListRootsRequest.
A notification from the client to the server, informing it that the list of
roots has changed.
This notification should be sent whenever the client adds, removes, or
modifies any root. The server should then request an updated list of roots
using the ListRootsRequest.
"""
method: Literal["notifications/roots/list_changed"]
@@ -940,7 +949,11 @@ class ClientRequest(
pass
class ClientNotification(RootModel[ProgressNotification | InitializedNotification | RootsListChangedNotification]):
class ClientNotification(
RootModel[
ProgressNotification | InitializedNotification | RootsListChangedNotification
]
):
pass

View File

@@ -11,6 +11,7 @@ TEST_INITIALIZATION_OPTIONS = InitializationOptions(
capabilities=ServerCapabilities(),
)
@pytest.fixture
def mcp_server() -> Server:
server = Server(name="test_server")
@@ -21,7 +22,7 @@ def mcp_server() -> Server:
Resource(
uri=AnyUrl("memory://test"),
name="Test Resource",
description="A test resource"
description="A test resource",
)
]