From fd68df6687c482c5407bc38725712a228faca9b5 Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Fri, 11 Oct 2024 11:54:16 +0100 Subject: [PATCH] Format with ruff --- mcp_python/__init__.py | 4 +- mcp_python/client/session.py | 3 +- mcp_python/client/sse.py | 24 ++++-- mcp_python/client/stdio.py | 3 +- mcp_python/server/__init__.py | 57 +++++++------ mcp_python/server/__main__.py | 18 +++- mcp_python/server/session.py | 7 +- mcp_python/server/sse.py | 17 ++-- mcp_python/server/stdio.py | 10 ++- mcp_python/server/types.py | 1 + mcp_python/server/websocket.py | 7 +- mcp_python/shared/session.py | 53 ++++++++---- mcp_python/types.py | 149 +++++++++++++++++++++++++-------- tests/client/test_session.py | 8 +- tests/server/test_session.py | 8 +- 15 files changed, 268 insertions(+), 101 deletions(-) diff --git a/mcp_python/__init__.py b/mcp_python/__init__.py index 2285847..78a62be 100644 --- a/mcp_python/__init__.py +++ b/mcp_python/__init__.py @@ -37,7 +37,6 @@ from .types import ( ReadResourceResult, Resource, ResourceUpdatedNotification, - Role as SamplingRole, SamplingMessage, ServerCapabilities, ServerNotification, @@ -49,6 +48,9 @@ from .types import ( Tool, UnsubscribeRequest, ) +from .types import ( + Role as SamplingRole, +) __all__ = [ "CallToolRequest", diff --git a/mcp_python/client/session.py b/mcp_python/client/session.py index 5eab70e..769e945 100644 --- a/mcp_python/client/session.py +++ b/mcp_python/client/session.py @@ -62,7 +62,8 @@ class ClientSession( if result.protocolVersion != SUPPORTED_PROTOCOL_VERSION: raise RuntimeError( - f"Unsupported protocol version from the server: {result.protocolVersion}" + "Unsupported protocol version from the server: " + f"{result.protocolVersion}" ) await self.send_notification( diff --git a/mcp_python/client/sse.py b/mcp_python/client/sse.py index ebb4842..09826a6 100644 --- a/mcp_python/client/sse.py +++ b/mcp_python/client/sse.py @@ -19,11 +19,17 @@ def remove_request_params(url: str) -> str: @asynccontextmanager -async def sse_client(url: str, headers: dict[str, Any] | None = None, timeout: float = 5, sse_read_timeout: float = 60 * 5): +async def sse_client( + url: str, + headers: dict[str, Any] | None = None, + timeout: float = 5, + sse_read_timeout: float = 60 * 5, +): """ Client transport for SSE. - `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. + `sse_read_timeout` determines how long (in seconds) the client will wait for a new + event before disconnecting. All other HTTP operations are controlled by `timeout`. """ read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] @@ -67,7 +73,10 @@ async def sse_client(url: str, headers: dict[str, Any] | None = None, timeout: f or url_parsed.scheme != endpoint_parsed.scheme ): - error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}" + error_msg = ( + "Endpoint origin does not match " + f"connection origin: {endpoint_url}" + ) logger.error(error_msg) raise ValueError(error_msg) @@ -104,11 +113,16 @@ async def sse_client(url: str, headers: dict[str, Any] | None = None, timeout: f logger.debug(f"Sending client message: {message}") response = await client.post( endpoint_url, - json=message.model_dump(by_alias=True, mode="json", exclude_none=True), + json=message.model_dump( + by_alias=True, + mode="json", + exclude_none=True, + ), ) response.raise_for_status() logger.debug( - f"Client message sent successfully: {response.status_code}" + "Client message sent successfully: " + f"{response.status_code}" ) except Exception as exc: logger.error(f"Error in post_writer: {exc}") diff --git a/mcp_python/client/stdio.py b/mcp_python/client/stdio.py index 30c0bf6..f9404d3 100644 --- a/mcp_python/client/stdio.py +++ b/mcp_python/client/stdio.py @@ -28,7 +28,8 @@ class StdioServerParameters(BaseModel): @asynccontextmanager async def stdio_client(server: StdioServerParameters): """ - Client transport for stdio: this will connect to a server by spawning a process and communicating with it over stdin/stdout. + Client transport for stdio: this will connect to a server by spawning a + process and communicating with it over stdin/stdout. """ read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] diff --git a/mcp_python/server/__init__.py b/mcp_python/server/__init__.py index 7801f64..3339db2 100644 --- a/mcp_python/server/__init__.py +++ b/mcp_python/server/__init__.py @@ -55,9 +55,11 @@ class Server: def create_initialization_options(self) -> types.InitializationOptions: """Create initialization options from this server instance.""" + def pkg_version(package: str) -> str: try: from importlib.metadata import version + return version(package) except Exception: return "unknown" @@ -69,16 +71,17 @@ class Server: ) def get_capabilities(self) -> ServerCapabilities: - """Convert existing handlers to a ServerCapabilities object.""" - def get_capability(req_type: type) -> dict[str, Any] | None: - return {} if req_type in self.request_handlers else None + """Convert existing handlers to a ServerCapabilities object.""" - return ServerCapabilities( - prompts=get_capability(ListPromptsRequest), - resources=get_capability(ListResourcesRequest), - tools=get_capability(ListPromptsRequest), - logging=get_capability(SetLevelRequest) - ) + def get_capability(req_type: type) -> dict[str, Any] | None: + return {} if req_type in self.request_handlers else None + + return ServerCapabilities( + prompts=get_capability(ListPromptsRequest), + resources=get_capability(ListResourcesRequest), + tools=get_capability(ListPromptsRequest), + logging=get_capability(SetLevelRequest), + ) @property def request_context(self) -> RequestContext: @@ -87,7 +90,7 @@ class Server: def list_prompts(self): def decorator(func: Callable[[], Awaitable[list[Prompt]]]): - logger.debug(f"Registering handler for PromptListRequest") + logger.debug("Registering handler for PromptListRequest") async def handler(_: Any): prompts = await func() @@ -103,17 +106,19 @@ class Server: GetPromptRequest, GetPromptResult, ImageContent, - Role as Role, SamplingMessage, TextContent, ) + from mcp_python.types import ( + Role as Role, + ) def decorator( func: Callable[ [str, dict[str, str] | None], Awaitable[types.PromptResponse] ], ): - logger.debug(f"Registering handler for GetPromptRequest") + logger.debug("Registering handler for GetPromptRequest") async def handler(req: GetPromptRequest): prompt_get = await func(req.params.name, req.params.arguments) @@ -149,7 +154,7 @@ class Server: def list_resources(self): def decorator(func: Callable[[], Awaitable[list[Resource]]]): - logger.debug(f"Registering handler for ListResourcesRequest") + logger.debug("Registering handler for ListResourcesRequest") async def handler(_: Any): resources = await func() @@ -169,7 +174,7 @@ class Server: ) def decorator(func: Callable[[AnyUrl], Awaitable[str | bytes]]): - logger.debug(f"Registering handler for ReadResourceRequest") + logger.debug("Registering handler for ReadResourceRequest") async def handler(req: ReadResourceRequest): result = await func(req.params.uri) @@ -204,7 +209,7 @@ class Server: from mcp_python.types import EmptyResult def decorator(func: Callable[[LoggingLevel], Awaitable[None]]): - logger.debug(f"Registering handler for SetLevelRequest") + logger.debug("Registering handler for SetLevelRequest") async def handler(req: SetLevelRequest): await func(req.params.level) @@ -219,7 +224,7 @@ class Server: from mcp_python.types import EmptyResult def decorator(func: Callable[[AnyUrl], Awaitable[None]]): - logger.debug(f"Registering handler for SubscribeRequest") + logger.debug("Registering handler for SubscribeRequest") async def handler(req: SubscribeRequest): await func(req.params.uri) @@ -234,7 +239,7 @@ class Server: from mcp_python.types import EmptyResult def decorator(func: Callable[[AnyUrl], Awaitable[None]]): - logger.debug(f"Registering handler for UnsubscribeRequest") + logger.debug("Registering handler for UnsubscribeRequest") async def handler(req: UnsubscribeRequest): await func(req.params.uri) @@ -249,7 +254,7 @@ class Server: from mcp_python.types import CallToolResult def decorator(func: Callable[..., Awaitable[Any]]): - logger.debug(f"Registering handler for CallToolRequest") + logger.debug("Registering handler for CallToolRequest") async def handler(req: CallToolRequest): result = await func(req.params.name, **(req.params.arguments or {})) @@ -264,7 +269,7 @@ class Server: def decorator( func: Callable[[str | int, float, float | None], Awaitable[None]], ): - logger.debug(f"Registering handler for ProgressNotification") + logger.debug("Registering handler for ProgressNotification") async def handler(req: ProgressNotification): await func( @@ -286,7 +291,7 @@ class Server: Awaitable[Completion | None], ], ): - logger.debug(f"Registering handler for CompleteRequest") + logger.debug("Registering handler for CompleteRequest") async def handler(req: CompleteRequest): completion = await func(req.params.ref, req.params.argument) @@ -307,10 +312,12 @@ class Server: self, read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], write_stream: MemoryObjectSendStream[JSONRPCMessage], - initialization_options: types.InitializationOptions + initialization_options: types.InitializationOptions, ): with warnings.catch_warnings(record=True) as w: - async with ServerSession(read_stream, write_stream, initialization_options) as session: + async with ServerSession( + read_stream, write_stream, initialization_options + ) as session: async for message in session.incoming_messages: logger.debug(f"Received message: {message}") @@ -359,14 +366,16 @@ class Server: handler = self.notification_handlers[type(notify)] logger.debug( - f"Dispatching notification of type {type(notify).__name__}" + f"Dispatching notification of type " + f"{type(notify).__name__}" ) try: await handler(notify) except Exception as err: logger.error( - f"Uncaught exception in notification handler: {err}" + f"Uncaught exception in notification handler: " + f"{err}" ) for warning in w: diff --git a/mcp_python/server/__main__.py b/mcp_python/server/__main__.py index efb7dd8..6cb8822 100644 --- a/mcp_python/server/__main__.py +++ b/mcp_python/server/__main__.py @@ -1,11 +1,12 @@ +import importlib.metadata import logging import sys -import importlib.metadata + import anyio from mcp_python.server.session import ServerSession -from mcp_python.server.types import InitializationOptions from mcp_python.server.stdio import stdio_server +from mcp_python.server.types import InitializationOptions from mcp_python.types import ServerCapabilities if not sys.warnoptions: @@ -30,7 +31,18 @@ async def receive_loop(session: ServerSession): async def main(): version = importlib.metadata.version("mcp_python") async with stdio_server() as (read_stream, write_stream): - async with ServerSession(read_stream, write_stream, InitializationOptions(server_name="mcp_python", server_version=version, capabilities=ServerCapabilities())) as session, write_stream: + async with ( + ServerSession( + read_stream, + write_stream, + InitializationOptions( + server_name="mcp_python", + server_version=version, + capabilities=ServerCapabilities(), + ), + ) as session, + write_stream, + ): await receive_loop(session) diff --git a/mcp_python/server/session.py b/mcp_python/server/session.py index c64f799..375e557 100644 --- a/mcp_python/server/session.py +++ b/mcp_python/server/session.py @@ -6,11 +6,11 @@ import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl +from mcp_python.server.types import InitializationOptions from mcp_python.shared.session import ( BaseSession, RequestResponder, ) -from mcp_python.server.types import InitializationOptions from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSION from mcp_python.types import ( ClientNotification, @@ -25,7 +25,6 @@ from mcp_python.types import ( JSONRPCMessage, LoggingLevel, SamplingMessage, - ServerCapabilities, ServerNotification, ServerRequest, ServerResult, @@ -53,7 +52,7 @@ class ServerSession( self, read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], write_stream: MemoryObjectSendStream[JSONRPCMessage], - init_options: InitializationOptions + init_options: InitializationOptions, ) -> None: super().__init__(read_stream, write_stream, ClientRequest, ClientNotification) self._initialization_state = InitializationState.NotInitialized @@ -72,7 +71,7 @@ class ServerSession( capabilities=self._init_options.capabilities, serverInfo=Implementation( name=self._init_options.server_name, - version=self._init_options.server_version + version=self._init_options.server_version, ), ) ) diff --git a/mcp_python/server/sse.py b/mcp_python/server/sse.py index c6e9fa6..d261c71 100644 --- a/mcp_python/server/sse.py +++ b/mcp_python/server/sse.py @@ -19,10 +19,14 @@ logger = logging.getLogger(__name__) class SseServerTransport: """ - SSE server transport for MCP. This class provides _two_ ASGI applications, suitable to be used with a framework like Starlette and a server like Hypercorn: + SSE server transport for MCP. This class provides _two_ ASGI applications, + suitable to be used with a framework like Starlette and a server like Hypercorn: - 1. connect_sse() is an ASGI application which receives incoming GET requests, and sets up a new SSE stream to send server messages to the client. - 2. handle_post_message() is an ASGI application which receives incoming POST requests, which should contain client messages that link to a previously-established SSE session. + 1. connect_sse() is an ASGI application which receives incoming GET requests, + and sets up a new SSE stream to send server messages to the client. + 2. handle_post_message() is an ASGI application which receives incoming POST + requests, which should contain client messages that link to a + previously-established SSE session. """ _endpoint: str @@ -30,7 +34,8 @@ class SseServerTransport: def __init__(self, endpoint: str) -> None: """ - Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL given. + Creates a new SSE server transport, which will direct the client to POST + messages to the relative or absolute URL given. """ super().__init__() @@ -74,7 +79,9 @@ class SseServerTransport: await sse_stream_writer.send( { "event": "message", - "data": message.model_dump_json(by_alias=True, exclude_none=True), + "data": message.model_dump_json( + by_alias=True, exclude_none=True + ), } ) diff --git a/mcp_python/server/stdio.py b/mcp_python/server/stdio.py index b55df0e..31ae415 100644 --- a/mcp_python/server/stdio.py +++ b/mcp_python/server/stdio.py @@ -7,14 +7,18 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStre from mcp_python.types import JSONRPCMessage + @asynccontextmanager async def stdio_server( - stdin: anyio.AsyncFile[str] | None = None, stdout: anyio.AsyncFile[str] | None = None + stdin: anyio.AsyncFile[str] | None = None, + stdout: anyio.AsyncFile[str] | None = None, ): """ - Server transport for stdio: this communicates with an MCP client by reading from the current process' stdin and writing to stdout. + Server transport for stdio: this communicates with an MCP client by reading + from the current process' stdin and writing to stdout. """ - # Purposely not using context managers for these, as we don't want to close standard process handles. + # Purposely not using context managers for these, as we don't want to close + # standard process handles. if not stdin: stdin = anyio.wrap_file(sys.stdin) if not stdout: diff --git a/mcp_python/server/types.py b/mcp_python/server/types.py index 1b56f24..7632406 100644 --- a/mcp_python/server/types.py +++ b/mcp_python/server/types.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from typing import Literal from pydantic import BaseModel + from mcp_python.types import Role, ServerCapabilities diff --git a/mcp_python/server/websocket.py b/mcp_python/server/websocket.py index 09f8a5a..5ba309b 100644 --- a/mcp_python/server/websocket.py +++ b/mcp_python/server/websocket.py @@ -14,7 +14,8 @@ logger = logging.getLogger(__name__) @asynccontextmanager async def websocket_server(scope: Scope, receive: Receive, send: Send): """ - WebSocket server transport for MCP. This is an ASGI application, suitable to be used with a framework like Starlette and a server like Hypercorn. + WebSocket server transport for MCP. This is an ASGI application, suitable to be + used with a framework like Starlette and a server like Hypercorn. """ websocket = WebSocket(scope, receive, send) @@ -47,7 +48,9 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send): try: async with write_stream_reader: async for message in write_stream_reader: - obj = message.model_dump(by_alias=True, mode="json", exclude_none=True) + obj = message.model_dump( + by_alias=True, mode="json", exclude_none=True + ) await websocket.send_json(obj) except anyio.ClosedResourceError: await websocket.close() diff --git a/mcp_python/shared/session.py b/mcp_python/shared/session.py index 1705e0d..3bc66fc 100644 --- a/mcp_python/shared/session.py +++ b/mcp_python/shared/session.py @@ -69,9 +69,11 @@ class BaseSession( ], ): """ - Implements an MCP "session" on top of read/write streams, including features like request/response linking, notifications, and progress. + Implements an MCP "session" on top of read/write streams, including features + like request/response linking, notifications, and progress. - This class is an async context manager that automatically starts processing messages when entered. + This class is an async context manager that automatically starts processing + messages when entered. """ _response_streams: dict[ @@ -108,7 +110,9 @@ class BaseSession( return self async def __aexit__(self, exc_type, exc_val, exc_tb): - # Using BaseSession as a context manager should not block on exit (this would be very surprising behavior), so make sure to cancel the tasks in the task group. + # Using BaseSession as a context manager should not block on exit (this + # would be very surprising behavior), so make sure to cancel the tasks + # in the task group. self._task_group.cancel_scope.cancel() return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) @@ -118,9 +122,11 @@ class BaseSession( result_type: type[ReceiveResultT], ) -> ReceiveResultT: """ - Sends a request and wait for a response. Raises an McpError if the response contains an error. + Sends a request and wait for a response. Raises an McpError if the + response contains an error. - Do not use this method to emit notifications! Use send_notification() instead. + Do not use this method to emit notifications! Use send_notification() + instead. """ request_id = self._request_id @@ -132,7 +138,9 @@ class BaseSession( self._response_streams[request_id] = response_stream jsonrpc_request = JSONRPCRequest( - jsonrpc="2.0", id=request_id, **request.model_dump(by_alias=True, mode="json", exclude_none=True) + jsonrpc="2.0", + id=request_id, + **request.model_dump(by_alias=True, mode="json", exclude_none=True), ) # TODO: Support progress callbacks @@ -147,10 +155,12 @@ class BaseSession( async def send_notification(self, notification: SendNotificationT) -> None: """ - Emits a notification, which is a one-way message that does not expect a response. + Emits a notification, which is a one-way message that does not expect + a response. """ jsonrpc_notification = JSONRPCNotification( - jsonrpc="2.0", **notification.model_dump(by_alias=True, mode="json", exclude_none=True) + jsonrpc="2.0", + **notification.model_dump(by_alias=True, mode="json", exclude_none=True), ) await self._write_stream.send(JSONRPCMessage(jsonrpc_notification)) @@ -165,7 +175,9 @@ class BaseSession( jsonrpc_response = JSONRPCResponse( jsonrpc="2.0", id=request_id, - result=response.model_dump(by_alias=True, mode="json", exclude_none=True), + result=response.model_dump( + by_alias=True, mode="json", exclude_none=True + ), ) await self._write_stream.send(JSONRPCMessage(jsonrpc_response)) @@ -180,7 +192,9 @@ class BaseSession( await self._incoming_message_stream_writer.send(message) elif isinstance(message.root, JSONRPCRequest): validated_request = self._receive_request_type.model_validate( - message.root.model_dump(by_alias=True, mode="json", exclude_none=True) + message.root.model_dump( + by_alias=True, mode="json", exclude_none=True + ) ) responder = RequestResponder( request_id=message.root.id, @@ -196,7 +210,9 @@ class BaseSession( await self._incoming_message_stream_writer.send(responder) elif isinstance(message.root, JSONRPCNotification): notification = self._receive_notification_type.model_validate( - message.root.model_dump(by_alias=True, mode="json", exclude_none=True) + message.root.model_dump( + by_alias=True, mode="json", exclude_none=True + ) ) await self._received_notification(notification) @@ -208,7 +224,8 @@ class BaseSession( else: await self._incoming_message_stream_writer.send( RuntimeError( - f"Received response with an unknown request ID: {message}" + "Received response with an unknown " + f"request ID: {message}" ) ) @@ -216,21 +233,25 @@ class BaseSession( self, responder: RequestResponder[ReceiveRequestT, SendResultT] ) -> None: """ - Can be overridden by subclasses to handle a request without needing to listen on the message stream. + Can be overridden by subclasses to handle a request without needing to + listen on the message stream. - If the request is responded to within this method, it will not be forwarded on to the message stream. + If the request is responded to within this method, it will not be + forwarded on to the message stream. """ async def _received_notification(self, notification: ReceiveNotificationT) -> None: """ - Can be overridden by subclasses to handle a notification without needing to listen on the message stream. + Can be overridden by subclasses to handle a notification without needing + to listen on the message stream. """ async def send_progress_notification( self, progress_token: str | int, progress: float, total: float | None = None ) -> None: """ - Sends a progress notification for a request that is currently being processed. + Sends a progress notification for a request that is currently being + processed. """ @property diff --git a/mcp_python/types.py b/mcp_python/types.py index 1087526..983964a 100644 --- a/mcp_python/types.py +++ b/mcp_python/types.py @@ -6,14 +6,19 @@ from pydantic.networks import AnyUrl """ Model Context Protocol bindings for Python -These bindings were generated from https://github.com/anthropic-experimental/mcp-spec, using Claude, with a prompt something like the following: +These bindings were generated from https://github.com/anthropic-experimental/mcp-spec, +using Claude, with a prompt something like the following: -Generate idiomatic Python bindings for this schema for MCP, or the "Model Context Protocol." The schema is defined in TypeScript, but there's also a JSON Schema version for reference. +Generate idiomatic Python bindings for this schema for MCP, or the "Model Context +Protocol." The schema is defined in TypeScript, but there's also a JSON Schema version +for reference. * For the bindings, let's use Pydantic V2 models. -* Each model should allow extra fields everywhere, by specifying `model_config = ConfigDict(extra='allow')`. Do this in every case, instead of a custom base class. +* Each model should allow extra fields everywhere, by specifying `model_config = + ConfigDict(extra='allow')`. Do this in every case, instead of a custom base class. * Union types should be represented with a Pydantic `RootModel`. -* Define additional model classes instead of using dictionaries. Do this even if they're not separate types in the schema. +* Define additional model classes instead of using dictionaries. Do this even if they're + not separate types in the schema. """ @@ -24,7 +29,10 @@ class RequestParams(BaseModel): class Meta(BaseModel): progressToken: ProgressToken | None = None """ - If specified, the caller is requesting out-of-band progress notifications for this request (as represented by notifications/progress). The value of this parameter is an opaque token that will be attached to any subsequent notifications. The receiver is not obligated to provide these notifications. + If specified, the caller is requesting out-of-band progress notifications for + this request (as represented by notifications/progress). The value of this + parameter is an opaque token that will be attached to any subsequent + notifications. The receiver is not obligated to provide these notifications. """ model_config = ConfigDict(extra="allow") @@ -38,9 +46,11 @@ class NotificationParams(BaseModel): _meta: Meta | None = None """ - This parameter name is reserved by MCP to allow clients and servers to attach additional metadata to their notifications. + This parameter name is reserved by MCP to allow clients and servers to attach + additional metadata to their notifications. """ + RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams) NotificationParamsT = TypeVar("NotificationParamsT", bound=NotificationParams) MethodT = TypeVar("MethodT", bound=str) @@ -68,7 +78,8 @@ class Result(BaseModel): _meta: dict[str, Any] | None = None """ - This result property is reserved by the protocol to allow clients and servers to attach additional metadata to their responses. + This result property is reserved by the protocol to allow clients and servers to + attach additional metadata to their responses. """ @@ -112,9 +123,15 @@ class ErrorData(BaseModel): code: int """The error type that occurred.""" message: str - """A short description of the error. The message SHOULD be limited to a concise single sentence.""" + """ + A short description of the error. The message SHOULD be limited to a concise single + sentence. + """ data: Any | None = None - """Additional information about the error. The value of this member is defined by the sender (e.g. detailed error information, nested errors etc.).""" + """ + Additional information about the error. The value of this member is defined by the + sender (e.g. detailed error information, nested errors etc.). + """ model_config = ConfigDict(extra="allow") @@ -182,14 +199,17 @@ class InitializeRequestParams(RequestParams): class InitializeRequest(Request): - """This request is sent from the client to the server when it first connects, asking it to begin initialization.""" + """ + This request is sent from the client to the server when it first connects, asking it + to begin initialization. + """ method: Literal["initialize"] params: InitializeRequestParams class InitializeResult(Result): - """After receiving an initialize request from the client, the server sends this response.""" + """After receiving an initialize request from the client, the server sends this.""" protocolVersion: Literal[1] """The version of the Model Context Protocol that the server wants to use.""" @@ -198,14 +218,20 @@ class InitializeResult(Result): class InitializedNotification(Notification): - """This notification is sent from the client to the server after initialization has finished.""" + """ + This notification is sent from the client to the server after initialization has + finished. + """ method: Literal["notifications/initialized"] params: NotificationParams | None = None class PingRequest(Request): - """A ping, issued by either the server or the client, to check that the other party is still alive.""" + """ + A ping, issued by either the server or the client, to check that the other party is + still alive. + """ method: Literal["ping"] params: RequestParams | None = None @@ -215,16 +241,25 @@ class ProgressNotificationParams(NotificationParams): """Parameters for progress notifications.""" progressToken: ProgressToken - """The progress token which was given in the initial request, used to associate this notification with the request that is proceeding.""" + """ + The progress token which was given in the initial request, used to associate this + notification with the request that is proceeding. + """ progress: float - """The progress thus far. This should increase every time progress is made, even if the total is unknown.""" + """ + The progress thus far. This should increase every time progress is made, even if the + total is unknown. + """ total: float | None = None """Total number of items to process (or total progress required), if known.""" model_config = ConfigDict(extra="allow") class ProgressNotification(Notification): - """An out-of-band notification used to inform the receiver of a progress update for a long-running request.""" + """ + An out-of-band notification used to inform the receiver of a progress update for a + long-running request. + """ method: Literal["notifications/progress"] params: ProgressNotificationParams @@ -251,13 +286,19 @@ class ResourceTemplate(BaseModel): """A template description for resources available on the server.""" uriTemplate: str - """A URI template (according to RFC 6570) that can be used to construct resource URIs.""" + """ + A URI template (according to RFC 6570) that can be used to construct resource + URIs. + """ name: str | None = None """A human-readable name for the type of resource this template refers to.""" description: str | None = None """A human-readable description of what this template is for.""" mimeType: str | None = None - """The MIME type for all resources that match this template. This should only be included if all resources matching this template have the same type.""" + """ + The MIME type for all resources that match this template. This should only be + included if all resources matching this template have the same type. + """ model_config = ConfigDict(extra="allow") @@ -272,7 +313,10 @@ class ReadResourceRequestParams(RequestParams): """Parameters for reading a resource.""" uri: AnyUrl - """The URI of the resource to read. The URI can use any protocol; it is up to the server how to interpret it.""" + """ + The URI of the resource to read. The URI can use any protocol; it is up to the + server how to interpret it. + """ model_config = ConfigDict(extra="allow") @@ -297,7 +341,10 @@ class TextResourceContents(ResourceContents): """Text contents of a resource.""" text: str - """The text of the item. This must only be set if the item can actually be represented as text (not binary data).""" + """ + The text of the item. This must only be set if the item can actually be represented + as text (not binary data). + """ class BlobResourceContents(ResourceContents): @@ -314,7 +361,10 @@ class ReadResourceResult(Result): class ResourceListChangedNotification(Notification): - """An optional notification from the server to the client, informing it that the list of resources it can read from has changed.""" + """ + An optional notification from the server to the client, informing it that the list + of resources it can read from has changed. + """ method: Literal["notifications/resources/list_changed"] params: NotificationParams | None = None @@ -324,12 +374,18 @@ class SubscribeRequestParams(RequestParams): """Parameters for subscribing to a resource.""" uri: AnyUrl - """The URI of the resource to subscribe to. The URI can use any protocol; it is up to the server how to interpret it.""" + """ + The URI of the resource to subscribe to. The URI can use any protocol; it is up to + the server how to interpret it. + """ model_config = ConfigDict(extra="allow") class SubscribeRequest(Request): - """Sent from the client to request resources/updated notifications from the server whenever a particular resource changes.""" + """ + Sent from the client to request resources/updated notifications from the server + whenever a particular resource changes. + """ method: Literal["resources/subscribe"] params: SubscribeRequestParams @@ -344,7 +400,10 @@ class UnsubscribeRequestParams(RequestParams): class UnsubscribeRequest(Request): - """Sent from the client to request cancellation of resources/updated notifications from the server.""" + """ + Sent from the client to request cancellation of resources/updated notifications from + the server. + """ method: Literal["resources/unsubscribe"] params: UnsubscribeRequestParams @@ -354,19 +413,25 @@ class ResourceUpdatedNotificationParams(NotificationParams): """Parameters for resource update notifications.""" uri: AnyUrl - """The URI of the resource that has been updated. This might be a sub-resource of the one that the client actually subscribed to.""" + """ + The URI of the resource that has been updated. This might be a sub-resource of the + one that the client actually subscribed to. + """ model_config = ConfigDict(extra="allow") class ResourceUpdatedNotification(Notification): - """A notification from the server to the client, informing it that a resource has changed and may need to be read again.""" + """ + A notification from the server to the client, informing it that a resource has + changed and may need to be read again. + """ method: Literal["notifications/resources/updated"] params: ResourceUpdatedNotificationParams class ListPromptsRequest(Request): - """Sent from the client to request a list of prompts and prompt templates the server has.""" + """Sent from the client to request a list of prompts and prompt templates.""" method: Literal["prompts/list"] params: RequestParams | None = None @@ -435,7 +500,10 @@ class ImageContent(BaseModel): data: str """The base64-encoded image data.""" mimeType: str - """The MIME type of the image. Different providers may support different image types.""" + """ + The MIME type of the image. Different providers may support different + image types. + """ model_config = ConfigDict(extra="allow") @@ -505,7 +573,10 @@ class CallToolResult(Result): class ToolListChangedNotification(Notification): - """An optional notification from the server to the client, informing it that the list of tools it offers has changed.""" + """ + An optional notification from the server to the client, informing it that the list + of tools it offers has changed. + """ method: Literal["notifications/tools/list_changed"] params: NotificationParams | None = None @@ -537,7 +608,10 @@ class LoggingMessageNotificationParams(NotificationParams): logger: str | None = None """An optional name of the logger issuing this message.""" data: Any - """The data to be logged, such as a string message or an object. Any JSON serializable type is allowed here.""" + """ + The data to be logged, such as a string message or an object. Any JSON serializable + type is allowed here. + """ model_config = ConfigDict(extra="allow") @@ -558,7 +632,10 @@ class CreateMessageRequestParams(RequestParams): systemPrompt: str | None = None """An optional system prompt the server wants to use for sampling.""" includeContext: IncludeContext | None = None - """A request to include context from one or more MCP servers (including the caller), to be attached to the prompt.""" + """ + A request to include context from one or more MCP servers (including the caller), to + be attached to the prompt. + """ temperature: float | None = None maxTokens: int """The maximum number of tokens to sample, as requested by the server.""" @@ -638,9 +715,15 @@ class Completion(BaseModel): values: list[str] """An array of completion values. Must not exceed 100 items.""" total: int | None = None - """The total number of completion options available. This can exceed the number of values actually sent in the response.""" + """ + The total number of completion options available. This can exceed the number of + values actually sent in the response. + """ hasMore: bool | None = None - """Indicates whether there are additional completion options beyond those provided in the current response, even if the exact total is unknown.""" + """ + Indicates whether there are additional completion options beyond those provided in + the current response, even if the exact total is unknown. + """ model_config = ConfigDict(extra="allow") diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 17634fe..cb7f038 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -59,14 +59,18 @@ async def test_client_session_initialize(): JSONRPCResponse( jsonrpc="2.0", id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + result=result.model_dump( + by_alias=True, mode="json", exclude_none=True + ), ) ) ) jsonrpc_notification = await client_to_server_receive.receive() assert isinstance(jsonrpc_notification.root, JSONRPCNotification) initialized_notification = ClientNotification.model_validate( - jsonrpc_notification.model_dump(by_alias=True, mode="json", exclude_none=True) + jsonrpc_notification.model_dump( + by_alias=True, mode="json", exclude_none=True + ) ) async def listen_session(): diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 01813a7..addf0f5 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -33,7 +33,13 @@ async def test_server_session_initialize(): nonlocal received_initialized async with ServerSession( - client_to_server_receive, server_to_client_send, InitializationOptions(server_name='mcp_python', server_version='0.1.0', capabilities=ServerCapabilities()) + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="mcp_python", + server_version="0.1.0", + capabilities=ServerCapabilities(), + ), ) as server_session: async for message in server_session.incoming_messages: if isinstance(message, Exception):