From b9b44e6dadfe0378bed59d00ab2fa3ca5b4a1a46 Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Mon, 11 Nov 2024 20:14:03 +0000 Subject: [PATCH] Don't re-export types We will be a bit more low level and expect callees to import mcp.types instead of relying in re-exported types. This makes usage more explicit and avoids potential collisions in mcp.server. --- src/mcp/client/session.py | 237 ++++++++++++------------------------ src/mcp/client/sse.py | 12 +- src/mcp/client/stdio.py | 12 +- src/mcp/server/__init__.py | 204 +++++++++++-------------------- src/mcp/server/session.py | 141 ++++++++------------- src/mcp/server/sse.py | 14 +-- src/mcp/server/stdio.py | 12 +- src/mcp/server/websocket.py | 12 +- 8 files changed, 231 insertions(+), 413 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index a9b8d54..0f3e313 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -5,83 +5,54 @@ from pydantic import AnyUrl from mcp.shared.session import BaseSession from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS -from mcp.types import ( - LATEST_PROTOCOL_VERSION, - CallToolResult, - ClientCapabilities, - ClientNotification, - ClientRequest, - ClientResult, - CompleteResult, - EmptyResult, - GetPromptResult, - Implementation, - InitializedNotification, - InitializeResult, - JSONRPCMessage, - ListPromptsResult, - ListResourcesResult, - ListToolsResult, - LoggingLevel, - PromptReference, - ReadResourceResult, - ResourceReference, - RootsCapability, - ServerNotification, - ServerRequest, -) +import mcp.types as types class ClientSession( BaseSession[ - ClientRequest, - ClientNotification, - ClientResult, - ServerRequest, - ServerNotification, + types.ClientRequest, + types.ClientNotification, + types.ClientResult, + types.ServerRequest, + types.ServerNotification, ] ): def __init__( self, - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[JSONRPCMessage], + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], + write_stream: MemoryObjectSendStream[types.JSONRPCMessage], read_timeout_seconds: timedelta | None = None, ) -> None: super().__init__( read_stream, write_stream, - ServerRequest, - ServerNotification, + types.ServerRequest, + types.ServerNotification, read_timeout_seconds=read_timeout_seconds, ) - async def initialize(self) -> InitializeResult: - from mcp.types import ( - InitializeRequest, - InitializeRequestParams, - ) - + async def initialize(self) -> types.InitializeResult: result = await self.send_request( - ClientRequest( - InitializeRequest( + types.ClientRequest( + types.InitializeRequest( method="initialize", - params=InitializeRequestParams( - protocolVersion=LATEST_PROTOCOL_VERSION, - capabilities=ClientCapabilities( + params=types.InitializeRequestParams( + protocolVersion=types.LATEST_PROTOCOL_VERSION, + capabilities=types.ClientCapabilities( sampling=None, experimental=None, - roots=RootsCapability( + roots=types.RootsCapability( # TODO: Should this be based on whether we # _will_ send notifications, or only whether # they're supported? listChanged=True ), ), - clientInfo=Implementation(name="mcp", version="0.1.0"), + clientInfo=types.Implementation(name="mcp", version="0.1.0"), ), ) ), - InitializeResult, + types.InitializeResult, ) if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS: @@ -91,40 +62,33 @@ class ClientSession( ) await self.send_notification( - ClientNotification( - InitializedNotification(method="notifications/initialized") + types.ClientNotification( + types.InitializedNotification(method="notifications/initialized") ) ) return result - async def send_ping(self) -> EmptyResult: + async def send_ping(self) -> types.EmptyResult: """Send a ping request.""" - from mcp.types import PingRequest - return await self.send_request( - ClientRequest( - PingRequest( + types.ClientRequest( + types.PingRequest( method="ping", ) ), - EmptyResult, + types.EmptyResult, ) async def send_progress_notification( self, progress_token: str | int, progress: float, total: float | None = None ) -> None: """Send a progress notification.""" - from mcp.types import ( - ProgressNotification, - ProgressNotificationParams, - ) - await self.send_notification( - ClientNotification( - ProgressNotification( + types.ClientNotification( + types.ProgressNotification( method="notifications/progress", - params=ProgressNotificationParams( + params=types.ProgressNotificationParams( progressToken=progress_token, progress=progress, total=total, @@ -133,180 +97,137 @@ class ClientSession( ) ) - async def set_logging_level(self, level: LoggingLevel) -> EmptyResult: + async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult: """Send a logging/setLevel request.""" - from mcp.types import ( - SetLevelRequest, - SetLevelRequestParams, - ) - return await self.send_request( - ClientRequest( - SetLevelRequest( + types.ClientRequest( + types.SetLevelRequest( method="logging/setLevel", - params=SetLevelRequestParams(level=level), + params=types.SetLevelRequestParams(level=level), ) ), - EmptyResult, + types.EmptyResult, ) - async def list_resources(self) -> ListResourcesResult: + async def list_resources(self) -> types.ListResourcesResult: """Send a resources/list request.""" - from mcp.types import ( - ListResourcesRequest, - ) - return await self.send_request( - ClientRequest( - ListResourcesRequest( + types.ClientRequest( + types.ListResourcesRequest( method="resources/list", ) ), - ListResourcesResult, + types.ListResourcesResult, ) - async def read_resource(self, uri: AnyUrl) -> ReadResourceResult: + async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: """Send a resources/read request.""" - from mcp.types import ( - ReadResourceRequest, - ReadResourceRequestParams, - ) - return await self.send_request( - ClientRequest( - ReadResourceRequest( + types.ClientRequest( + types.ReadResourceRequest( method="resources/read", - params=ReadResourceRequestParams(uri=uri), + params=types.ReadResourceRequestParams(uri=uri), ) ), - ReadResourceResult, + types.ReadResourceResult, ) - async def subscribe_resource(self, uri: AnyUrl) -> EmptyResult: + async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: """Send a resources/subscribe request.""" - from mcp.types import ( - SubscribeRequest, - SubscribeRequestParams, - ) - return await self.send_request( - ClientRequest( - SubscribeRequest( + types.ClientRequest( + types.SubscribeRequest( method="resources/subscribe", - params=SubscribeRequestParams(uri=uri), + params=types.SubscribeRequestParams(uri=uri), ) ), - EmptyResult, + types.EmptyResult, ) - async def unsubscribe_resource(self, uri: AnyUrl) -> EmptyResult: + async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: """Send a resources/unsubscribe request.""" - from mcp.types import ( - UnsubscribeRequest, - UnsubscribeRequestParams, - ) - return await self.send_request( - ClientRequest( - UnsubscribeRequest( + types.ClientRequest( + types.UnsubscribeRequest( method="resources/unsubscribe", - params=UnsubscribeRequestParams(uri=uri), + params=types.UnsubscribeRequestParams(uri=uri), ) ), - EmptyResult, + types.EmptyResult, ) async def call_tool( self, name: str, arguments: dict | None = None - ) -> CallToolResult: + ) -> types.CallToolResult: """Send a tools/call request.""" - from mcp.types import ( - CallToolRequest, - CallToolRequestParams, - ) - return await self.send_request( - ClientRequest( - CallToolRequest( + types.ClientRequest( + types.CallToolRequest( method="tools/call", - params=CallToolRequestParams(name=name, arguments=arguments), + params=types.CallToolRequestParams(name=name, arguments=arguments), ) ), - CallToolResult, + types.CallToolResult, ) - async def list_prompts(self) -> ListPromptsResult: + async def list_prompts(self) -> types.ListPromptsResult: """Send a prompts/list request.""" - from mcp.types import ListPromptsRequest - return await self.send_request( - ClientRequest( - ListPromptsRequest( + types.ClientRequest( + types.ListPromptsRequest( method="prompts/list", ) ), - ListPromptsResult, + types.ListPromptsResult, ) async def get_prompt( self, name: str, arguments: dict[str, str] | None = None - ) -> GetPromptResult: + ) -> types.GetPromptResult: """Send a prompts/get request.""" - from mcp.types import GetPromptRequest, GetPromptRequestParams - return await self.send_request( - ClientRequest( - GetPromptRequest( + types.ClientRequest( + types.GetPromptRequest( method="prompts/get", - params=GetPromptRequestParams(name=name, arguments=arguments), + params=types.GetPromptRequestParams(name=name, arguments=arguments), ) ), - GetPromptResult, + types.GetPromptResult, ) async def complete( - self, ref: ResourceReference | PromptReference, argument: dict - ) -> CompleteResult: + self, ref: types.ResourceReference | types.PromptReference, argument: dict + ) -> types.CompleteResult: """Send a completion/complete request.""" - from mcp.types import ( - CompleteRequest, - CompleteRequestParams, - CompletionArgument, - ) - return await self.send_request( - ClientRequest( - CompleteRequest( + types.ClientRequest( + types.CompleteRequest( method="completion/complete", - params=CompleteRequestParams( + params=types.CompleteRequestParams( ref=ref, - argument=CompletionArgument(**argument), + argument=types.CompletionArgument(**argument), ), ) ), - CompleteResult, + types.CompleteResult, ) - async def list_tools(self) -> ListToolsResult: + async def list_tools(self) -> types.ListToolsResult: """Send a tools/list request.""" - from mcp.types import ListToolsRequest - return await self.send_request( - ClientRequest( - ListToolsRequest( + types.ClientRequest( + types.ListToolsRequest( method="tools/list", ) ), - ListToolsResult, + types.ListToolsResult, ) async def send_roots_list_changed(self) -> None: """Send a roots/list_changed notification.""" - from mcp.types import RootsListChangedNotification - await self.send_notification( - ClientNotification( - RootsListChangedNotification( + types.ClientNotification( + types.RootsListChangedNotification( method="notifications/roots/list_changed", ) ) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index b5c36db..c79f48a 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -9,7 +9,7 @@ from anyio.abc import TaskStatus from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import aconnect_sse -from mcp.types import JSONRPCMessage +import mcp.types as types logger = logging.getLogger(__name__) @@ -31,11 +31,11 @@ async def sse_client( `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. """ - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] - write_stream: MemoryObjectSendStream[JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] + write_stream: MemoryObjectSendStream[types.JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -85,7 +85,7 @@ async def sse_client( case "message": try: message = ( - JSONRPCMessage.model_validate_json( + types.JSONRPCMessage.model_validate_json( sse.data ) ) diff --git a/src/mcp/client/stdio.py b/src/mcp/client/stdio.py index 6a29138..e79a816 100644 --- a/src/mcp/client/stdio.py +++ b/src/mcp/client/stdio.py @@ -8,7 +8,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStre from anyio.streams.text import TextReceiveStream from pydantic import BaseModel, Field -from mcp.types import JSONRPCMessage +import mcp.types as types # Environment variables to inherit by default DEFAULT_INHERITED_ENV_VARS = ( @@ -72,11 +72,11 @@ async def stdio_client(server: StdioServerParameters): Client transport for stdio: this will connect to a server by spawning a process and communicating with it over stdin/stdout. """ - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] - write_stream: MemoryObjectSendStream[JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] + write_stream: MemoryObjectSendStream[types.JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -99,7 +99,7 @@ async def stdio_client(server: StdioServerParameters): for line in lines: try: - message = JSONRPCMessage.model_validate_json(line) + message = types.JSONRPCMessage.model_validate_json(line) except Exception as exc: await read_stream_writer.send(exc) continue diff --git a/src/mcp/server/__init__.py b/src/mcp/server/__init__.py index 29133e0..abfc40d 100644 --- a/src/mcp/server/__init__.py +++ b/src/mcp/server/__init__.py @@ -12,48 +12,7 @@ from mcp.server.session import ServerSession from mcp.server.stdio import stdio_server as stdio_server from mcp.shared.context import RequestContext from mcp.shared.session import RequestResponder -from mcp.types import ( - METHOD_NOT_FOUND, - CallToolRequest, - GetPromptResult, - GetPromptRequest, - GetPromptResult, - ImageContent, - ClientNotification, - ClientRequest, - CompleteRequest, - EmbeddedResource, - EmptyResult, - ErrorData, - JSONRPCMessage, - ListPromptsRequest, - ListPromptsResult, - ListResourcesRequest, - ListResourcesResult, - ListToolsRequest, - ListToolsResult, - LoggingCapability, - LoggingLevel, - PingRequest, - ProgressNotification, - Prompt, - PromptMessage, - PromptReference, - PromptsCapability, - ReadResourceRequest, - ReadResourceResult, - Resource, - ResourceReference, - ResourcesCapability, - ServerCapabilities, - ServerResult, - SetLevelRequest, - SubscribeRequest, - TextContent, - Tool, - ToolsCapability, - UnsubscribeRequest, -) +import mcp.types as types logger = logging.getLogger(__name__) @@ -77,8 +36,8 @@ class NotificationOptions: class Server: def __init__(self, name: str): self.name = name - self.request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]] = { - PingRequest: _ping_handler, + self.request_handlers: dict[type, Callable[..., Awaitable[types.ServerResult]]] = { + types.PingRequest: _ping_handler, } self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} self.notification_options = NotificationOptions() @@ -116,7 +75,7 @@ class Server: self, notification_options: NotificationOptions, experimental_capabilities: dict[str, dict[str, Any]], - ) -> ServerCapabilities: + ) -> types.ServerCapabilities: """Convert existing handlers to a ServerCapabilities object.""" prompts_capability = None resources_capability = None @@ -124,28 +83,28 @@ class Server: logging_capability = None # Set prompt capabilities if handler exists - if ListPromptsRequest in self.request_handlers: - prompts_capability = PromptsCapability( + if types.ListPromptsRequest in self.request_handlers: + prompts_capability = types.PromptsCapability( listChanged=notification_options.prompts_changed ) # Set resource capabilities if handler exists - if ListResourcesRequest in self.request_handlers: - resources_capability = ResourcesCapability( + if types.ListResourcesRequest in self.request_handlers: + resources_capability = types.ResourcesCapability( subscribe=False, listChanged=notification_options.resources_changed ) # Set tool capabilities if handler exists - if ListToolsRequest in self.request_handlers: - tools_capability = ToolsCapability( + if types.ListToolsRequest in self.request_handlers: + tools_capability = types.ToolsCapability( listChanged=notification_options.tools_changed ) # Set logging capabilities if handler exists - if SetLevelRequest in self.request_handlers: - logging_capability = LoggingCapability() + if types.SetLevelRequest in self.request_handlers: + logging_capability = types.LoggingCapability() - return ServerCapabilities( + return types.ServerCapabilities( prompts=prompts_capability, resources=resources_capability, tools=tools_capability, @@ -159,14 +118,14 @@ class Server: return request_ctx.get() def list_prompts(self): - def decorator(func: Callable[[], Awaitable[list[Prompt]]]): + def decorator(func: Callable[[], Awaitable[list[types.Prompt]]]): logger.debug("Registering handler for PromptListRequest") async def handler(_: Any): prompts = await func() - return ServerResult(ListPromptsResult(prompts=prompts)) + return types.ServerResult(types.ListPromptsResult(prompts=prompts)) - self.request_handlers[ListPromptsRequest] = handler + self.request_handlers[types.ListPromptsRequest] = handler return func return decorator @@ -174,47 +133,42 @@ class Server: def get_prompt(self): def decorator( func: Callable[ - [str, dict[str, str] | None], Awaitable[GetPromptResult] + [str, dict[str, str] | None], Awaitable[types.GetPromptResult] ], ): logger.debug("Registering handler for GetPromptRequest") - async def handler(req: GetPromptRequest): + async def handler(req: types.GetPromptRequest): prompt_get = await func(req.params.name, req.params.arguments) - return ServerResult(prompt_get) + return types.ServerResult(prompt_get) - self.request_handlers[GetPromptRequest] = handler + self.request_handlers[types.GetPromptRequest] = handler return func return decorator def list_resources(self): - def decorator(func: Callable[[], Awaitable[list[Resource]]]): + def decorator(func: Callable[[], Awaitable[list[types.Resource]]]): logger.debug("Registering handler for ListResourcesRequest") async def handler(_: Any): resources = await func() - return ServerResult(ListResourcesResult(resources=resources)) + return types.ServerResult(types.ListResourcesResult(resources=resources)) - self.request_handlers[ListResourcesRequest] = handler + self.request_handlers[types.ListResourcesRequest] = handler return func return decorator def read_resource(self): - from mcp.types import ( - BlobResourceContents, - TextResourceContents, - ) - def decorator(func: Callable[[AnyUrl], Awaitable[str | bytes]]): logger.debug("Registering handler for ReadResourceRequest") - async def handler(req: ReadResourceRequest): + async def handler(req: types.ReadResourceRequest): result = await func(req.params.uri) match result: case str(s): - content = TextResourceContents( + content = types.TextResourceContents( uri=req.params.uri, text=s, mimeType="text/plain", @@ -222,130 +176,117 @@ class Server: case bytes(b): import base64 - content = BlobResourceContents( + content = types.BlobResourceContents( uri=req.params.uri, blob=base64.urlsafe_b64encode(b).decode(), mimeType="application/octet-stream", ) - return ServerResult( - ReadResourceResult( + return types.ServerResult( + types.ReadResourceResult( contents=[content], ) ) - self.request_handlers[ReadResourceRequest] = handler + self.request_handlers[types.ReadResourceRequest] = handler return func return decorator def set_logging_level(self): - from mcp.types import EmptyResult - - def decorator(func: Callable[[LoggingLevel], Awaitable[None]]): + def decorator(func: Callable[[types.LoggingLevel], Awaitable[None]]): logger.debug("Registering handler for SetLevelRequest") - async def handler(req: SetLevelRequest): + async def handler(req: types.SetLevelRequest): await func(req.params.level) - return ServerResult(EmptyResult()) + return types.ServerResult(types.EmptyResult()) - self.request_handlers[SetLevelRequest] = handler + self.request_handlers[types.SetLevelRequest] = handler return func return decorator def subscribe_resource(self): - from mcp.types import EmptyResult - def decorator(func: Callable[[AnyUrl], Awaitable[None]]): logger.debug("Registering handler for SubscribeRequest") - async def handler(req: SubscribeRequest): + async def handler(req: types.SubscribeRequest): await func(req.params.uri) - return ServerResult(EmptyResult()) + return types.ServerResult(types.EmptyResult()) - self.request_handlers[SubscribeRequest] = handler + self.request_handlers[types.SubscribeRequest] = handler return func return decorator def unsubscribe_resource(self): - from mcp.types import EmptyResult - def decorator(func: Callable[[AnyUrl], Awaitable[None]]): logger.debug("Registering handler for UnsubscribeRequest") - async def handler(req: UnsubscribeRequest): + async def handler(req: types.UnsubscribeRequest): await func(req.params.uri) - return ServerResult(EmptyResult()) + return types.ServerResult(types.EmptyResult()) - self.request_handlers[UnsubscribeRequest] = handler + self.request_handlers[types.UnsubscribeRequest] = handler return func return decorator def list_tools(self): - def decorator(func: Callable[[], Awaitable[list[Tool]]]): + def decorator(func: Callable[[], Awaitable[list[types.Tool]]]): logger.debug("Registering handler for ListToolsRequest") async def handler(_: Any): tools = await func() - return ServerResult(ListToolsResult(tools=tools)) + return types.ServerResult(types.ListToolsResult(tools=tools)) - self.request_handlers[ListToolsRequest] = handler + self.request_handlers[types.ListToolsRequest] = handler return func return decorator def call_tool(self): - from mcp.types import ( - CallToolResult, - EmbeddedResource, - ImageContent, - TextContent, - ) - def decorator( func: Callable[ ..., - Awaitable[Sequence[TextContent | ImageContent | EmbeddedResource]], + Awaitable[Sequence[types.TextContent | types.ImageContent | types.EmbeddedResource]], ], ): logger.debug("Registering handler for CallToolRequest") - async def handler(req: CallToolRequest): + async def handler(req: types.CallToolRequest): try: results = await func(req.params.name, (req.params.arguments or {})) content = [] for result in results: match result: case str() as text: - content.append(TextContent(type="text", text=text)) - case ImageContent() as img: + content.append(types.TextContent(type="text", text=text)) + case types.ImageContent() as img: content.append( - ImageContent( + types.ImageContent( type="image", data=img.data, mimeType=img.mimeType, ) ) - case EmbeddedResource() as resource: + case types.EmbeddedResource() as resource: content.append( - EmbeddedResource( + types.EmbeddedResource( type="resource", resource=resource.resource ) ) - return ServerResult(CallToolResult(content=content, isError=False)) + return types.ServerResult(types.CallToolResult(content=content, isError=False)) except Exception as e: - return ServerResult( - CallToolResult( - content=[TextContent(type="text", text=str(e))], + return types.ServerResult( + types.CallToolResult( + content=[types.TextContent(type="text", text=str(e))], isError=True, ) ) - self.request_handlers[CallToolRequest] = handler + self.request_handlers[types.CallToolRequest] = handler return func return decorator @@ -356,47 +297,46 @@ class Server: ): logger.debug("Registering handler for ProgressNotification") - async def handler(req: ProgressNotification): + async def handler(req: types.ProgressNotification): await func( req.params.progressToken, req.params.progress, req.params.total ) - self.notification_handlers[ProgressNotification] = handler + self.notification_handlers[types.ProgressNotification] = handler return func return decorator def completion(self): """Provides completions for prompts and resource templates""" - from mcp.types import CompleteResult, Completion, CompletionArgument def decorator( func: Callable[ - [PromptReference | ResourceReference, CompletionArgument], - Awaitable[Completion | None], + [types.PromptReference | types.ResourceReference, types.CompletionArgument], + Awaitable[types.Completion | None], ], ): logger.debug("Registering handler for CompleteRequest") - async def handler(req: CompleteRequest): + async def handler(req: types.CompleteRequest): completion = await func(req.params.ref, req.params.argument) - return ServerResult( - CompleteResult( + return types.ServerResult( + types.CompleteResult( completion=completion if completion is not None - else Completion(values=[], total=None, hasMore=None), + else types.Completion(values=[], total=None, hasMore=None), ) ) - self.request_handlers[CompleteRequest] = handler + self.request_handlers[types.CompleteRequest] = handler return func return decorator async def run( self, - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[JSONRPCMessage], + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], + write_stream: MemoryObjectSendStream[types.JSONRPCMessage], initialization_options: InitializationOptions, # When True, exceptions are returned as messages to the client. # When False, exceptions are raised, which will cause the server to shut down @@ -412,7 +352,7 @@ class Server: logger.debug(f"Received message: {message}") match message: - case RequestResponder(request=ClientRequest(root=req)): + case RequestResponder(request=types.ClientRequest(root=req)): logger.info( f"Processing request of type {type(req).__name__}" ) @@ -437,7 +377,7 @@ class Server: except Exception as err: if raise_exceptions: raise err - response = ErrorData( + response = types.ErrorData( code=0, message=str(err), data=None ) finally: @@ -448,14 +388,14 @@ class Server: await message.respond(response) else: await message.respond( - ErrorData( - code=METHOD_NOT_FOUND, + types.ErrorData( + code=types.METHOD_NOT_FOUND, message="Method not found", ) ) logger.debug("Response sent") - case ClientNotification(root=notify): + case types.ClientNotification(root=notify): if type(notify) in self.notification_handlers: assert type(notify) in self.notification_handlers @@ -479,5 +419,5 @@ class Server: ) -async def _ping_handler(request: PingRequest) -> ServerResult: - return ServerResult(EmptyResult()) +async def _ping_handler(request: types.PingRequest) -> types.ServerResult: + return types.ServerResult(types.EmptyResult()) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 03e6882..97b70bd 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -11,29 +11,7 @@ from mcp.shared.session import ( BaseSession, RequestResponder, ) -from mcp.types import ( - LATEST_PROTOCOL_VERSION, - ClientNotification, - ClientRequest, - CreateMessageResult, - EmptyResult, - Implementation, - IncludeContext, - InitializedNotification, - InitializeRequest, - InitializeResult, - JSONRPCMessage, - ListRootsResult, - LoggingLevel, - ModelPreferences, - PromptListChangedNotification, - ResourceListChangedNotification, - SamplingMessage, - ServerNotification, - ServerRequest, - ServerResult, - ToolListChangedNotification, -) +import mcp.types as types class InitializationState(Enum): @@ -44,37 +22,37 @@ class InitializationState(Enum): class ServerSession( BaseSession[ - ServerRequest, - ServerNotification, - ServerResult, - ClientRequest, - ClientNotification, + types.ServerRequest, + types.ServerNotification, + types.ServerResult, + types.ClientRequest, + types.ClientNotification, ] ): _initialized: InitializationState = InitializationState.NotInitialized def __init__( self, - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[JSONRPCMessage], + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], + write_stream: MemoryObjectSendStream[types.JSONRPCMessage], init_options: InitializationOptions, ) -> None: - super().__init__(read_stream, write_stream, ClientRequest, ClientNotification) + super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification) self._initialization_state = InitializationState.NotInitialized self._init_options = init_options async def _received_request( - self, responder: RequestResponder[ClientRequest, ServerResult] + self, responder: RequestResponder[types.ClientRequest, types.ServerResult] ): match responder.request.root: - case InitializeRequest(): + case types.InitializeRequest(): self._initialization_state = InitializationState.Initializing await responder.respond( - ServerResult( - InitializeResult( - protocolVersion=LATEST_PROTOCOL_VERSION, + types.ServerResult( + types.InitializeResult( + protocolVersion=types.LATEST_PROTOCOL_VERSION, capabilities=self._init_options.capabilities, - serverInfo=Implementation( + serverInfo=types.Implementation( name=self._init_options.server_name, version=self._init_options.server_version, ), @@ -87,11 +65,11 @@ class ServerSession( "Received request before initialization was complete" ) - async def _received_notification(self, notification: ClientNotification) -> None: + async def _received_notification(self, notification: types.ClientNotification) -> None: # Need this to avoid ASYNC910 await anyio.lowlevel.checkpoint() match notification.root: - case InitializedNotification(): + case types.InitializedNotification(): self._initialization_state = InitializationState.Initialized case _: if self._initialization_state != InitializationState.Initialized: @@ -100,19 +78,14 @@ class ServerSession( ) async def send_log_message( - self, level: LoggingLevel, data: Any, logger: str | None = None + self, level: types.LoggingLevel, data: Any, logger: str | None = None ) -> None: """Send a log message notification.""" - from mcp.types import ( - LoggingMessageNotification, - LoggingMessageNotificationParams, - ) - await self.send_notification( - ServerNotification( - LoggingMessageNotification( + types.ServerNotification( + types.LoggingMessageNotification( method="notifications/message", - params=LoggingMessageNotificationParams( + params=types.LoggingMessageNotificationParams( level=level, data=data, logger=logger, @@ -123,43 +96,33 @@ class ServerSession( async def send_resource_updated(self, uri: AnyUrl) -> None: """Send a resource updated notification.""" - from mcp.types import ( - ResourceUpdatedNotification, - ResourceUpdatedNotificationParams, - ) - await self.send_notification( - ServerNotification( - ResourceUpdatedNotification( + types.ServerNotification( + types.ResourceUpdatedNotification( method="notifications/resources/updated", - params=ResourceUpdatedNotificationParams(uri=uri), + params=types.ResourceUpdatedNotificationParams(uri=uri), ) ) ) async def create_message( self, - messages: list[SamplingMessage], + messages: list[types.SamplingMessage], *, max_tokens: int, system_prompt: str | None = None, - include_context: IncludeContext | None = None, + include_context: types.IncludeContext | None = None, temperature: float | None = None, stop_sequences: list[str] | None = None, metadata: dict[str, Any] | None = None, - model_preferences: ModelPreferences | None = None, - ) -> CreateMessageResult: + model_preferences: types.ModelPreferences | None = None, + ) -> types.CreateMessageResult: """Send a sampling/create_message request.""" - from mcp.types import ( - CreateMessageRequest, - CreateMessageRequestParams, - ) - return await self.send_request( - ServerRequest( - CreateMessageRequest( + types.ServerRequest( + types.CreateMessageRequest( method="sampling/createMessage", - params=CreateMessageRequestParams( + params=types.CreateMessageRequestParams( messages=messages, systemPrompt=system_prompt, includeContext=include_context, @@ -171,46 +134,40 @@ class ServerSession( ), ) ), - CreateMessageResult, + types.CreateMessageResult, ) - async def list_roots(self) -> ListRootsResult: + async def list_roots(self) -> types.ListRootsResult: """Send a roots/list request.""" - from mcp.types import ListRootsRequest - return await self.send_request( - ServerRequest( - ListRootsRequest( + types.ServerRequest( + types.ListRootsRequest( method="roots/list", ) ), - ListRootsResult, + types.ListRootsResult, ) - async def send_ping(self) -> EmptyResult: + async def send_ping(self) -> types.EmptyResult: """Send a ping request.""" - from mcp.types import PingRequest - return await self.send_request( - ServerRequest( - PingRequest( + types.ServerRequest( + types.PingRequest( method="ping", ) ), - EmptyResult, + types.EmptyResult, ) async def send_progress_notification( self, progress_token: str | int, progress: float, total: float | None = None ) -> None: """Send a progress notification.""" - from mcp.types import ProgressNotification, ProgressNotificationParams - await self.send_notification( - ServerNotification( - ProgressNotification( + types.ServerNotification( + types.ProgressNotification( method="notifications/progress", - params=ProgressNotificationParams( + params=types.ProgressNotificationParams( progressToken=progress_token, progress=progress, total=total, @@ -222,8 +179,8 @@ class ServerSession( async def send_resource_list_changed(self) -> None: """Send a resource list changed notification.""" await self.send_notification( - ServerNotification( - ResourceListChangedNotification( + types.ServerNotification( + types.ResourceListChangedNotification( method="notifications/resources/list_changed", ) ) @@ -232,8 +189,8 @@ class ServerSession( async def send_tool_list_changed(self) -> None: """Send a tool list changed notification.""" await self.send_notification( - ServerNotification( - ToolListChangedNotification( + types.ServerNotification( + types.ToolListChangedNotification( method="notifications/tools/list_changed", ) ) @@ -242,8 +199,8 @@ class ServerSession( async def send_prompt_list_changed(self) -> None: """Send a prompt list changed notification.""" await self.send_notification( - ServerNotification( - PromptListChangedNotification( + types.ServerNotification( + types.PromptListChangedNotification( method="notifications/prompts/list_changed", ) ) diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 92ebb7a..4074fdb 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -12,7 +12,7 @@ from starlette.requests import Request from starlette.responses import Response from starlette.types import Receive, Scope, Send -from mcp.types import JSONRPCMessage +import mcp.types as types logger = logging.getLogger(__name__) @@ -30,7 +30,7 @@ class SseServerTransport: """ _endpoint: str - _read_stream_writers: dict[UUID, MemoryObjectSendStream[JSONRPCMessage | Exception]] + _read_stream_writers: dict[UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception]] def __init__(self, endpoint: str) -> None: """ @@ -50,11 +50,11 @@ class SseServerTransport: raise ValueError("connect_sse can only handle HTTP requests") logger.debug("Setting up SSE connection") - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] - write_stream: MemoryObjectSendStream[JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] + write_stream: MemoryObjectSendStream[types.JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -125,7 +125,7 @@ class SseServerTransport: logger.debug(f"Received JSON: {json}") try: - message = JSONRPCMessage.model_validate(json) + message = types.JSONRPCMessage.model_validate(json) logger.debug(f"Validated client message: {message}") except ValidationError as err: logger.error(f"Failed to parse message: {err}") diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 29f6bb6..ffe4081 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -5,7 +5,7 @@ import anyio import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from mcp.types import JSONRPCMessage +import mcp.types as types @asynccontextmanager @@ -24,11 +24,11 @@ async def stdio_server( if not stdout: stdout = anyio.wrap_file(sys.stdout) - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] - write_stream: MemoryObjectSendStream[JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] + write_stream: MemoryObjectSendStream[types.JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -38,7 +38,7 @@ async def stdio_server( async with read_stream_writer: async for line in stdin: try: - message = JSONRPCMessage.model_validate_json(line) + message = types.JSONRPCMessage.model_validate_json(line) except Exception as exc: await read_stream_writer.send(exc) continue diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index 2a6d812..bd3d632 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -6,7 +6,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStre from starlette.types import Receive, Scope, Send from starlette.websockets import WebSocket -from mcp.types import JSONRPCMessage +import mcp.types as types logger = logging.getLogger(__name__) @@ -21,11 +21,11 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send): websocket = WebSocket(scope, receive, send) await websocket.accept(subprotocol="mcp") - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] - write_stream: MemoryObjectSendStream[JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] + write_stream: MemoryObjectSendStream[types.JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -35,7 +35,7 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send): async with read_stream_writer: async for message in websocket.iter_json(): try: - client_message = JSONRPCMessage.model_validate(message) + client_message = types.JSONRPCMessage.model_validate(message) except Exception as exc: await read_stream_writer.send(exc) continue