Don't re-export types

We will be a bit more low level and expect callees to import mcp.types
instead of relying in re-exported types. This makes usage more explicit
and avoids potential collisions in mcp.server.
This commit is contained in:
David Soria Parra
2024-11-11 20:14:03 +00:00
parent f5d82bd229
commit b9b44e6dad
8 changed files with 231 additions and 413 deletions

View File

@@ -11,29 +11,7 @@ from mcp.shared.session import (
BaseSession,
RequestResponder,
)
from mcp.types import (
LATEST_PROTOCOL_VERSION,
ClientNotification,
ClientRequest,
CreateMessageResult,
EmptyResult,
Implementation,
IncludeContext,
InitializedNotification,
InitializeRequest,
InitializeResult,
JSONRPCMessage,
ListRootsResult,
LoggingLevel,
ModelPreferences,
PromptListChangedNotification,
ResourceListChangedNotification,
SamplingMessage,
ServerNotification,
ServerRequest,
ServerResult,
ToolListChangedNotification,
)
import mcp.types as types
class InitializationState(Enum):
@@ -44,37 +22,37 @@ class InitializationState(Enum):
class ServerSession(
BaseSession[
ServerRequest,
ServerNotification,
ServerResult,
ClientRequest,
ClientNotification,
types.ServerRequest,
types.ServerNotification,
types.ServerResult,
types.ClientRequest,
types.ClientNotification,
]
):
_initialized: InitializationState = InitializationState.NotInitialized
def __init__(
self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
init_options: InitializationOptions,
) -> None:
super().__init__(read_stream, write_stream, ClientRequest, ClientNotification)
super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification)
self._initialization_state = InitializationState.NotInitialized
self._init_options = init_options
async def _received_request(
self, responder: RequestResponder[ClientRequest, ServerResult]
self, responder: RequestResponder[types.ClientRequest, types.ServerResult]
):
match responder.request.root:
case InitializeRequest():
case types.InitializeRequest():
self._initialization_state = InitializationState.Initializing
await responder.respond(
ServerResult(
InitializeResult(
protocolVersion=LATEST_PROTOCOL_VERSION,
types.ServerResult(
types.InitializeResult(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=self._init_options.capabilities,
serverInfo=Implementation(
serverInfo=types.Implementation(
name=self._init_options.server_name,
version=self._init_options.server_version,
),
@@ -87,11 +65,11 @@ class ServerSession(
"Received request before initialization was complete"
)
async def _received_notification(self, notification: ClientNotification) -> None:
async def _received_notification(self, notification: types.ClientNotification) -> None:
# Need this to avoid ASYNC910
await anyio.lowlevel.checkpoint()
match notification.root:
case InitializedNotification():
case types.InitializedNotification():
self._initialization_state = InitializationState.Initialized
case _:
if self._initialization_state != InitializationState.Initialized:
@@ -100,19 +78,14 @@ class ServerSession(
)
async def send_log_message(
self, level: LoggingLevel, data: Any, logger: str | None = None
self, level: types.LoggingLevel, data: Any, logger: str | None = None
) -> None:
"""Send a log message notification."""
from mcp.types import (
LoggingMessageNotification,
LoggingMessageNotificationParams,
)
await self.send_notification(
ServerNotification(
LoggingMessageNotification(
types.ServerNotification(
types.LoggingMessageNotification(
method="notifications/message",
params=LoggingMessageNotificationParams(
params=types.LoggingMessageNotificationParams(
level=level,
data=data,
logger=logger,
@@ -123,43 +96,33 @@ class ServerSession(
async def send_resource_updated(self, uri: AnyUrl) -> None:
"""Send a resource updated notification."""
from mcp.types import (
ResourceUpdatedNotification,
ResourceUpdatedNotificationParams,
)
await self.send_notification(
ServerNotification(
ResourceUpdatedNotification(
types.ServerNotification(
types.ResourceUpdatedNotification(
method="notifications/resources/updated",
params=ResourceUpdatedNotificationParams(uri=uri),
params=types.ResourceUpdatedNotificationParams(uri=uri),
)
)
)
async def create_message(
self,
messages: list[SamplingMessage],
messages: list[types.SamplingMessage],
*,
max_tokens: int,
system_prompt: str | None = None,
include_context: IncludeContext | None = None,
include_context: types.IncludeContext | None = None,
temperature: float | None = None,
stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None,
model_preferences: ModelPreferences | None = None,
) -> CreateMessageResult:
model_preferences: types.ModelPreferences | None = None,
) -> types.CreateMessageResult:
"""Send a sampling/create_message request."""
from mcp.types import (
CreateMessageRequest,
CreateMessageRequestParams,
)
return await self.send_request(
ServerRequest(
CreateMessageRequest(
types.ServerRequest(
types.CreateMessageRequest(
method="sampling/createMessage",
params=CreateMessageRequestParams(
params=types.CreateMessageRequestParams(
messages=messages,
systemPrompt=system_prompt,
includeContext=include_context,
@@ -171,46 +134,40 @@ class ServerSession(
),
)
),
CreateMessageResult,
types.CreateMessageResult,
)
async def list_roots(self) -> ListRootsResult:
async def list_roots(self) -> types.ListRootsResult:
"""Send a roots/list request."""
from mcp.types import ListRootsRequest
return await self.send_request(
ServerRequest(
ListRootsRequest(
types.ServerRequest(
types.ListRootsRequest(
method="roots/list",
)
),
ListRootsResult,
types.ListRootsResult,
)
async def send_ping(self) -> EmptyResult:
async def send_ping(self) -> types.EmptyResult:
"""Send a ping request."""
from mcp.types import PingRequest
return await self.send_request(
ServerRequest(
PingRequest(
types.ServerRequest(
types.PingRequest(
method="ping",
)
),
EmptyResult,
types.EmptyResult,
)
async def send_progress_notification(
self, progress_token: str | int, progress: float, total: float | None = None
) -> None:
"""Send a progress notification."""
from mcp.types import ProgressNotification, ProgressNotificationParams
await self.send_notification(
ServerNotification(
ProgressNotification(
types.ServerNotification(
types.ProgressNotification(
method="notifications/progress",
params=ProgressNotificationParams(
params=types.ProgressNotificationParams(
progressToken=progress_token,
progress=progress,
total=total,
@@ -222,8 +179,8 @@ class ServerSession(
async def send_resource_list_changed(self) -> None:
"""Send a resource list changed notification."""
await self.send_notification(
ServerNotification(
ResourceListChangedNotification(
types.ServerNotification(
types.ResourceListChangedNotification(
method="notifications/resources/list_changed",
)
)
@@ -232,8 +189,8 @@ class ServerSession(
async def send_tool_list_changed(self) -> None:
"""Send a tool list changed notification."""
await self.send_notification(
ServerNotification(
ToolListChangedNotification(
types.ServerNotification(
types.ToolListChangedNotification(
method="notifications/tools/list_changed",
)
)
@@ -242,8 +199,8 @@ class ServerSession(
async def send_prompt_list_changed(self) -> None:
"""Send a prompt list changed notification."""
await self.send_notification(
ServerNotification(
PromptListChangedNotification(
types.ServerNotification(
types.PromptListChangedNotification(
method="notifications/prompts/list_changed",
)
)