mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2026-01-14 19:24:25 +01:00
Don't re-export types
We will be a bit more low level and expect callees to import mcp.types instead of relying in re-exported types. This makes usage more explicit and avoids potential collisions in mcp.server.
This commit is contained in:
@@ -5,83 +5,54 @@ from pydantic import AnyUrl
|
||||
|
||||
from mcp.shared.session import BaseSession
|
||||
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
||||
from mcp.types import (
|
||||
LATEST_PROTOCOL_VERSION,
|
||||
CallToolResult,
|
||||
ClientCapabilities,
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
ClientResult,
|
||||
CompleteResult,
|
||||
EmptyResult,
|
||||
GetPromptResult,
|
||||
Implementation,
|
||||
InitializedNotification,
|
||||
InitializeResult,
|
||||
JSONRPCMessage,
|
||||
ListPromptsResult,
|
||||
ListResourcesResult,
|
||||
ListToolsResult,
|
||||
LoggingLevel,
|
||||
PromptReference,
|
||||
ReadResourceResult,
|
||||
ResourceReference,
|
||||
RootsCapability,
|
||||
ServerNotification,
|
||||
ServerRequest,
|
||||
)
|
||||
import mcp.types as types
|
||||
|
||||
|
||||
class ClientSession(
|
||||
BaseSession[
|
||||
ClientRequest,
|
||||
ClientNotification,
|
||||
ClientResult,
|
||||
ServerRequest,
|
||||
ServerNotification,
|
||||
types.ClientRequest,
|
||||
types.ClientNotification,
|
||||
types.ClientResult,
|
||||
types.ServerRequest,
|
||||
types.ServerNotification,
|
||||
]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
read_stream,
|
||||
write_stream,
|
||||
ServerRequest,
|
||||
ServerNotification,
|
||||
types.ServerRequest,
|
||||
types.ServerNotification,
|
||||
read_timeout_seconds=read_timeout_seconds,
|
||||
)
|
||||
|
||||
async def initialize(self) -> InitializeResult:
|
||||
from mcp.types import (
|
||||
InitializeRequest,
|
||||
InitializeRequestParams,
|
||||
)
|
||||
|
||||
async def initialize(self) -> types.InitializeResult:
|
||||
result = await self.send_request(
|
||||
ClientRequest(
|
||||
InitializeRequest(
|
||||
types.ClientRequest(
|
||||
types.InitializeRequest(
|
||||
method="initialize",
|
||||
params=InitializeRequestParams(
|
||||
protocolVersion=LATEST_PROTOCOL_VERSION,
|
||||
capabilities=ClientCapabilities(
|
||||
params=types.InitializeRequestParams(
|
||||
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
||||
capabilities=types.ClientCapabilities(
|
||||
sampling=None,
|
||||
experimental=None,
|
||||
roots=RootsCapability(
|
||||
roots=types.RootsCapability(
|
||||
# TODO: Should this be based on whether we
|
||||
# _will_ send notifications, or only whether
|
||||
# they're supported?
|
||||
listChanged=True
|
||||
),
|
||||
),
|
||||
clientInfo=Implementation(name="mcp", version="0.1.0"),
|
||||
clientInfo=types.Implementation(name="mcp", version="0.1.0"),
|
||||
),
|
||||
)
|
||||
),
|
||||
InitializeResult,
|
||||
types.InitializeResult,
|
||||
)
|
||||
|
||||
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
|
||||
@@ -91,40 +62,33 @@ class ClientSession(
|
||||
)
|
||||
|
||||
await self.send_notification(
|
||||
ClientNotification(
|
||||
InitializedNotification(method="notifications/initialized")
|
||||
types.ClientNotification(
|
||||
types.InitializedNotification(method="notifications/initialized")
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def send_ping(self) -> EmptyResult:
|
||||
async def send_ping(self) -> types.EmptyResult:
|
||||
"""Send a ping request."""
|
||||
from mcp.types import PingRequest
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
PingRequest(
|
||||
types.ClientRequest(
|
||||
types.PingRequest(
|
||||
method="ping",
|
||||
)
|
||||
),
|
||||
EmptyResult,
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
async def send_progress_notification(
|
||||
self, progress_token: str | int, progress: float, total: float | None = None
|
||||
) -> None:
|
||||
"""Send a progress notification."""
|
||||
from mcp.types import (
|
||||
ProgressNotification,
|
||||
ProgressNotificationParams,
|
||||
)
|
||||
|
||||
await self.send_notification(
|
||||
ClientNotification(
|
||||
ProgressNotification(
|
||||
types.ClientNotification(
|
||||
types.ProgressNotification(
|
||||
method="notifications/progress",
|
||||
params=ProgressNotificationParams(
|
||||
params=types.ProgressNotificationParams(
|
||||
progressToken=progress_token,
|
||||
progress=progress,
|
||||
total=total,
|
||||
@@ -133,180 +97,137 @@ class ClientSession(
|
||||
)
|
||||
)
|
||||
|
||||
async def set_logging_level(self, level: LoggingLevel) -> EmptyResult:
|
||||
async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
|
||||
"""Send a logging/setLevel request."""
|
||||
from mcp.types import (
|
||||
SetLevelRequest,
|
||||
SetLevelRequestParams,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
SetLevelRequest(
|
||||
types.ClientRequest(
|
||||
types.SetLevelRequest(
|
||||
method="logging/setLevel",
|
||||
params=SetLevelRequestParams(level=level),
|
||||
params=types.SetLevelRequestParams(level=level),
|
||||
)
|
||||
),
|
||||
EmptyResult,
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
async def list_resources(self) -> ListResourcesResult:
|
||||
async def list_resources(self) -> types.ListResourcesResult:
|
||||
"""Send a resources/list request."""
|
||||
from mcp.types import (
|
||||
ListResourcesRequest,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
ListResourcesRequest(
|
||||
types.ClientRequest(
|
||||
types.ListResourcesRequest(
|
||||
method="resources/list",
|
||||
)
|
||||
),
|
||||
ListResourcesResult,
|
||||
types.ListResourcesResult,
|
||||
)
|
||||
|
||||
async def read_resource(self, uri: AnyUrl) -> ReadResourceResult:
|
||||
async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
|
||||
"""Send a resources/read request."""
|
||||
from mcp.types import (
|
||||
ReadResourceRequest,
|
||||
ReadResourceRequestParams,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
ReadResourceRequest(
|
||||
types.ClientRequest(
|
||||
types.ReadResourceRequest(
|
||||
method="resources/read",
|
||||
params=ReadResourceRequestParams(uri=uri),
|
||||
params=types.ReadResourceRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
ReadResourceResult,
|
||||
types.ReadResourceResult,
|
||||
)
|
||||
|
||||
async def subscribe_resource(self, uri: AnyUrl) -> EmptyResult:
|
||||
async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
|
||||
"""Send a resources/subscribe request."""
|
||||
from mcp.types import (
|
||||
SubscribeRequest,
|
||||
SubscribeRequestParams,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
SubscribeRequest(
|
||||
types.ClientRequest(
|
||||
types.SubscribeRequest(
|
||||
method="resources/subscribe",
|
||||
params=SubscribeRequestParams(uri=uri),
|
||||
params=types.SubscribeRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
EmptyResult,
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
async def unsubscribe_resource(self, uri: AnyUrl) -> EmptyResult:
|
||||
async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
|
||||
"""Send a resources/unsubscribe request."""
|
||||
from mcp.types import (
|
||||
UnsubscribeRequest,
|
||||
UnsubscribeRequestParams,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
UnsubscribeRequest(
|
||||
types.ClientRequest(
|
||||
types.UnsubscribeRequest(
|
||||
method="resources/unsubscribe",
|
||||
params=UnsubscribeRequestParams(uri=uri),
|
||||
params=types.UnsubscribeRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
EmptyResult,
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
async def call_tool(
|
||||
self, name: str, arguments: dict | None = None
|
||||
) -> CallToolResult:
|
||||
) -> types.CallToolResult:
|
||||
"""Send a tools/call request."""
|
||||
from mcp.types import (
|
||||
CallToolRequest,
|
||||
CallToolRequestParams,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
CallToolRequest(
|
||||
types.ClientRequest(
|
||||
types.CallToolRequest(
|
||||
method="tools/call",
|
||||
params=CallToolRequestParams(name=name, arguments=arguments),
|
||||
params=types.CallToolRequestParams(name=name, arguments=arguments),
|
||||
)
|
||||
),
|
||||
CallToolResult,
|
||||
types.CallToolResult,
|
||||
)
|
||||
|
||||
async def list_prompts(self) -> ListPromptsResult:
|
||||
async def list_prompts(self) -> types.ListPromptsResult:
|
||||
"""Send a prompts/list request."""
|
||||
from mcp.types import ListPromptsRequest
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
ListPromptsRequest(
|
||||
types.ClientRequest(
|
||||
types.ListPromptsRequest(
|
||||
method="prompts/list",
|
||||
)
|
||||
),
|
||||
ListPromptsResult,
|
||||
types.ListPromptsResult,
|
||||
)
|
||||
|
||||
async def get_prompt(
|
||||
self, name: str, arguments: dict[str, str] | None = None
|
||||
) -> GetPromptResult:
|
||||
) -> types.GetPromptResult:
|
||||
"""Send a prompts/get request."""
|
||||
from mcp.types import GetPromptRequest, GetPromptRequestParams
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
GetPromptRequest(
|
||||
types.ClientRequest(
|
||||
types.GetPromptRequest(
|
||||
method="prompts/get",
|
||||
params=GetPromptRequestParams(name=name, arguments=arguments),
|
||||
params=types.GetPromptRequestParams(name=name, arguments=arguments),
|
||||
)
|
||||
),
|
||||
GetPromptResult,
|
||||
types.GetPromptResult,
|
||||
)
|
||||
|
||||
async def complete(
|
||||
self, ref: ResourceReference | PromptReference, argument: dict
|
||||
) -> CompleteResult:
|
||||
self, ref: types.ResourceReference | types.PromptReference, argument: dict
|
||||
) -> types.CompleteResult:
|
||||
"""Send a completion/complete request."""
|
||||
from mcp.types import (
|
||||
CompleteRequest,
|
||||
CompleteRequestParams,
|
||||
CompletionArgument,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
CompleteRequest(
|
||||
types.ClientRequest(
|
||||
types.CompleteRequest(
|
||||
method="completion/complete",
|
||||
params=CompleteRequestParams(
|
||||
params=types.CompleteRequestParams(
|
||||
ref=ref,
|
||||
argument=CompletionArgument(**argument),
|
||||
argument=types.CompletionArgument(**argument),
|
||||
),
|
||||
)
|
||||
),
|
||||
CompleteResult,
|
||||
types.CompleteResult,
|
||||
)
|
||||
|
||||
async def list_tools(self) -> ListToolsResult:
|
||||
async def list_tools(self) -> types.ListToolsResult:
|
||||
"""Send a tools/list request."""
|
||||
from mcp.types import ListToolsRequest
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
ListToolsRequest(
|
||||
types.ClientRequest(
|
||||
types.ListToolsRequest(
|
||||
method="tools/list",
|
||||
)
|
||||
),
|
||||
ListToolsResult,
|
||||
types.ListToolsResult,
|
||||
)
|
||||
|
||||
async def send_roots_list_changed(self) -> None:
|
||||
"""Send a roots/list_changed notification."""
|
||||
from mcp.types import RootsListChangedNotification
|
||||
|
||||
await self.send_notification(
|
||||
ClientNotification(
|
||||
RootsListChangedNotification(
|
||||
types.ClientNotification(
|
||||
types.RootsListChangedNotification(
|
||||
method="notifications/roots/list_changed",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@ from anyio.abc import TaskStatus
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from httpx_sse import aconnect_sse
|
||||
|
||||
from mcp.types import JSONRPCMessage
|
||||
import mcp.types as types
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,11 +31,11 @@ async def sse_client(
|
||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
||||
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||
"""
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
@@ -85,7 +85,7 @@ async def sse_client(
|
||||
case "message":
|
||||
try:
|
||||
message = (
|
||||
JSONRPCMessage.model_validate_json(
|
||||
types.JSONRPCMessage.model_validate_json(
|
||||
sse.data
|
||||
)
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStre
|
||||
from anyio.streams.text import TextReceiveStream
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from mcp.types import JSONRPCMessage
|
||||
import mcp.types as types
|
||||
|
||||
# Environment variables to inherit by default
|
||||
DEFAULT_INHERITED_ENV_VARS = (
|
||||
@@ -72,11 +72,11 @@ async def stdio_client(server: StdioServerParameters):
|
||||
Client transport for stdio: this will connect to a server by spawning a
|
||||
process and communicating with it over stdin/stdout.
|
||||
"""
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
@@ -99,7 +99,7 @@ async def stdio_client(server: StdioServerParameters):
|
||||
|
||||
for line in lines:
|
||||
try:
|
||||
message = JSONRPCMessage.model_validate_json(line)
|
||||
message = types.JSONRPCMessage.model_validate_json(line)
|
||||
except Exception as exc:
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user