Don't re-export types

We will be a bit more low level and expect callees to import mcp.types
instead of relying in re-exported types. This makes usage more explicit
and avoids potential collisions in mcp.server.
This commit is contained in:
David Soria Parra
2024-11-11 20:14:03 +00:00
parent f5d82bd229
commit b9b44e6dad
8 changed files with 231 additions and 413 deletions

View File

@@ -5,83 +5,54 @@ from pydantic import AnyUrl
from mcp.shared.session import BaseSession
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
from mcp.types import (
LATEST_PROTOCOL_VERSION,
CallToolResult,
ClientCapabilities,
ClientNotification,
ClientRequest,
ClientResult,
CompleteResult,
EmptyResult,
GetPromptResult,
Implementation,
InitializedNotification,
InitializeResult,
JSONRPCMessage,
ListPromptsResult,
ListResourcesResult,
ListToolsResult,
LoggingLevel,
PromptReference,
ReadResourceResult,
ResourceReference,
RootsCapability,
ServerNotification,
ServerRequest,
)
import mcp.types as types
class ClientSession(
BaseSession[
ClientRequest,
ClientNotification,
ClientResult,
ServerRequest,
ServerNotification,
types.ClientRequest,
types.ClientNotification,
types.ClientResult,
types.ServerRequest,
types.ServerNotification,
]
):
def __init__(
self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
read_timeout_seconds: timedelta | None = None,
) -> None:
super().__init__(
read_stream,
write_stream,
ServerRequest,
ServerNotification,
types.ServerRequest,
types.ServerNotification,
read_timeout_seconds=read_timeout_seconds,
)
async def initialize(self) -> InitializeResult:
from mcp.types import (
InitializeRequest,
InitializeRequestParams,
)
async def initialize(self) -> types.InitializeResult:
result = await self.send_request(
ClientRequest(
InitializeRequest(
types.ClientRequest(
types.InitializeRequest(
method="initialize",
params=InitializeRequestParams(
protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=ClientCapabilities(
params=types.InitializeRequestParams(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=types.ClientCapabilities(
sampling=None,
experimental=None,
roots=RootsCapability(
roots=types.RootsCapability(
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
listChanged=True
),
),
clientInfo=Implementation(name="mcp", version="0.1.0"),
clientInfo=types.Implementation(name="mcp", version="0.1.0"),
),
)
),
InitializeResult,
types.InitializeResult,
)
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
@@ -91,40 +62,33 @@ class ClientSession(
)
await self.send_notification(
ClientNotification(
InitializedNotification(method="notifications/initialized")
types.ClientNotification(
types.InitializedNotification(method="notifications/initialized")
)
)
return result
async def send_ping(self) -> EmptyResult:
async def send_ping(self) -> types.EmptyResult:
"""Send a ping request."""
from mcp.types import PingRequest
return await self.send_request(
ClientRequest(
PingRequest(
types.ClientRequest(
types.PingRequest(
method="ping",
)
),
EmptyResult,
types.EmptyResult,
)
async def send_progress_notification(
self, progress_token: str | int, progress: float, total: float | None = None
) -> None:
"""Send a progress notification."""
from mcp.types import (
ProgressNotification,
ProgressNotificationParams,
)
await self.send_notification(
ClientNotification(
ProgressNotification(
types.ClientNotification(
types.ProgressNotification(
method="notifications/progress",
params=ProgressNotificationParams(
params=types.ProgressNotificationParams(
progressToken=progress_token,
progress=progress,
total=total,
@@ -133,180 +97,137 @@ class ClientSession(
)
)
async def set_logging_level(self, level: LoggingLevel) -> EmptyResult:
async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
"""Send a logging/setLevel request."""
from mcp.types import (
SetLevelRequest,
SetLevelRequestParams,
)
return await self.send_request(
ClientRequest(
SetLevelRequest(
types.ClientRequest(
types.SetLevelRequest(
method="logging/setLevel",
params=SetLevelRequestParams(level=level),
params=types.SetLevelRequestParams(level=level),
)
),
EmptyResult,
types.EmptyResult,
)
async def list_resources(self) -> ListResourcesResult:
async def list_resources(self) -> types.ListResourcesResult:
"""Send a resources/list request."""
from mcp.types import (
ListResourcesRequest,
)
return await self.send_request(
ClientRequest(
ListResourcesRequest(
types.ClientRequest(
types.ListResourcesRequest(
method="resources/list",
)
),
ListResourcesResult,
types.ListResourcesResult,
)
async def read_resource(self, uri: AnyUrl) -> ReadResourceResult:
async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
"""Send a resources/read request."""
from mcp.types import (
ReadResourceRequest,
ReadResourceRequestParams,
)
return await self.send_request(
ClientRequest(
ReadResourceRequest(
types.ClientRequest(
types.ReadResourceRequest(
method="resources/read",
params=ReadResourceRequestParams(uri=uri),
params=types.ReadResourceRequestParams(uri=uri),
)
),
ReadResourceResult,
types.ReadResourceResult,
)
async def subscribe_resource(self, uri: AnyUrl) -> EmptyResult:
async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
"""Send a resources/subscribe request."""
from mcp.types import (
SubscribeRequest,
SubscribeRequestParams,
)
return await self.send_request(
ClientRequest(
SubscribeRequest(
types.ClientRequest(
types.SubscribeRequest(
method="resources/subscribe",
params=SubscribeRequestParams(uri=uri),
params=types.SubscribeRequestParams(uri=uri),
)
),
EmptyResult,
types.EmptyResult,
)
async def unsubscribe_resource(self, uri: AnyUrl) -> EmptyResult:
async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
"""Send a resources/unsubscribe request."""
from mcp.types import (
UnsubscribeRequest,
UnsubscribeRequestParams,
)
return await self.send_request(
ClientRequest(
UnsubscribeRequest(
types.ClientRequest(
types.UnsubscribeRequest(
method="resources/unsubscribe",
params=UnsubscribeRequestParams(uri=uri),
params=types.UnsubscribeRequestParams(uri=uri),
)
),
EmptyResult,
types.EmptyResult,
)
async def call_tool(
self, name: str, arguments: dict | None = None
) -> CallToolResult:
) -> types.CallToolResult:
"""Send a tools/call request."""
from mcp.types import (
CallToolRequest,
CallToolRequestParams,
)
return await self.send_request(
ClientRequest(
CallToolRequest(
types.ClientRequest(
types.CallToolRequest(
method="tools/call",
params=CallToolRequestParams(name=name, arguments=arguments),
params=types.CallToolRequestParams(name=name, arguments=arguments),
)
),
CallToolResult,
types.CallToolResult,
)
async def list_prompts(self) -> ListPromptsResult:
async def list_prompts(self) -> types.ListPromptsResult:
"""Send a prompts/list request."""
from mcp.types import ListPromptsRequest
return await self.send_request(
ClientRequest(
ListPromptsRequest(
types.ClientRequest(
types.ListPromptsRequest(
method="prompts/list",
)
),
ListPromptsResult,
types.ListPromptsResult,
)
async def get_prompt(
self, name: str, arguments: dict[str, str] | None = None
) -> GetPromptResult:
) -> types.GetPromptResult:
"""Send a prompts/get request."""
from mcp.types import GetPromptRequest, GetPromptRequestParams
return await self.send_request(
ClientRequest(
GetPromptRequest(
types.ClientRequest(
types.GetPromptRequest(
method="prompts/get",
params=GetPromptRequestParams(name=name, arguments=arguments),
params=types.GetPromptRequestParams(name=name, arguments=arguments),
)
),
GetPromptResult,
types.GetPromptResult,
)
async def complete(
self, ref: ResourceReference | PromptReference, argument: dict
) -> CompleteResult:
self, ref: types.ResourceReference | types.PromptReference, argument: dict
) -> types.CompleteResult:
"""Send a completion/complete request."""
from mcp.types import (
CompleteRequest,
CompleteRequestParams,
CompletionArgument,
)
return await self.send_request(
ClientRequest(
CompleteRequest(
types.ClientRequest(
types.CompleteRequest(
method="completion/complete",
params=CompleteRequestParams(
params=types.CompleteRequestParams(
ref=ref,
argument=CompletionArgument(**argument),
argument=types.CompletionArgument(**argument),
),
)
),
CompleteResult,
types.CompleteResult,
)
async def list_tools(self) -> ListToolsResult:
async def list_tools(self) -> types.ListToolsResult:
"""Send a tools/list request."""
from mcp.types import ListToolsRequest
return await self.send_request(
ClientRequest(
ListToolsRequest(
types.ClientRequest(
types.ListToolsRequest(
method="tools/list",
)
),
ListToolsResult,
types.ListToolsResult,
)
async def send_roots_list_changed(self) -> None:
"""Send a roots/list_changed notification."""
from mcp.types import RootsListChangedNotification
await self.send_notification(
ClientNotification(
RootsListChangedNotification(
types.ClientNotification(
types.RootsListChangedNotification(
method="notifications/roots/list_changed",
)
)

View File

@@ -9,7 +9,7 @@ from anyio.abc import TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import aconnect_sse
from mcp.types import JSONRPCMessage
import mcp.types as types
logger = logging.getLogger(__name__)
@@ -31,11 +31,11 @@ async def sse_client(
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.
"""
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: MemoryObjectSendStream[JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -85,7 +85,7 @@ async def sse_client(
case "message":
try:
message = (
JSONRPCMessage.model_validate_json(
types.JSONRPCMessage.model_validate_json(
sse.data
)
)

View File

@@ -8,7 +8,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStre
from anyio.streams.text import TextReceiveStream
from pydantic import BaseModel, Field
from mcp.types import JSONRPCMessage
import mcp.types as types
# Environment variables to inherit by default
DEFAULT_INHERITED_ENV_VARS = (
@@ -72,11 +72,11 @@ async def stdio_client(server: StdioServerParameters):
Client transport for stdio: this will connect to a server by spawning a
process and communicating with it over stdin/stdout.
"""
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: MemoryObjectSendStream[JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -99,7 +99,7 @@ async def stdio_client(server: StdioServerParameters):
for line in lines:
try:
message = JSONRPCMessage.model_validate_json(line)
message = types.JSONRPCMessage.model_validate_json(line)
except Exception as exc:
await read_stream_writer.send(exc)
continue

View File

@@ -12,48 +12,7 @@ from mcp.server.session import ServerSession
from mcp.server.stdio import stdio_server as stdio_server
from mcp.shared.context import RequestContext
from mcp.shared.session import RequestResponder
from mcp.types import (
METHOD_NOT_FOUND,
CallToolRequest,
GetPromptResult,
GetPromptRequest,
GetPromptResult,
ImageContent,
ClientNotification,
ClientRequest,
CompleteRequest,
EmbeddedResource,
EmptyResult,
ErrorData,
JSONRPCMessage,
ListPromptsRequest,
ListPromptsResult,
ListResourcesRequest,
ListResourcesResult,
ListToolsRequest,
ListToolsResult,
LoggingCapability,
LoggingLevel,
PingRequest,
ProgressNotification,
Prompt,
PromptMessage,
PromptReference,
PromptsCapability,
ReadResourceRequest,
ReadResourceResult,
Resource,
ResourceReference,
ResourcesCapability,
ServerCapabilities,
ServerResult,
SetLevelRequest,
SubscribeRequest,
TextContent,
Tool,
ToolsCapability,
UnsubscribeRequest,
)
import mcp.types as types
logger = logging.getLogger(__name__)
@@ -77,8 +36,8 @@ class NotificationOptions:
class Server:
def __init__(self, name: str):
self.name = name
self.request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]] = {
PingRequest: _ping_handler,
self.request_handlers: dict[type, Callable[..., Awaitable[types.ServerResult]]] = {
types.PingRequest: _ping_handler,
}
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
self.notification_options = NotificationOptions()
@@ -116,7 +75,7 @@ class Server:
self,
notification_options: NotificationOptions,
experimental_capabilities: dict[str, dict[str, Any]],
) -> ServerCapabilities:
) -> types.ServerCapabilities:
"""Convert existing handlers to a ServerCapabilities object."""
prompts_capability = None
resources_capability = None
@@ -124,28 +83,28 @@ class Server:
logging_capability = None
# Set prompt capabilities if handler exists
if ListPromptsRequest in self.request_handlers:
prompts_capability = PromptsCapability(
if types.ListPromptsRequest in self.request_handlers:
prompts_capability = types.PromptsCapability(
listChanged=notification_options.prompts_changed
)
# Set resource capabilities if handler exists
if ListResourcesRequest in self.request_handlers:
resources_capability = ResourcesCapability(
if types.ListResourcesRequest in self.request_handlers:
resources_capability = types.ResourcesCapability(
subscribe=False, listChanged=notification_options.resources_changed
)
# Set tool capabilities if handler exists
if ListToolsRequest in self.request_handlers:
tools_capability = ToolsCapability(
if types.ListToolsRequest in self.request_handlers:
tools_capability = types.ToolsCapability(
listChanged=notification_options.tools_changed
)
# Set logging capabilities if handler exists
if SetLevelRequest in self.request_handlers:
logging_capability = LoggingCapability()
if types.SetLevelRequest in self.request_handlers:
logging_capability = types.LoggingCapability()
return ServerCapabilities(
return types.ServerCapabilities(
prompts=prompts_capability,
resources=resources_capability,
tools=tools_capability,
@@ -159,14 +118,14 @@ class Server:
return request_ctx.get()
def list_prompts(self):
def decorator(func: Callable[[], Awaitable[list[Prompt]]]):
def decorator(func: Callable[[], Awaitable[list[types.Prompt]]]):
logger.debug("Registering handler for PromptListRequest")
async def handler(_: Any):
prompts = await func()
return ServerResult(ListPromptsResult(prompts=prompts))
return types.ServerResult(types.ListPromptsResult(prompts=prompts))
self.request_handlers[ListPromptsRequest] = handler
self.request_handlers[types.ListPromptsRequest] = handler
return func
return decorator
@@ -174,47 +133,42 @@ class Server:
def get_prompt(self):
def decorator(
func: Callable[
[str, dict[str, str] | None], Awaitable[GetPromptResult]
[str, dict[str, str] | None], Awaitable[types.GetPromptResult]
],
):
logger.debug("Registering handler for GetPromptRequest")
async def handler(req: GetPromptRequest):
async def handler(req: types.GetPromptRequest):
prompt_get = await func(req.params.name, req.params.arguments)
return ServerResult(prompt_get)
return types.ServerResult(prompt_get)
self.request_handlers[GetPromptRequest] = handler
self.request_handlers[types.GetPromptRequest] = handler
return func
return decorator
def list_resources(self):
def decorator(func: Callable[[], Awaitable[list[Resource]]]):
def decorator(func: Callable[[], Awaitable[list[types.Resource]]]):
logger.debug("Registering handler for ListResourcesRequest")
async def handler(_: Any):
resources = await func()
return ServerResult(ListResourcesResult(resources=resources))
return types.ServerResult(types.ListResourcesResult(resources=resources))
self.request_handlers[ListResourcesRequest] = handler
self.request_handlers[types.ListResourcesRequest] = handler
return func
return decorator
def read_resource(self):
from mcp.types import (
BlobResourceContents,
TextResourceContents,
)
def decorator(func: Callable[[AnyUrl], Awaitable[str | bytes]]):
logger.debug("Registering handler for ReadResourceRequest")
async def handler(req: ReadResourceRequest):
async def handler(req: types.ReadResourceRequest):
result = await func(req.params.uri)
match result:
case str(s):
content = TextResourceContents(
content = types.TextResourceContents(
uri=req.params.uri,
text=s,
mimeType="text/plain",
@@ -222,130 +176,117 @@ class Server:
case bytes(b):
import base64
content = BlobResourceContents(
content = types.BlobResourceContents(
uri=req.params.uri,
blob=base64.urlsafe_b64encode(b).decode(),
mimeType="application/octet-stream",
)
return ServerResult(
ReadResourceResult(
return types.ServerResult(
types.ReadResourceResult(
contents=[content],
)
)
self.request_handlers[ReadResourceRequest] = handler
self.request_handlers[types.ReadResourceRequest] = handler
return func
return decorator
def set_logging_level(self):
from mcp.types import EmptyResult
def decorator(func: Callable[[LoggingLevel], Awaitable[None]]):
def decorator(func: Callable[[types.LoggingLevel], Awaitable[None]]):
logger.debug("Registering handler for SetLevelRequest")
async def handler(req: SetLevelRequest):
async def handler(req: types.SetLevelRequest):
await func(req.params.level)
return ServerResult(EmptyResult())
return types.ServerResult(types.EmptyResult())
self.request_handlers[SetLevelRequest] = handler
self.request_handlers[types.SetLevelRequest] = handler
return func
return decorator
def subscribe_resource(self):
from mcp.types import EmptyResult
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
logger.debug("Registering handler for SubscribeRequest")
async def handler(req: SubscribeRequest):
async def handler(req: types.SubscribeRequest):
await func(req.params.uri)
return ServerResult(EmptyResult())
return types.ServerResult(types.EmptyResult())
self.request_handlers[SubscribeRequest] = handler
self.request_handlers[types.SubscribeRequest] = handler
return func
return decorator
def unsubscribe_resource(self):
from mcp.types import EmptyResult
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
logger.debug("Registering handler for UnsubscribeRequest")
async def handler(req: UnsubscribeRequest):
async def handler(req: types.UnsubscribeRequest):
await func(req.params.uri)
return ServerResult(EmptyResult())
return types.ServerResult(types.EmptyResult())
self.request_handlers[UnsubscribeRequest] = handler
self.request_handlers[types.UnsubscribeRequest] = handler
return func
return decorator
def list_tools(self):
def decorator(func: Callable[[], Awaitable[list[Tool]]]):
def decorator(func: Callable[[], Awaitable[list[types.Tool]]]):
logger.debug("Registering handler for ListToolsRequest")
async def handler(_: Any):
tools = await func()
return ServerResult(ListToolsResult(tools=tools))
return types.ServerResult(types.ListToolsResult(tools=tools))
self.request_handlers[ListToolsRequest] = handler
self.request_handlers[types.ListToolsRequest] = handler
return func
return decorator
def call_tool(self):
from mcp.types import (
CallToolResult,
EmbeddedResource,
ImageContent,
TextContent,
)
def decorator(
func: Callable[
...,
Awaitable[Sequence[TextContent | ImageContent | EmbeddedResource]],
Awaitable[Sequence[types.TextContent | types.ImageContent | types.EmbeddedResource]],
],
):
logger.debug("Registering handler for CallToolRequest")
async def handler(req: CallToolRequest):
async def handler(req: types.CallToolRequest):
try:
results = await func(req.params.name, (req.params.arguments or {}))
content = []
for result in results:
match result:
case str() as text:
content.append(TextContent(type="text", text=text))
case ImageContent() as img:
content.append(types.TextContent(type="text", text=text))
case types.ImageContent() as img:
content.append(
ImageContent(
types.ImageContent(
type="image",
data=img.data,
mimeType=img.mimeType,
)
)
case EmbeddedResource() as resource:
case types.EmbeddedResource() as resource:
content.append(
EmbeddedResource(
types.EmbeddedResource(
type="resource", resource=resource.resource
)
)
return ServerResult(CallToolResult(content=content, isError=False))
return types.ServerResult(types.CallToolResult(content=content, isError=False))
except Exception as e:
return ServerResult(
CallToolResult(
content=[TextContent(type="text", text=str(e))],
return types.ServerResult(
types.CallToolResult(
content=[types.TextContent(type="text", text=str(e))],
isError=True,
)
)
self.request_handlers[CallToolRequest] = handler
self.request_handlers[types.CallToolRequest] = handler
return func
return decorator
@@ -356,47 +297,46 @@ class Server:
):
logger.debug("Registering handler for ProgressNotification")
async def handler(req: ProgressNotification):
async def handler(req: types.ProgressNotification):
await func(
req.params.progressToken, req.params.progress, req.params.total
)
self.notification_handlers[ProgressNotification] = handler
self.notification_handlers[types.ProgressNotification] = handler
return func
return decorator
def completion(self):
"""Provides completions for prompts and resource templates"""
from mcp.types import CompleteResult, Completion, CompletionArgument
def decorator(
func: Callable[
[PromptReference | ResourceReference, CompletionArgument],
Awaitable[Completion | None],
[types.PromptReference | types.ResourceReference, types.CompletionArgument],
Awaitable[types.Completion | None],
],
):
logger.debug("Registering handler for CompleteRequest")
async def handler(req: CompleteRequest):
async def handler(req: types.CompleteRequest):
completion = await func(req.params.ref, req.params.argument)
return ServerResult(
CompleteResult(
return types.ServerResult(
types.CompleteResult(
completion=completion
if completion is not None
else Completion(values=[], total=None, hasMore=None),
else types.Completion(values=[], total=None, hasMore=None),
)
)
self.request_handlers[CompleteRequest] = handler
self.request_handlers[types.CompleteRequest] = handler
return func
return decorator
async def run(
self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
initialization_options: InitializationOptions,
# When True, exceptions are returned as messages to the client.
# When False, exceptions are raised, which will cause the server to shut down
@@ -412,7 +352,7 @@ class Server:
logger.debug(f"Received message: {message}")
match message:
case RequestResponder(request=ClientRequest(root=req)):
case RequestResponder(request=types.ClientRequest(root=req)):
logger.info(
f"Processing request of type {type(req).__name__}"
)
@@ -437,7 +377,7 @@ class Server:
except Exception as err:
if raise_exceptions:
raise err
response = ErrorData(
response = types.ErrorData(
code=0, message=str(err), data=None
)
finally:
@@ -448,14 +388,14 @@ class Server:
await message.respond(response)
else:
await message.respond(
ErrorData(
code=METHOD_NOT_FOUND,
types.ErrorData(
code=types.METHOD_NOT_FOUND,
message="Method not found",
)
)
logger.debug("Response sent")
case ClientNotification(root=notify):
case types.ClientNotification(root=notify):
if type(notify) in self.notification_handlers:
assert type(notify) in self.notification_handlers
@@ -479,5 +419,5 @@ class Server:
)
async def _ping_handler(request: PingRequest) -> ServerResult:
return ServerResult(EmptyResult())
async def _ping_handler(request: types.PingRequest) -> types.ServerResult:
return types.ServerResult(types.EmptyResult())

View File

@@ -11,29 +11,7 @@ from mcp.shared.session import (
BaseSession,
RequestResponder,
)
from mcp.types import (
LATEST_PROTOCOL_VERSION,
ClientNotification,
ClientRequest,
CreateMessageResult,
EmptyResult,
Implementation,
IncludeContext,
InitializedNotification,
InitializeRequest,
InitializeResult,
JSONRPCMessage,
ListRootsResult,
LoggingLevel,
ModelPreferences,
PromptListChangedNotification,
ResourceListChangedNotification,
SamplingMessage,
ServerNotification,
ServerRequest,
ServerResult,
ToolListChangedNotification,
)
import mcp.types as types
class InitializationState(Enum):
@@ -44,37 +22,37 @@ class InitializationState(Enum):
class ServerSession(
BaseSession[
ServerRequest,
ServerNotification,
ServerResult,
ClientRequest,
ClientNotification,
types.ServerRequest,
types.ServerNotification,
types.ServerResult,
types.ClientRequest,
types.ClientNotification,
]
):
_initialized: InitializationState = InitializationState.NotInitialized
def __init__(
self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
init_options: InitializationOptions,
) -> None:
super().__init__(read_stream, write_stream, ClientRequest, ClientNotification)
super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification)
self._initialization_state = InitializationState.NotInitialized
self._init_options = init_options
async def _received_request(
self, responder: RequestResponder[ClientRequest, ServerResult]
self, responder: RequestResponder[types.ClientRequest, types.ServerResult]
):
match responder.request.root:
case InitializeRequest():
case types.InitializeRequest():
self._initialization_state = InitializationState.Initializing
await responder.respond(
ServerResult(
InitializeResult(
protocolVersion=LATEST_PROTOCOL_VERSION,
types.ServerResult(
types.InitializeResult(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=self._init_options.capabilities,
serverInfo=Implementation(
serverInfo=types.Implementation(
name=self._init_options.server_name,
version=self._init_options.server_version,
),
@@ -87,11 +65,11 @@ class ServerSession(
"Received request before initialization was complete"
)
async def _received_notification(self, notification: ClientNotification) -> None:
async def _received_notification(self, notification: types.ClientNotification) -> None:
# Need this to avoid ASYNC910
await anyio.lowlevel.checkpoint()
match notification.root:
case InitializedNotification():
case types.InitializedNotification():
self._initialization_state = InitializationState.Initialized
case _:
if self._initialization_state != InitializationState.Initialized:
@@ -100,19 +78,14 @@ class ServerSession(
)
async def send_log_message(
self, level: LoggingLevel, data: Any, logger: str | None = None
self, level: types.LoggingLevel, data: Any, logger: str | None = None
) -> None:
"""Send a log message notification."""
from mcp.types import (
LoggingMessageNotification,
LoggingMessageNotificationParams,
)
await self.send_notification(
ServerNotification(
LoggingMessageNotification(
types.ServerNotification(
types.LoggingMessageNotification(
method="notifications/message",
params=LoggingMessageNotificationParams(
params=types.LoggingMessageNotificationParams(
level=level,
data=data,
logger=logger,
@@ -123,43 +96,33 @@ class ServerSession(
async def send_resource_updated(self, uri: AnyUrl) -> None:
"""Send a resource updated notification."""
from mcp.types import (
ResourceUpdatedNotification,
ResourceUpdatedNotificationParams,
)
await self.send_notification(
ServerNotification(
ResourceUpdatedNotification(
types.ServerNotification(
types.ResourceUpdatedNotification(
method="notifications/resources/updated",
params=ResourceUpdatedNotificationParams(uri=uri),
params=types.ResourceUpdatedNotificationParams(uri=uri),
)
)
)
async def create_message(
self,
messages: list[SamplingMessage],
messages: list[types.SamplingMessage],
*,
max_tokens: int,
system_prompt: str | None = None,
include_context: IncludeContext | None = None,
include_context: types.IncludeContext | None = None,
temperature: float | None = None,
stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None,
model_preferences: ModelPreferences | None = None,
) -> CreateMessageResult:
model_preferences: types.ModelPreferences | None = None,
) -> types.CreateMessageResult:
"""Send a sampling/create_message request."""
from mcp.types import (
CreateMessageRequest,
CreateMessageRequestParams,
)
return await self.send_request(
ServerRequest(
CreateMessageRequest(
types.ServerRequest(
types.CreateMessageRequest(
method="sampling/createMessage",
params=CreateMessageRequestParams(
params=types.CreateMessageRequestParams(
messages=messages,
systemPrompt=system_prompt,
includeContext=include_context,
@@ -171,46 +134,40 @@ class ServerSession(
),
)
),
CreateMessageResult,
types.CreateMessageResult,
)
async def list_roots(self) -> ListRootsResult:
async def list_roots(self) -> types.ListRootsResult:
"""Send a roots/list request."""
from mcp.types import ListRootsRequest
return await self.send_request(
ServerRequest(
ListRootsRequest(
types.ServerRequest(
types.ListRootsRequest(
method="roots/list",
)
),
ListRootsResult,
types.ListRootsResult,
)
async def send_ping(self) -> EmptyResult:
async def send_ping(self) -> types.EmptyResult:
"""Send a ping request."""
from mcp.types import PingRequest
return await self.send_request(
ServerRequest(
PingRequest(
types.ServerRequest(
types.PingRequest(
method="ping",
)
),
EmptyResult,
types.EmptyResult,
)
async def send_progress_notification(
self, progress_token: str | int, progress: float, total: float | None = None
) -> None:
"""Send a progress notification."""
from mcp.types import ProgressNotification, ProgressNotificationParams
await self.send_notification(
ServerNotification(
ProgressNotification(
types.ServerNotification(
types.ProgressNotification(
method="notifications/progress",
params=ProgressNotificationParams(
params=types.ProgressNotificationParams(
progressToken=progress_token,
progress=progress,
total=total,
@@ -222,8 +179,8 @@ class ServerSession(
async def send_resource_list_changed(self) -> None:
"""Send a resource list changed notification."""
await self.send_notification(
ServerNotification(
ResourceListChangedNotification(
types.ServerNotification(
types.ResourceListChangedNotification(
method="notifications/resources/list_changed",
)
)
@@ -232,8 +189,8 @@ class ServerSession(
async def send_tool_list_changed(self) -> None:
"""Send a tool list changed notification."""
await self.send_notification(
ServerNotification(
ToolListChangedNotification(
types.ServerNotification(
types.ToolListChangedNotification(
method="notifications/tools/list_changed",
)
)
@@ -242,8 +199,8 @@ class ServerSession(
async def send_prompt_list_changed(self) -> None:
"""Send a prompt list changed notification."""
await self.send_notification(
ServerNotification(
PromptListChangedNotification(
types.ServerNotification(
types.PromptListChangedNotification(
method="notifications/prompts/list_changed",
)
)

View File

@@ -12,7 +12,7 @@ from starlette.requests import Request
from starlette.responses import Response
from starlette.types import Receive, Scope, Send
from mcp.types import JSONRPCMessage
import mcp.types as types
logger = logging.getLogger(__name__)
@@ -30,7 +30,7 @@ class SseServerTransport:
"""
_endpoint: str
_read_stream_writers: dict[UUID, MemoryObjectSendStream[JSONRPCMessage | Exception]]
_read_stream_writers: dict[UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception]]
def __init__(self, endpoint: str) -> None:
"""
@@ -50,11 +50,11 @@ class SseServerTransport:
raise ValueError("connect_sse can only handle HTTP requests")
logger.debug("Setting up SSE connection")
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: MemoryObjectSendStream[JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -125,7 +125,7 @@ class SseServerTransport:
logger.debug(f"Received JSON: {json}")
try:
message = JSONRPCMessage.model_validate(json)
message = types.JSONRPCMessage.model_validate(json)
logger.debug(f"Validated client message: {message}")
except ValidationError as err:
logger.error(f"Failed to parse message: {err}")

View File

@@ -5,7 +5,7 @@ import anyio
import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp.types import JSONRPCMessage
import mcp.types as types
@asynccontextmanager
@@ -24,11 +24,11 @@ async def stdio_server(
if not stdout:
stdout = anyio.wrap_file(sys.stdout)
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: MemoryObjectSendStream[JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -38,7 +38,7 @@ async def stdio_server(
async with read_stream_writer:
async for line in stdin:
try:
message = JSONRPCMessage.model_validate_json(line)
message = types.JSONRPCMessage.model_validate_json(line)
except Exception as exc:
await read_stream_writer.send(exc)
continue

View File

@@ -6,7 +6,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStre
from starlette.types import Receive, Scope, Send
from starlette.websockets import WebSocket
from mcp.types import JSONRPCMessage
import mcp.types as types
logger = logging.getLogger(__name__)
@@ -21,11 +21,11 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send):
websocket = WebSocket(scope, receive, send)
await websocket.accept(subprotocol="mcp")
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: MemoryObjectSendStream[JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -35,7 +35,7 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send):
async with read_stream_writer:
async for message in websocket.iter_json():
try:
client_message = JSONRPCMessage.model_validate(message)
client_message = types.JSONRPCMessage.model_validate(message)
except Exception as exc:
await read_stream_writer.send(exc)
continue