Don't re-export types

We will be a bit more low level and expect callees to import mcp.types
instead of relying in re-exported types. This makes usage more explicit
and avoids potential collisions in mcp.server.
This commit is contained in:
David Soria Parra
2024-11-11 20:14:03 +00:00
parent f5d82bd229
commit b9b44e6dad
8 changed files with 231 additions and 413 deletions

View File

@@ -5,83 +5,54 @@ from pydantic import AnyUrl
from mcp.shared.session import BaseSession from mcp.shared.session import BaseSession
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
from mcp.types import ( import mcp.types as types
LATEST_PROTOCOL_VERSION,
CallToolResult,
ClientCapabilities,
ClientNotification,
ClientRequest,
ClientResult,
CompleteResult,
EmptyResult,
GetPromptResult,
Implementation,
InitializedNotification,
InitializeResult,
JSONRPCMessage,
ListPromptsResult,
ListResourcesResult,
ListToolsResult,
LoggingLevel,
PromptReference,
ReadResourceResult,
ResourceReference,
RootsCapability,
ServerNotification,
ServerRequest,
)
class ClientSession( class ClientSession(
BaseSession[ BaseSession[
ClientRequest, types.ClientRequest,
ClientNotification, types.ClientNotification,
ClientResult, types.ClientResult,
ServerRequest, types.ServerRequest,
ServerNotification, types.ServerNotification,
] ]
): ):
def __init__( def __init__(
self, self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage], write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
read_timeout_seconds: timedelta | None = None, read_timeout_seconds: timedelta | None = None,
) -> None: ) -> None:
super().__init__( super().__init__(
read_stream, read_stream,
write_stream, write_stream,
ServerRequest, types.ServerRequest,
ServerNotification, types.ServerNotification,
read_timeout_seconds=read_timeout_seconds, read_timeout_seconds=read_timeout_seconds,
) )
async def initialize(self) -> InitializeResult: async def initialize(self) -> types.InitializeResult:
from mcp.types import (
InitializeRequest,
InitializeRequestParams,
)
result = await self.send_request( result = await self.send_request(
ClientRequest( types.ClientRequest(
InitializeRequest( types.InitializeRequest(
method="initialize", method="initialize",
params=InitializeRequestParams( params=types.InitializeRequestParams(
protocolVersion=LATEST_PROTOCOL_VERSION, protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=ClientCapabilities( capabilities=types.ClientCapabilities(
sampling=None, sampling=None,
experimental=None, experimental=None,
roots=RootsCapability( roots=types.RootsCapability(
# TODO: Should this be based on whether we # TODO: Should this be based on whether we
# _will_ send notifications, or only whether # _will_ send notifications, or only whether
# they're supported? # they're supported?
listChanged=True listChanged=True
), ),
), ),
clientInfo=Implementation(name="mcp", version="0.1.0"), clientInfo=types.Implementation(name="mcp", version="0.1.0"),
), ),
) )
), ),
InitializeResult, types.InitializeResult,
) )
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS: if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
@@ -91,40 +62,33 @@ class ClientSession(
) )
await self.send_notification( await self.send_notification(
ClientNotification( types.ClientNotification(
InitializedNotification(method="notifications/initialized") types.InitializedNotification(method="notifications/initialized")
) )
) )
return result return result
async def send_ping(self) -> EmptyResult: async def send_ping(self) -> types.EmptyResult:
"""Send a ping request.""" """Send a ping request."""
from mcp.types import PingRequest
return await self.send_request( return await self.send_request(
ClientRequest( types.ClientRequest(
PingRequest( types.PingRequest(
method="ping", method="ping",
) )
), ),
EmptyResult, types.EmptyResult,
) )
async def send_progress_notification( async def send_progress_notification(
self, progress_token: str | int, progress: float, total: float | None = None self, progress_token: str | int, progress: float, total: float | None = None
) -> None: ) -> None:
"""Send a progress notification.""" """Send a progress notification."""
from mcp.types import (
ProgressNotification,
ProgressNotificationParams,
)
await self.send_notification( await self.send_notification(
ClientNotification( types.ClientNotification(
ProgressNotification( types.ProgressNotification(
method="notifications/progress", method="notifications/progress",
params=ProgressNotificationParams( params=types.ProgressNotificationParams(
progressToken=progress_token, progressToken=progress_token,
progress=progress, progress=progress,
total=total, total=total,
@@ -133,180 +97,137 @@ class ClientSession(
) )
) )
async def set_logging_level(self, level: LoggingLevel) -> EmptyResult: async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
"""Send a logging/setLevel request.""" """Send a logging/setLevel request."""
from mcp.types import (
SetLevelRequest,
SetLevelRequestParams,
)
return await self.send_request( return await self.send_request(
ClientRequest( types.ClientRequest(
SetLevelRequest( types.SetLevelRequest(
method="logging/setLevel", method="logging/setLevel",
params=SetLevelRequestParams(level=level), params=types.SetLevelRequestParams(level=level),
) )
), ),
EmptyResult, types.EmptyResult,
) )
async def list_resources(self) -> ListResourcesResult: async def list_resources(self) -> types.ListResourcesResult:
"""Send a resources/list request.""" """Send a resources/list request."""
from mcp.types import (
ListResourcesRequest,
)
return await self.send_request( return await self.send_request(
ClientRequest( types.ClientRequest(
ListResourcesRequest( types.ListResourcesRequest(
method="resources/list", method="resources/list",
) )
), ),
ListResourcesResult, types.ListResourcesResult,
) )
async def read_resource(self, uri: AnyUrl) -> ReadResourceResult: async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
"""Send a resources/read request.""" """Send a resources/read request."""
from mcp.types import (
ReadResourceRequest,
ReadResourceRequestParams,
)
return await self.send_request( return await self.send_request(
ClientRequest( types.ClientRequest(
ReadResourceRequest( types.ReadResourceRequest(
method="resources/read", method="resources/read",
params=ReadResourceRequestParams(uri=uri), params=types.ReadResourceRequestParams(uri=uri),
) )
), ),
ReadResourceResult, types.ReadResourceResult,
) )
async def subscribe_resource(self, uri: AnyUrl) -> EmptyResult: async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
"""Send a resources/subscribe request.""" """Send a resources/subscribe request."""
from mcp.types import (
SubscribeRequest,
SubscribeRequestParams,
)
return await self.send_request( return await self.send_request(
ClientRequest( types.ClientRequest(
SubscribeRequest( types.SubscribeRequest(
method="resources/subscribe", method="resources/subscribe",
params=SubscribeRequestParams(uri=uri), params=types.SubscribeRequestParams(uri=uri),
) )
), ),
EmptyResult, types.EmptyResult,
) )
async def unsubscribe_resource(self, uri: AnyUrl) -> EmptyResult: async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
"""Send a resources/unsubscribe request.""" """Send a resources/unsubscribe request."""
from mcp.types import (
UnsubscribeRequest,
UnsubscribeRequestParams,
)
return await self.send_request( return await self.send_request(
ClientRequest( types.ClientRequest(
UnsubscribeRequest( types.UnsubscribeRequest(
method="resources/unsubscribe", method="resources/unsubscribe",
params=UnsubscribeRequestParams(uri=uri), params=types.UnsubscribeRequestParams(uri=uri),
) )
), ),
EmptyResult, types.EmptyResult,
) )
async def call_tool( async def call_tool(
self, name: str, arguments: dict | None = None self, name: str, arguments: dict | None = None
) -> CallToolResult: ) -> types.CallToolResult:
"""Send a tools/call request.""" """Send a tools/call request."""
from mcp.types import (
CallToolRequest,
CallToolRequestParams,
)
return await self.send_request( return await self.send_request(
ClientRequest( types.ClientRequest(
CallToolRequest( types.CallToolRequest(
method="tools/call", method="tools/call",
params=CallToolRequestParams(name=name, arguments=arguments), params=types.CallToolRequestParams(name=name, arguments=arguments),
) )
), ),
CallToolResult, types.CallToolResult,
) )
async def list_prompts(self) -> ListPromptsResult: async def list_prompts(self) -> types.ListPromptsResult:
"""Send a prompts/list request.""" """Send a prompts/list request."""
from mcp.types import ListPromptsRequest
return await self.send_request( return await self.send_request(
ClientRequest( types.ClientRequest(
ListPromptsRequest( types.ListPromptsRequest(
method="prompts/list", method="prompts/list",
) )
), ),
ListPromptsResult, types.ListPromptsResult,
) )
async def get_prompt( async def get_prompt(
self, name: str, arguments: dict[str, str] | None = None self, name: str, arguments: dict[str, str] | None = None
) -> GetPromptResult: ) -> types.GetPromptResult:
"""Send a prompts/get request.""" """Send a prompts/get request."""
from mcp.types import GetPromptRequest, GetPromptRequestParams
return await self.send_request( return await self.send_request(
ClientRequest( types.ClientRequest(
GetPromptRequest( types.GetPromptRequest(
method="prompts/get", method="prompts/get",
params=GetPromptRequestParams(name=name, arguments=arguments), params=types.GetPromptRequestParams(name=name, arguments=arguments),
) )
), ),
GetPromptResult, types.GetPromptResult,
) )
async def complete( async def complete(
self, ref: ResourceReference | PromptReference, argument: dict self, ref: types.ResourceReference | types.PromptReference, argument: dict
) -> CompleteResult: ) -> types.CompleteResult:
"""Send a completion/complete request.""" """Send a completion/complete request."""
from mcp.types import (
CompleteRequest,
CompleteRequestParams,
CompletionArgument,
)
return await self.send_request( return await self.send_request(
ClientRequest( types.ClientRequest(
CompleteRequest( types.CompleteRequest(
method="completion/complete", method="completion/complete",
params=CompleteRequestParams( params=types.CompleteRequestParams(
ref=ref, ref=ref,
argument=CompletionArgument(**argument), argument=types.CompletionArgument(**argument),
), ),
) )
), ),
CompleteResult, types.CompleteResult,
) )
async def list_tools(self) -> ListToolsResult: async def list_tools(self) -> types.ListToolsResult:
"""Send a tools/list request.""" """Send a tools/list request."""
from mcp.types import ListToolsRequest
return await self.send_request( return await self.send_request(
ClientRequest( types.ClientRequest(
ListToolsRequest( types.ListToolsRequest(
method="tools/list", method="tools/list",
) )
), ),
ListToolsResult, types.ListToolsResult,
) )
async def send_roots_list_changed(self) -> None: async def send_roots_list_changed(self) -> None:
"""Send a roots/list_changed notification.""" """Send a roots/list_changed notification."""
from mcp.types import RootsListChangedNotification
await self.send_notification( await self.send_notification(
ClientNotification( types.ClientNotification(
RootsListChangedNotification( types.RootsListChangedNotification(
method="notifications/roots/list_changed", method="notifications/roots/list_changed",
) )
) )

View File

@@ -9,7 +9,7 @@ from anyio.abc import TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import aconnect_sse from httpx_sse import aconnect_sse
from mcp.types import JSONRPCMessage import mcp.types as types
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -31,11 +31,11 @@ async def sse_client(
`sse_read_timeout` determines how long (in seconds) the client will wait for a new `sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`. event before disconnecting. All other HTTP operations are controlled by `timeout`.
""" """
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: MemoryObjectSendStream[JSONRPCMessage] write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0) read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -85,7 +85,7 @@ async def sse_client(
case "message": case "message":
try: try:
message = ( message = (
JSONRPCMessage.model_validate_json( types.JSONRPCMessage.model_validate_json(
sse.data sse.data
) )
) )

View File

@@ -8,7 +8,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStre
from anyio.streams.text import TextReceiveStream from anyio.streams.text import TextReceiveStream
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from mcp.types import JSONRPCMessage import mcp.types as types
# Environment variables to inherit by default # Environment variables to inherit by default
DEFAULT_INHERITED_ENV_VARS = ( DEFAULT_INHERITED_ENV_VARS = (
@@ -72,11 +72,11 @@ async def stdio_client(server: StdioServerParameters):
Client transport for stdio: this will connect to a server by spawning a Client transport for stdio: this will connect to a server by spawning a
process and communicating with it over stdin/stdout. process and communicating with it over stdin/stdout.
""" """
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: MemoryObjectSendStream[JSONRPCMessage] write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0) read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -99,7 +99,7 @@ async def stdio_client(server: StdioServerParameters):
for line in lines: for line in lines:
try: try:
message = JSONRPCMessage.model_validate_json(line) message = types.JSONRPCMessage.model_validate_json(line)
except Exception as exc: except Exception as exc:
await read_stream_writer.send(exc) await read_stream_writer.send(exc)
continue continue

View File

@@ -12,48 +12,7 @@ from mcp.server.session import ServerSession
from mcp.server.stdio import stdio_server as stdio_server from mcp.server.stdio import stdio_server as stdio_server
from mcp.shared.context import RequestContext from mcp.shared.context import RequestContext
from mcp.shared.session import RequestResponder from mcp.shared.session import RequestResponder
from mcp.types import ( import mcp.types as types
METHOD_NOT_FOUND,
CallToolRequest,
GetPromptResult,
GetPromptRequest,
GetPromptResult,
ImageContent,
ClientNotification,
ClientRequest,
CompleteRequest,
EmbeddedResource,
EmptyResult,
ErrorData,
JSONRPCMessage,
ListPromptsRequest,
ListPromptsResult,
ListResourcesRequest,
ListResourcesResult,
ListToolsRequest,
ListToolsResult,
LoggingCapability,
LoggingLevel,
PingRequest,
ProgressNotification,
Prompt,
PromptMessage,
PromptReference,
PromptsCapability,
ReadResourceRequest,
ReadResourceResult,
Resource,
ResourceReference,
ResourcesCapability,
ServerCapabilities,
ServerResult,
SetLevelRequest,
SubscribeRequest,
TextContent,
Tool,
ToolsCapability,
UnsubscribeRequest,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -77,8 +36,8 @@ class NotificationOptions:
class Server: class Server:
def __init__(self, name: str): def __init__(self, name: str):
self.name = name self.name = name
self.request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]] = { self.request_handlers: dict[type, Callable[..., Awaitable[types.ServerResult]]] = {
PingRequest: _ping_handler, types.PingRequest: _ping_handler,
} }
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
self.notification_options = NotificationOptions() self.notification_options = NotificationOptions()
@@ -116,7 +75,7 @@ class Server:
self, self,
notification_options: NotificationOptions, notification_options: NotificationOptions,
experimental_capabilities: dict[str, dict[str, Any]], experimental_capabilities: dict[str, dict[str, Any]],
) -> ServerCapabilities: ) -> types.ServerCapabilities:
"""Convert existing handlers to a ServerCapabilities object.""" """Convert existing handlers to a ServerCapabilities object."""
prompts_capability = None prompts_capability = None
resources_capability = None resources_capability = None
@@ -124,28 +83,28 @@ class Server:
logging_capability = None logging_capability = None
# Set prompt capabilities if handler exists # Set prompt capabilities if handler exists
if ListPromptsRequest in self.request_handlers: if types.ListPromptsRequest in self.request_handlers:
prompts_capability = PromptsCapability( prompts_capability = types.PromptsCapability(
listChanged=notification_options.prompts_changed listChanged=notification_options.prompts_changed
) )
# Set resource capabilities if handler exists # Set resource capabilities if handler exists
if ListResourcesRequest in self.request_handlers: if types.ListResourcesRequest in self.request_handlers:
resources_capability = ResourcesCapability( resources_capability = types.ResourcesCapability(
subscribe=False, listChanged=notification_options.resources_changed subscribe=False, listChanged=notification_options.resources_changed
) )
# Set tool capabilities if handler exists # Set tool capabilities if handler exists
if ListToolsRequest in self.request_handlers: if types.ListToolsRequest in self.request_handlers:
tools_capability = ToolsCapability( tools_capability = types.ToolsCapability(
listChanged=notification_options.tools_changed listChanged=notification_options.tools_changed
) )
# Set logging capabilities if handler exists # Set logging capabilities if handler exists
if SetLevelRequest in self.request_handlers: if types.SetLevelRequest in self.request_handlers:
logging_capability = LoggingCapability() logging_capability = types.LoggingCapability()
return ServerCapabilities( return types.ServerCapabilities(
prompts=prompts_capability, prompts=prompts_capability,
resources=resources_capability, resources=resources_capability,
tools=tools_capability, tools=tools_capability,
@@ -159,14 +118,14 @@ class Server:
return request_ctx.get() return request_ctx.get()
def list_prompts(self): def list_prompts(self):
def decorator(func: Callable[[], Awaitable[list[Prompt]]]): def decorator(func: Callable[[], Awaitable[list[types.Prompt]]]):
logger.debug("Registering handler for PromptListRequest") logger.debug("Registering handler for PromptListRequest")
async def handler(_: Any): async def handler(_: Any):
prompts = await func() prompts = await func()
return ServerResult(ListPromptsResult(prompts=prompts)) return types.ServerResult(types.ListPromptsResult(prompts=prompts))
self.request_handlers[ListPromptsRequest] = handler self.request_handlers[types.ListPromptsRequest] = handler
return func return func
return decorator return decorator
@@ -174,47 +133,42 @@ class Server:
def get_prompt(self): def get_prompt(self):
def decorator( def decorator(
func: Callable[ func: Callable[
[str, dict[str, str] | None], Awaitable[GetPromptResult] [str, dict[str, str] | None], Awaitable[types.GetPromptResult]
], ],
): ):
logger.debug("Registering handler for GetPromptRequest") logger.debug("Registering handler for GetPromptRequest")
async def handler(req: GetPromptRequest): async def handler(req: types.GetPromptRequest):
prompt_get = await func(req.params.name, req.params.arguments) prompt_get = await func(req.params.name, req.params.arguments)
return ServerResult(prompt_get) return types.ServerResult(prompt_get)
self.request_handlers[GetPromptRequest] = handler self.request_handlers[types.GetPromptRequest] = handler
return func return func
return decorator return decorator
def list_resources(self): def list_resources(self):
def decorator(func: Callable[[], Awaitable[list[Resource]]]): def decorator(func: Callable[[], Awaitable[list[types.Resource]]]):
logger.debug("Registering handler for ListResourcesRequest") logger.debug("Registering handler for ListResourcesRequest")
async def handler(_: Any): async def handler(_: Any):
resources = await func() resources = await func()
return ServerResult(ListResourcesResult(resources=resources)) return types.ServerResult(types.ListResourcesResult(resources=resources))
self.request_handlers[ListResourcesRequest] = handler self.request_handlers[types.ListResourcesRequest] = handler
return func return func
return decorator return decorator
def read_resource(self): def read_resource(self):
from mcp.types import (
BlobResourceContents,
TextResourceContents,
)
def decorator(func: Callable[[AnyUrl], Awaitable[str | bytes]]): def decorator(func: Callable[[AnyUrl], Awaitable[str | bytes]]):
logger.debug("Registering handler for ReadResourceRequest") logger.debug("Registering handler for ReadResourceRequest")
async def handler(req: ReadResourceRequest): async def handler(req: types.ReadResourceRequest):
result = await func(req.params.uri) result = await func(req.params.uri)
match result: match result:
case str(s): case str(s):
content = TextResourceContents( content = types.TextResourceContents(
uri=req.params.uri, uri=req.params.uri,
text=s, text=s,
mimeType="text/plain", mimeType="text/plain",
@@ -222,130 +176,117 @@ class Server:
case bytes(b): case bytes(b):
import base64 import base64
content = BlobResourceContents( content = types.BlobResourceContents(
uri=req.params.uri, uri=req.params.uri,
blob=base64.urlsafe_b64encode(b).decode(), blob=base64.urlsafe_b64encode(b).decode(),
mimeType="application/octet-stream", mimeType="application/octet-stream",
) )
return ServerResult( return types.ServerResult(
ReadResourceResult( types.ReadResourceResult(
contents=[content], contents=[content],
) )
) )
self.request_handlers[ReadResourceRequest] = handler self.request_handlers[types.ReadResourceRequest] = handler
return func return func
return decorator return decorator
def set_logging_level(self): def set_logging_level(self):
from mcp.types import EmptyResult def decorator(func: Callable[[types.LoggingLevel], Awaitable[None]]):
def decorator(func: Callable[[LoggingLevel], Awaitable[None]]):
logger.debug("Registering handler for SetLevelRequest") logger.debug("Registering handler for SetLevelRequest")
async def handler(req: SetLevelRequest): async def handler(req: types.SetLevelRequest):
await func(req.params.level) await func(req.params.level)
return ServerResult(EmptyResult()) return types.ServerResult(types.EmptyResult())
self.request_handlers[SetLevelRequest] = handler self.request_handlers[types.SetLevelRequest] = handler
return func return func
return decorator return decorator
def subscribe_resource(self): def subscribe_resource(self):
from mcp.types import EmptyResult
def decorator(func: Callable[[AnyUrl], Awaitable[None]]): def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
logger.debug("Registering handler for SubscribeRequest") logger.debug("Registering handler for SubscribeRequest")
async def handler(req: SubscribeRequest): async def handler(req: types.SubscribeRequest):
await func(req.params.uri) await func(req.params.uri)
return ServerResult(EmptyResult()) return types.ServerResult(types.EmptyResult())
self.request_handlers[SubscribeRequest] = handler self.request_handlers[types.SubscribeRequest] = handler
return func return func
return decorator return decorator
def unsubscribe_resource(self): def unsubscribe_resource(self):
from mcp.types import EmptyResult
def decorator(func: Callable[[AnyUrl], Awaitable[None]]): def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
logger.debug("Registering handler for UnsubscribeRequest") logger.debug("Registering handler for UnsubscribeRequest")
async def handler(req: UnsubscribeRequest): async def handler(req: types.UnsubscribeRequest):
await func(req.params.uri) await func(req.params.uri)
return ServerResult(EmptyResult()) return types.ServerResult(types.EmptyResult())
self.request_handlers[UnsubscribeRequest] = handler self.request_handlers[types.UnsubscribeRequest] = handler
return func return func
return decorator return decorator
def list_tools(self): def list_tools(self):
def decorator(func: Callable[[], Awaitable[list[Tool]]]): def decorator(func: Callable[[], Awaitable[list[types.Tool]]]):
logger.debug("Registering handler for ListToolsRequest") logger.debug("Registering handler for ListToolsRequest")
async def handler(_: Any): async def handler(_: Any):
tools = await func() tools = await func()
return ServerResult(ListToolsResult(tools=tools)) return types.ServerResult(types.ListToolsResult(tools=tools))
self.request_handlers[ListToolsRequest] = handler self.request_handlers[types.ListToolsRequest] = handler
return func return func
return decorator return decorator
def call_tool(self): def call_tool(self):
from mcp.types import (
CallToolResult,
EmbeddedResource,
ImageContent,
TextContent,
)
def decorator( def decorator(
func: Callable[ func: Callable[
..., ...,
Awaitable[Sequence[TextContent | ImageContent | EmbeddedResource]], Awaitable[Sequence[types.TextContent | types.ImageContent | types.EmbeddedResource]],
], ],
): ):
logger.debug("Registering handler for CallToolRequest") logger.debug("Registering handler for CallToolRequest")
async def handler(req: CallToolRequest): async def handler(req: types.CallToolRequest):
try: try:
results = await func(req.params.name, (req.params.arguments or {})) results = await func(req.params.name, (req.params.arguments or {}))
content = [] content = []
for result in results: for result in results:
match result: match result:
case str() as text: case str() as text:
content.append(TextContent(type="text", text=text)) content.append(types.TextContent(type="text", text=text))
case ImageContent() as img: case types.ImageContent() as img:
content.append( content.append(
ImageContent( types.ImageContent(
type="image", type="image",
data=img.data, data=img.data,
mimeType=img.mimeType, mimeType=img.mimeType,
) )
) )
case EmbeddedResource() as resource: case types.EmbeddedResource() as resource:
content.append( content.append(
EmbeddedResource( types.EmbeddedResource(
type="resource", resource=resource.resource type="resource", resource=resource.resource
) )
) )
return ServerResult(CallToolResult(content=content, isError=False)) return types.ServerResult(types.CallToolResult(content=content, isError=False))
except Exception as e: except Exception as e:
return ServerResult( return types.ServerResult(
CallToolResult( types.CallToolResult(
content=[TextContent(type="text", text=str(e))], content=[types.TextContent(type="text", text=str(e))],
isError=True, isError=True,
) )
) )
self.request_handlers[CallToolRequest] = handler self.request_handlers[types.CallToolRequest] = handler
return func return func
return decorator return decorator
@@ -356,47 +297,46 @@ class Server:
): ):
logger.debug("Registering handler for ProgressNotification") logger.debug("Registering handler for ProgressNotification")
async def handler(req: ProgressNotification): async def handler(req: types.ProgressNotification):
await func( await func(
req.params.progressToken, req.params.progress, req.params.total req.params.progressToken, req.params.progress, req.params.total
) )
self.notification_handlers[ProgressNotification] = handler self.notification_handlers[types.ProgressNotification] = handler
return func return func
return decorator return decorator
def completion(self): def completion(self):
"""Provides completions for prompts and resource templates""" """Provides completions for prompts and resource templates"""
from mcp.types import CompleteResult, Completion, CompletionArgument
def decorator( def decorator(
func: Callable[ func: Callable[
[PromptReference | ResourceReference, CompletionArgument], [types.PromptReference | types.ResourceReference, types.CompletionArgument],
Awaitable[Completion | None], Awaitable[types.Completion | None],
], ],
): ):
logger.debug("Registering handler for CompleteRequest") logger.debug("Registering handler for CompleteRequest")
async def handler(req: CompleteRequest): async def handler(req: types.CompleteRequest):
completion = await func(req.params.ref, req.params.argument) completion = await func(req.params.ref, req.params.argument)
return ServerResult( return types.ServerResult(
CompleteResult( types.CompleteResult(
completion=completion completion=completion
if completion is not None if completion is not None
else Completion(values=[], total=None, hasMore=None), else types.Completion(values=[], total=None, hasMore=None),
) )
) )
self.request_handlers[CompleteRequest] = handler self.request_handlers[types.CompleteRequest] = handler
return func return func
return decorator return decorator
async def run( async def run(
self, self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage], write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
initialization_options: InitializationOptions, initialization_options: InitializationOptions,
# When True, exceptions are returned as messages to the client. # When True, exceptions are returned as messages to the client.
# When False, exceptions are raised, which will cause the server to shut down # When False, exceptions are raised, which will cause the server to shut down
@@ -412,7 +352,7 @@ class Server:
logger.debug(f"Received message: {message}") logger.debug(f"Received message: {message}")
match message: match message:
case RequestResponder(request=ClientRequest(root=req)): case RequestResponder(request=types.ClientRequest(root=req)):
logger.info( logger.info(
f"Processing request of type {type(req).__name__}" f"Processing request of type {type(req).__name__}"
) )
@@ -437,7 +377,7 @@ class Server:
except Exception as err: except Exception as err:
if raise_exceptions: if raise_exceptions:
raise err raise err
response = ErrorData( response = types.ErrorData(
code=0, message=str(err), data=None code=0, message=str(err), data=None
) )
finally: finally:
@@ -448,14 +388,14 @@ class Server:
await message.respond(response) await message.respond(response)
else: else:
await message.respond( await message.respond(
ErrorData( types.ErrorData(
code=METHOD_NOT_FOUND, code=types.METHOD_NOT_FOUND,
message="Method not found", message="Method not found",
) )
) )
logger.debug("Response sent") logger.debug("Response sent")
case ClientNotification(root=notify): case types.ClientNotification(root=notify):
if type(notify) in self.notification_handlers: if type(notify) in self.notification_handlers:
assert type(notify) in self.notification_handlers assert type(notify) in self.notification_handlers
@@ -479,5 +419,5 @@ class Server:
) )
async def _ping_handler(request: PingRequest) -> ServerResult: async def _ping_handler(request: types.PingRequest) -> types.ServerResult:
return ServerResult(EmptyResult()) return types.ServerResult(types.EmptyResult())

View File

@@ -11,29 +11,7 @@ from mcp.shared.session import (
BaseSession, BaseSession,
RequestResponder, RequestResponder,
) )
from mcp.types import ( import mcp.types as types
LATEST_PROTOCOL_VERSION,
ClientNotification,
ClientRequest,
CreateMessageResult,
EmptyResult,
Implementation,
IncludeContext,
InitializedNotification,
InitializeRequest,
InitializeResult,
JSONRPCMessage,
ListRootsResult,
LoggingLevel,
ModelPreferences,
PromptListChangedNotification,
ResourceListChangedNotification,
SamplingMessage,
ServerNotification,
ServerRequest,
ServerResult,
ToolListChangedNotification,
)
class InitializationState(Enum): class InitializationState(Enum):
@@ -44,37 +22,37 @@ class InitializationState(Enum):
class ServerSession( class ServerSession(
BaseSession[ BaseSession[
ServerRequest, types.ServerRequest,
ServerNotification, types.ServerNotification,
ServerResult, types.ServerResult,
ClientRequest, types.ClientRequest,
ClientNotification, types.ClientNotification,
] ]
): ):
_initialized: InitializationState = InitializationState.NotInitialized _initialized: InitializationState = InitializationState.NotInitialized
def __init__( def __init__(
self, self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage], write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
init_options: InitializationOptions, init_options: InitializationOptions,
) -> None: ) -> None:
super().__init__(read_stream, write_stream, ClientRequest, ClientNotification) super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification)
self._initialization_state = InitializationState.NotInitialized self._initialization_state = InitializationState.NotInitialized
self._init_options = init_options self._init_options = init_options
async def _received_request( async def _received_request(
self, responder: RequestResponder[ClientRequest, ServerResult] self, responder: RequestResponder[types.ClientRequest, types.ServerResult]
): ):
match responder.request.root: match responder.request.root:
case InitializeRequest(): case types.InitializeRequest():
self._initialization_state = InitializationState.Initializing self._initialization_state = InitializationState.Initializing
await responder.respond( await responder.respond(
ServerResult( types.ServerResult(
InitializeResult( types.InitializeResult(
protocolVersion=LATEST_PROTOCOL_VERSION, protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=self._init_options.capabilities, capabilities=self._init_options.capabilities,
serverInfo=Implementation( serverInfo=types.Implementation(
name=self._init_options.server_name, name=self._init_options.server_name,
version=self._init_options.server_version, version=self._init_options.server_version,
), ),
@@ -87,11 +65,11 @@ class ServerSession(
"Received request before initialization was complete" "Received request before initialization was complete"
) )
async def _received_notification(self, notification: ClientNotification) -> None: async def _received_notification(self, notification: types.ClientNotification) -> None:
# Need this to avoid ASYNC910 # Need this to avoid ASYNC910
await anyio.lowlevel.checkpoint() await anyio.lowlevel.checkpoint()
match notification.root: match notification.root:
case InitializedNotification(): case types.InitializedNotification():
self._initialization_state = InitializationState.Initialized self._initialization_state = InitializationState.Initialized
case _: case _:
if self._initialization_state != InitializationState.Initialized: if self._initialization_state != InitializationState.Initialized:
@@ -100,19 +78,14 @@ class ServerSession(
) )
async def send_log_message( async def send_log_message(
self, level: LoggingLevel, data: Any, logger: str | None = None self, level: types.LoggingLevel, data: Any, logger: str | None = None
) -> None: ) -> None:
"""Send a log message notification.""" """Send a log message notification."""
from mcp.types import (
LoggingMessageNotification,
LoggingMessageNotificationParams,
)
await self.send_notification( await self.send_notification(
ServerNotification( types.ServerNotification(
LoggingMessageNotification( types.LoggingMessageNotification(
method="notifications/message", method="notifications/message",
params=LoggingMessageNotificationParams( params=types.LoggingMessageNotificationParams(
level=level, level=level,
data=data, data=data,
logger=logger, logger=logger,
@@ -123,43 +96,33 @@ class ServerSession(
async def send_resource_updated(self, uri: AnyUrl) -> None: async def send_resource_updated(self, uri: AnyUrl) -> None:
"""Send a resource updated notification.""" """Send a resource updated notification."""
from mcp.types import (
ResourceUpdatedNotification,
ResourceUpdatedNotificationParams,
)
await self.send_notification( await self.send_notification(
ServerNotification( types.ServerNotification(
ResourceUpdatedNotification( types.ResourceUpdatedNotification(
method="notifications/resources/updated", method="notifications/resources/updated",
params=ResourceUpdatedNotificationParams(uri=uri), params=types.ResourceUpdatedNotificationParams(uri=uri),
) )
) )
) )
async def create_message( async def create_message(
self, self,
messages: list[SamplingMessage], messages: list[types.SamplingMessage],
*, *,
max_tokens: int, max_tokens: int,
system_prompt: str | None = None, system_prompt: str | None = None,
include_context: IncludeContext | None = None, include_context: types.IncludeContext | None = None,
temperature: float | None = None, temperature: float | None = None,
stop_sequences: list[str] | None = None, stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,
model_preferences: ModelPreferences | None = None, model_preferences: types.ModelPreferences | None = None,
) -> CreateMessageResult: ) -> types.CreateMessageResult:
"""Send a sampling/create_message request.""" """Send a sampling/create_message request."""
from mcp.types import (
CreateMessageRequest,
CreateMessageRequestParams,
)
return await self.send_request( return await self.send_request(
ServerRequest( types.ServerRequest(
CreateMessageRequest( types.CreateMessageRequest(
method="sampling/createMessage", method="sampling/createMessage",
params=CreateMessageRequestParams( params=types.CreateMessageRequestParams(
messages=messages, messages=messages,
systemPrompt=system_prompt, systemPrompt=system_prompt,
includeContext=include_context, includeContext=include_context,
@@ -171,46 +134,40 @@ class ServerSession(
), ),
) )
), ),
CreateMessageResult, types.CreateMessageResult,
) )
async def list_roots(self) -> ListRootsResult: async def list_roots(self) -> types.ListRootsResult:
"""Send a roots/list request.""" """Send a roots/list request."""
from mcp.types import ListRootsRequest
return await self.send_request( return await self.send_request(
ServerRequest( types.ServerRequest(
ListRootsRequest( types.ListRootsRequest(
method="roots/list", method="roots/list",
) )
), ),
ListRootsResult, types.ListRootsResult,
) )
async def send_ping(self) -> EmptyResult: async def send_ping(self) -> types.EmptyResult:
"""Send a ping request.""" """Send a ping request."""
from mcp.types import PingRequest
return await self.send_request( return await self.send_request(
ServerRequest( types.ServerRequest(
PingRequest( types.PingRequest(
method="ping", method="ping",
) )
), ),
EmptyResult, types.EmptyResult,
) )
async def send_progress_notification( async def send_progress_notification(
self, progress_token: str | int, progress: float, total: float | None = None self, progress_token: str | int, progress: float, total: float | None = None
) -> None: ) -> None:
"""Send a progress notification.""" """Send a progress notification."""
from mcp.types import ProgressNotification, ProgressNotificationParams
await self.send_notification( await self.send_notification(
ServerNotification( types.ServerNotification(
ProgressNotification( types.ProgressNotification(
method="notifications/progress", method="notifications/progress",
params=ProgressNotificationParams( params=types.ProgressNotificationParams(
progressToken=progress_token, progressToken=progress_token,
progress=progress, progress=progress,
total=total, total=total,
@@ -222,8 +179,8 @@ class ServerSession(
async def send_resource_list_changed(self) -> None: async def send_resource_list_changed(self) -> None:
"""Send a resource list changed notification.""" """Send a resource list changed notification."""
await self.send_notification( await self.send_notification(
ServerNotification( types.ServerNotification(
ResourceListChangedNotification( types.ResourceListChangedNotification(
method="notifications/resources/list_changed", method="notifications/resources/list_changed",
) )
) )
@@ -232,8 +189,8 @@ class ServerSession(
async def send_tool_list_changed(self) -> None: async def send_tool_list_changed(self) -> None:
"""Send a tool list changed notification.""" """Send a tool list changed notification."""
await self.send_notification( await self.send_notification(
ServerNotification( types.ServerNotification(
ToolListChangedNotification( types.ToolListChangedNotification(
method="notifications/tools/list_changed", method="notifications/tools/list_changed",
) )
) )
@@ -242,8 +199,8 @@ class ServerSession(
async def send_prompt_list_changed(self) -> None: async def send_prompt_list_changed(self) -> None:
"""Send a prompt list changed notification.""" """Send a prompt list changed notification."""
await self.send_notification( await self.send_notification(
ServerNotification( types.ServerNotification(
PromptListChangedNotification( types.PromptListChangedNotification(
method="notifications/prompts/list_changed", method="notifications/prompts/list_changed",
) )
) )

View File

@@ -12,7 +12,7 @@ from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
from starlette.types import Receive, Scope, Send from starlette.types import Receive, Scope, Send
from mcp.types import JSONRPCMessage import mcp.types as types
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -30,7 +30,7 @@ class SseServerTransport:
""" """
_endpoint: str _endpoint: str
_read_stream_writers: dict[UUID, MemoryObjectSendStream[JSONRPCMessage | Exception]] _read_stream_writers: dict[UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception]]
def __init__(self, endpoint: str) -> None: def __init__(self, endpoint: str) -> None:
""" """
@@ -50,11 +50,11 @@ class SseServerTransport:
raise ValueError("connect_sse can only handle HTTP requests") raise ValueError("connect_sse can only handle HTTP requests")
logger.debug("Setting up SSE connection") logger.debug("Setting up SSE connection")
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: MemoryObjectSendStream[JSONRPCMessage] write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0) read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -125,7 +125,7 @@ class SseServerTransport:
logger.debug(f"Received JSON: {json}") logger.debug(f"Received JSON: {json}")
try: try:
message = JSONRPCMessage.model_validate(json) message = types.JSONRPCMessage.model_validate(json)
logger.debug(f"Validated client message: {message}") logger.debug(f"Validated client message: {message}")
except ValidationError as err: except ValidationError as err:
logger.error(f"Failed to parse message: {err}") logger.error(f"Failed to parse message: {err}")

View File

@@ -5,7 +5,7 @@ import anyio
import anyio.lowlevel import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp.types import JSONRPCMessage import mcp.types as types
@asynccontextmanager @asynccontextmanager
@@ -24,11 +24,11 @@ async def stdio_server(
if not stdout: if not stdout:
stdout = anyio.wrap_file(sys.stdout) stdout = anyio.wrap_file(sys.stdout)
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: MemoryObjectSendStream[JSONRPCMessage] write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0) read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -38,7 +38,7 @@ async def stdio_server(
async with read_stream_writer: async with read_stream_writer:
async for line in stdin: async for line in stdin:
try: try:
message = JSONRPCMessage.model_validate_json(line) message = types.JSONRPCMessage.model_validate_json(line)
except Exception as exc: except Exception as exc:
await read_stream_writer.send(exc) await read_stream_writer.send(exc)
continue continue

View File

@@ -6,7 +6,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStre
from starlette.types import Receive, Scope, Send from starlette.types import Receive, Scope, Send
from starlette.websockets import WebSocket from starlette.websockets import WebSocket
from mcp.types import JSONRPCMessage import mcp.types as types
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -21,11 +21,11 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send):
websocket = WebSocket(scope, receive, send) websocket = WebSocket(scope, receive, send)
await websocket.accept(subprotocol="mcp") await websocket.accept(subprotocol="mcp")
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: MemoryObjectSendStream[JSONRPCMessage] write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0) read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -35,7 +35,7 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send):
async with read_stream_writer: async with read_stream_writer:
async for message in websocket.iter_json(): async for message in websocket.iter_json():
try: try:
client_message = JSONRPCMessage.model_validate(message) client_message = types.JSONRPCMessage.model_validate(message)
except Exception as exc: except Exception as exc:
await read_stream_writer.send(exc) await read_stream_writer.send(exc)
continue continue