diff --git a/mcp_python/client/session.py b/mcp_python/client/session.py index 663109f..6c6d01f 100644 --- a/mcp_python/client/session.py +++ b/mcp_python/client/session.py @@ -1,7 +1,7 @@ from datetime import timedelta from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import AnyUrl, FileUrl +from pydantic import AnyUrl from mcp_python.shared.session import BaseSession from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSIONS @@ -21,7 +21,6 @@ from mcp_python.types import ( JSONRPCMessage, ListPromptsResult, ListResourcesResult, - ListRootsResult, ListToolsResult, LoggingLevel, PromptReference, @@ -71,9 +70,11 @@ class ClientSession( sampling=None, experimental=None, roots={ - # TODO: Should this be based on whether we _will_ send notifications, or only whether they're supported? + # TODO: Should this be based on whether we + # _will_ send notifications, or only whether + # they're supported? "listChanged": True - } + }, ), clientInfo=Implementation(name="mcp_python", version="0.1.0"), ), @@ -246,7 +247,9 @@ class ClientSession( ListPromptsResult, ) - async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> GetPromptResult: + async def get_prompt( + self, name: str, arguments: dict[str, str] | None = None + ) -> GetPromptResult: """Send a prompts/get request.""" from mcp_python.types import GetPromptRequest, GetPromptRequestParams @@ -260,9 +263,15 @@ class ClientSession( GetPromptResult, ) - async def complete(self, ref: ResourceReference | PromptReference, argument: dict) -> CompleteResult: + async def complete( + self, ref: ResourceReference | PromptReference, argument: dict + ) -> CompleteResult: """Send a completion/complete request.""" - from mcp_python.types import CompleteRequest, CompleteRequestParams, CompletionArgument + from mcp_python.types import ( + CompleteRequest, + CompleteRequestParams, + CompletionArgument, + ) return await self.send_request( ClientRequest( diff --git a/mcp_python/server/__init__.py b/mcp_python/server/__init__.py index 5c85eea..26d6902 100644 --- a/mcp_python/server/__init__.py +++ b/mcp_python/server/__init__.py @@ -18,6 +18,7 @@ from mcp_python.types import ( ClientNotification, ClientRequest, CompleteRequest, + EmbeddedResource, EmptyResult, ErrorData, JSONRPCMessage, @@ -31,6 +32,7 @@ from mcp_python.types import ( PingRequest, ProgressNotification, Prompt, + PromptMessage, PromptReference, ReadResourceRequest, ReadResourceResult, @@ -40,11 +42,9 @@ from mcp_python.types import ( ServerResult, SetLevelRequest, SubscribeRequest, + TextContent, Tool, UnsubscribeRequest, - TextContent, - EmbeddedResource, - PromptMessage, ) logger = logging.getLogger(__name__) @@ -147,17 +147,14 @@ class Server: ) case types.EmbeddedResource() as resource: content = EmbeddedResource( - type="resource", - resource=resource.resource + type="resource", resource=resource.resource ) case _: raise ValueError( f"Unexpected content type: {type(message.content)}" ) - prompt_message = PromptMessage( - role=message.role, content=content - ) + prompt_message = PromptMessage(role=message.role, content=content) messages.append(prompt_message) return ServerResult( @@ -175,9 +172,7 @@ class Server: async def handler(_: Any): resources = await func() - return ServerResult( - ListResourcesResult(resources=resources) - ) + return ServerResult(ListResourcesResult(resources=resources)) self.request_handlers[ListResourcesRequest] = handler return func @@ -222,7 +217,6 @@ class Server: return decorator - def set_logging_level(self): from mcp_python.types import EmptyResult @@ -282,10 +276,17 @@ class Server: return decorator def call_tool(self): - from mcp_python.types import CallToolResult, TextContent, ImageContent, EmbeddedResource + from mcp_python.types import ( + CallToolResult, + EmbeddedResource, + ImageContent, + TextContent, + ) def decorator( - func: Callable[..., Awaitable[list[str | types.ImageContent | types.EmbeddedResource]]] + func: Callable[ + ..., Awaitable[list[str | types.ImageContent | types.EmbeddedResource]] + ], ): logger.debug("Registering handler for CallToolRequest") @@ -298,28 +299,26 @@ class Server: case str() as text: content.append(TextContent(type="text", text=text)) case types.ImageContent() as img: - content.append(ImageContent( - type="image", - data=img.data, - mimeType=img.mime_type - )) + content.append( + ImageContent( + type="image", + data=img.data, + mimeType=img.mime_type, + ) + ) case types.EmbeddedResource() as resource: - content.append(EmbeddedResource( - type="resource", - resource=resource.resource - )) + content.append( + EmbeddedResource( + type="resource", resource=resource.resource + ) + ) - return ServerResult( - CallToolResult( - content=content, - isError=False - ) - ) + return ServerResult(CallToolResult(content=content, isError=False)) except Exception as e: return ServerResult( CallToolResult( content=[TextContent(type="text", text=str(e))], - isError=True + isError=True, ) ) diff --git a/mcp_python/server/session.py b/mcp_python/server/session.py index 7ecdf19..db7ebc2 100644 --- a/mcp_python/server/session.py +++ b/mcp_python/server/session.py @@ -12,7 +12,7 @@ from mcp_python.shared.session import ( RequestResponder, ) from mcp_python.types import ( - ListRootsResult, LATEST_PROTOCOL_VERSION, + LATEST_PROTOCOL_VERSION, ClientNotification, ClientRequest, CreateMessageResult, @@ -23,15 +23,16 @@ from mcp_python.types import ( InitializeRequest, InitializeResult, JSONRPCMessage, + ListRootsResult, LoggingLevel, + ModelPreferences, + PromptListChangedNotification, + ResourceListChangedNotification, SamplingMessage, ServerNotification, ServerRequest, ServerResult, - ResourceListChangedNotification, ToolListChangedNotification, - PromptListChangedNotification, - ModelPreferences, ) diff --git a/mcp_python/server/types.py b/mcp_python/server/types.py index 437bc29..acc5c1e 100644 --- a/mcp_python/server/types.py +++ b/mcp_python/server/types.py @@ -1,5 +1,6 @@ """ -This module provides simpler types to use with the server for managing prompts and tools. +This module provides simpler types to use with the server for managing prompts +and tools. """ from dataclasses import dataclass @@ -7,7 +8,12 @@ from typing import Literal from pydantic import BaseModel -from mcp_python.types import Role, ServerCapabilities, TextResourceContents, BlobResourceContents +from mcp_python.types import ( + BlobResourceContents, + Role, + ServerCapabilities, + TextResourceContents, +) @dataclass diff --git a/mcp_python/shared/memory.py b/mcp_python/shared/memory.py index a291749..6ebfe9f 100644 --- a/mcp_python/shared/memory.py +++ b/mcp_python/shared/memory.py @@ -15,14 +15,14 @@ from mcp_python.types import JSONRPCMessage MessageStream = tuple[ MemoryObjectReceiveStream[JSONRPCMessage | Exception], - MemoryObjectSendStream[JSONRPCMessage] + MemoryObjectSendStream[JSONRPCMessage], ] + @asynccontextmanager -async def create_client_server_memory_streams() -> AsyncGenerator[ - tuple[MessageStream, MessageStream], - None -]: +async def create_client_server_memory_streams() -> ( + AsyncGenerator[tuple[MessageStream, MessageStream], None] +): """ Creates a pair of bidirectional memory streams for client-server communication. diff --git a/mcp_python/shared/session.py b/mcp_python/shared/session.py index f063a33..95e354b 100644 --- a/mcp_python/shared/session.py +++ b/mcp_python/shared/session.py @@ -154,7 +154,8 @@ class BaseSession( try: with anyio.fail_after( - None if self._read_timeout_seconds is None + None + if self._read_timeout_seconds is None else self._read_timeout_seconds.total_seconds() ): response_or_error = await response_stream_reader.receive() @@ -168,7 +169,6 @@ class BaseSession( f"{self._read_timeout_seconds} seconds." ), ) - ) if isinstance(response_or_error, JSONRPCError): diff --git a/mcp_python/types.py b/mcp_python/types.py index 0e26e74..4e80719 100644 --- a/mcp_python/types.py +++ b/mcp_python/types.py @@ -654,7 +654,9 @@ class ToolListChangedNotification(Notification): params: NotificationParams | None = None -LoggingLevel = Literal["debug", "info", "notice", "warning", "error", "critical", "alert", "emergency"] +LoggingLevel = Literal[ + "debug", "info", "notice", "warning", "error", "critical", "alert", "emergency" +] class SetLevelRequestParams(RequestParams): @@ -708,7 +710,8 @@ class ModelHint(BaseModel): class ModelPreferences(BaseModel): """ - The server's preferences for model selection, requested of the client during sampling. + The server's preferences for model selection, requested of the client during + sampling. Because LLMs can vary along multiple dimensions, choosing the "best" model is rarely straightforward. Different models excel in different areas—some are @@ -761,7 +764,10 @@ class CreateMessageRequestParams(RequestParams): messages: list[SamplingMessage] modelPreferences: ModelPreferences | None = None - """The server's preferences for which model to select. The client MAY ignore these preferences.""" + """ + The server's preferences for which model to select. The client MAY ignore + these preferences. + """ systemPrompt: str | None = None """An optional system prompt the server wants to use for sampling.""" includeContext: IncludeContext | None = None @@ -911,9 +917,12 @@ class ListRootsResult(Result): class RootsListChangedNotification(Notification): """ - A notification from the client to the server, informing it that the list of roots has changed. - This notification should be sent whenever the client adds, removes, or modifies any root. - The server should then request an updated list of roots using the ListRootsRequest. + A notification from the client to the server, informing it that the list of + roots has changed. + + This notification should be sent whenever the client adds, removes, or + modifies any root. The server should then request an updated list of roots + using the ListRootsRequest. """ method: Literal["notifications/roots/list_changed"] @@ -940,7 +949,11 @@ class ClientRequest( pass -class ClientNotification(RootModel[ProgressNotification | InitializedNotification | RootsListChangedNotification]): +class ClientNotification( + RootModel[ + ProgressNotification | InitializedNotification | RootsListChangedNotification + ] +): pass diff --git a/tests/conftest.py b/tests/conftest.py index 37ff5a4..28690b2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ TEST_INITIALIZATION_OPTIONS = InitializationOptions( capabilities=ServerCapabilities(), ) + @pytest.fixture def mcp_server() -> Server: server = Server(name="test_server") @@ -21,7 +22,7 @@ def mcp_server() -> Server: Resource( uri=AnyUrl("memory://test"), name="Test Resource", - description="A test resource" + description="A test resource", ) ]