Don't re-export types

We will be a bit more low level and expect callees to import mcp.types
instead of relying in re-exported types. This makes usage more explicit
and avoids potential collisions in mcp.server.
This commit is contained in:
David Soria Parra
2024-11-11 20:14:03 +00:00
parent f5d82bd229
commit b9b44e6dad
8 changed files with 231 additions and 413 deletions

View File

@@ -5,83 +5,54 @@ from pydantic import AnyUrl
from mcp.shared.session import BaseSession
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
from mcp.types import (
LATEST_PROTOCOL_VERSION,
CallToolResult,
ClientCapabilities,
ClientNotification,
ClientRequest,
ClientResult,
CompleteResult,
EmptyResult,
GetPromptResult,
Implementation,
InitializedNotification,
InitializeResult,
JSONRPCMessage,
ListPromptsResult,
ListResourcesResult,
ListToolsResult,
LoggingLevel,
PromptReference,
ReadResourceResult,
ResourceReference,
RootsCapability,
ServerNotification,
ServerRequest,
)
import mcp.types as types
class ClientSession(
BaseSession[
ClientRequest,
ClientNotification,
ClientResult,
ServerRequest,
ServerNotification,
types.ClientRequest,
types.ClientNotification,
types.ClientResult,
types.ServerRequest,
types.ServerNotification,
]
):
def __init__(
self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
read_timeout_seconds: timedelta | None = None,
) -> None:
super().__init__(
read_stream,
write_stream,
ServerRequest,
ServerNotification,
types.ServerRequest,
types.ServerNotification,
read_timeout_seconds=read_timeout_seconds,
)
async def initialize(self) -> InitializeResult:
from mcp.types import (
InitializeRequest,
InitializeRequestParams,
)
async def initialize(self) -> types.InitializeResult:
result = await self.send_request(
ClientRequest(
InitializeRequest(
types.ClientRequest(
types.InitializeRequest(
method="initialize",
params=InitializeRequestParams(
protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=ClientCapabilities(
params=types.InitializeRequestParams(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=types.ClientCapabilities(
sampling=None,
experimental=None,
roots=RootsCapability(
roots=types.RootsCapability(
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
listChanged=True
),
),
clientInfo=Implementation(name="mcp", version="0.1.0"),
clientInfo=types.Implementation(name="mcp", version="0.1.0"),
),
)
),
InitializeResult,
types.InitializeResult,
)
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
@@ -91,40 +62,33 @@ class ClientSession(
)
await self.send_notification(
ClientNotification(
InitializedNotification(method="notifications/initialized")
types.ClientNotification(
types.InitializedNotification(method="notifications/initialized")
)
)
return result
async def send_ping(self) -> EmptyResult:
async def send_ping(self) -> types.EmptyResult:
"""Send a ping request."""
from mcp.types import PingRequest
return await self.send_request(
ClientRequest(
PingRequest(
types.ClientRequest(
types.PingRequest(
method="ping",
)
),
EmptyResult,
types.EmptyResult,
)
async def send_progress_notification(
self, progress_token: str | int, progress: float, total: float | None = None
) -> None:
"""Send a progress notification."""
from mcp.types import (
ProgressNotification,
ProgressNotificationParams,
)
await self.send_notification(
ClientNotification(
ProgressNotification(
types.ClientNotification(
types.ProgressNotification(
method="notifications/progress",
params=ProgressNotificationParams(
params=types.ProgressNotificationParams(
progressToken=progress_token,
progress=progress,
total=total,
@@ -133,180 +97,137 @@ class ClientSession(
)
)
async def set_logging_level(self, level: LoggingLevel) -> EmptyResult:
async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
"""Send a logging/setLevel request."""
from mcp.types import (
SetLevelRequest,
SetLevelRequestParams,
)
return await self.send_request(
ClientRequest(
SetLevelRequest(
types.ClientRequest(
types.SetLevelRequest(
method="logging/setLevel",
params=SetLevelRequestParams(level=level),
params=types.SetLevelRequestParams(level=level),
)
),
EmptyResult,
types.EmptyResult,
)
async def list_resources(self) -> ListResourcesResult:
async def list_resources(self) -> types.ListResourcesResult:
"""Send a resources/list request."""
from mcp.types import (
ListResourcesRequest,
)
return await self.send_request(
ClientRequest(
ListResourcesRequest(
types.ClientRequest(
types.ListResourcesRequest(
method="resources/list",
)
),
ListResourcesResult,
types.ListResourcesResult,
)
async def read_resource(self, uri: AnyUrl) -> ReadResourceResult:
async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
"""Send a resources/read request."""
from mcp.types import (
ReadResourceRequest,
ReadResourceRequestParams,
)
return await self.send_request(
ClientRequest(
ReadResourceRequest(
types.ClientRequest(
types.ReadResourceRequest(
method="resources/read",
params=ReadResourceRequestParams(uri=uri),
params=types.ReadResourceRequestParams(uri=uri),
)
),
ReadResourceResult,
types.ReadResourceResult,
)
async def subscribe_resource(self, uri: AnyUrl) -> EmptyResult:
async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
"""Send a resources/subscribe request."""
from mcp.types import (
SubscribeRequest,
SubscribeRequestParams,
)
return await self.send_request(
ClientRequest(
SubscribeRequest(
types.ClientRequest(
types.SubscribeRequest(
method="resources/subscribe",
params=SubscribeRequestParams(uri=uri),
params=types.SubscribeRequestParams(uri=uri),
)
),
EmptyResult,
types.EmptyResult,
)
async def unsubscribe_resource(self, uri: AnyUrl) -> EmptyResult:
async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
"""Send a resources/unsubscribe request."""
from mcp.types import (
UnsubscribeRequest,
UnsubscribeRequestParams,
)
return await self.send_request(
ClientRequest(
UnsubscribeRequest(
types.ClientRequest(
types.UnsubscribeRequest(
method="resources/unsubscribe",
params=UnsubscribeRequestParams(uri=uri),
params=types.UnsubscribeRequestParams(uri=uri),
)
),
EmptyResult,
types.EmptyResult,
)
async def call_tool(
self, name: str, arguments: dict | None = None
) -> CallToolResult:
) -> types.CallToolResult:
"""Send a tools/call request."""
from mcp.types import (
CallToolRequest,
CallToolRequestParams,
)
return await self.send_request(
ClientRequest(
CallToolRequest(
types.ClientRequest(
types.CallToolRequest(
method="tools/call",
params=CallToolRequestParams(name=name, arguments=arguments),
params=types.CallToolRequestParams(name=name, arguments=arguments),
)
),
CallToolResult,
types.CallToolResult,
)
async def list_prompts(self) -> ListPromptsResult:
async def list_prompts(self) -> types.ListPromptsResult:
"""Send a prompts/list request."""
from mcp.types import ListPromptsRequest
return await self.send_request(
ClientRequest(
ListPromptsRequest(
types.ClientRequest(
types.ListPromptsRequest(
method="prompts/list",
)
),
ListPromptsResult,
types.ListPromptsResult,
)
async def get_prompt(
self, name: str, arguments: dict[str, str] | None = None
) -> GetPromptResult:
) -> types.GetPromptResult:
"""Send a prompts/get request."""
from mcp.types import GetPromptRequest, GetPromptRequestParams
return await self.send_request(
ClientRequest(
GetPromptRequest(
types.ClientRequest(
types.GetPromptRequest(
method="prompts/get",
params=GetPromptRequestParams(name=name, arguments=arguments),
params=types.GetPromptRequestParams(name=name, arguments=arguments),
)
),
GetPromptResult,
types.GetPromptResult,
)
async def complete(
self, ref: ResourceReference | PromptReference, argument: dict
) -> CompleteResult:
self, ref: types.ResourceReference | types.PromptReference, argument: dict
) -> types.CompleteResult:
"""Send a completion/complete request."""
from mcp.types import (
CompleteRequest,
CompleteRequestParams,
CompletionArgument,
)
return await self.send_request(
ClientRequest(
CompleteRequest(
types.ClientRequest(
types.CompleteRequest(
method="completion/complete",
params=CompleteRequestParams(
params=types.CompleteRequestParams(
ref=ref,
argument=CompletionArgument(**argument),
argument=types.CompletionArgument(**argument),
),
)
),
CompleteResult,
types.CompleteResult,
)
async def list_tools(self) -> ListToolsResult:
async def list_tools(self) -> types.ListToolsResult:
"""Send a tools/list request."""
from mcp.types import ListToolsRequest
return await self.send_request(
ClientRequest(
ListToolsRequest(
types.ClientRequest(
types.ListToolsRequest(
method="tools/list",
)
),
ListToolsResult,
types.ListToolsResult,
)
async def send_roots_list_changed(self) -> None:
"""Send a roots/list_changed notification."""
from mcp.types import RootsListChangedNotification
await self.send_notification(
ClientNotification(
RootsListChangedNotification(
types.ClientNotification(
types.RootsListChangedNotification(
method="notifications/roots/list_changed",
)
)

View File

@@ -9,7 +9,7 @@ from anyio.abc import TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import aconnect_sse
from mcp.types import JSONRPCMessage
import mcp.types as types
logger = logging.getLogger(__name__)
@@ -31,11 +31,11 @@ async def sse_client(
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.
"""
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: MemoryObjectSendStream[JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -85,7 +85,7 @@ async def sse_client(
case "message":
try:
message = (
JSONRPCMessage.model_validate_json(
types.JSONRPCMessage.model_validate_json(
sse.data
)
)

View File

@@ -8,7 +8,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStre
from anyio.streams.text import TextReceiveStream
from pydantic import BaseModel, Field
from mcp.types import JSONRPCMessage
import mcp.types as types
# Environment variables to inherit by default
DEFAULT_INHERITED_ENV_VARS = (
@@ -72,11 +72,11 @@ async def stdio_client(server: StdioServerParameters):
Client transport for stdio: this will connect to a server by spawning a
process and communicating with it over stdin/stdout.
"""
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: MemoryObjectSendStream[JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -99,7 +99,7 @@ async def stdio_client(server: StdioServerParameters):
for line in lines:
try:
message = JSONRPCMessage.model_validate_json(line)
message = types.JSONRPCMessage.model_validate_json(line)
except Exception as exc:
await read_stream_writer.send(exc)
continue