mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 23:04:25 +01:00
Don't re-export types
We will be a bit more low level and expect callees to import mcp.types instead of relying in re-exported types. This makes usage more explicit and avoids potential collisions in mcp.server.
This commit is contained in:
@@ -11,29 +11,7 @@ from mcp.shared.session import (
|
||||
BaseSession,
|
||||
RequestResponder,
|
||||
)
|
||||
from mcp.types import (
|
||||
LATEST_PROTOCOL_VERSION,
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
CreateMessageResult,
|
||||
EmptyResult,
|
||||
Implementation,
|
||||
IncludeContext,
|
||||
InitializedNotification,
|
||||
InitializeRequest,
|
||||
InitializeResult,
|
||||
JSONRPCMessage,
|
||||
ListRootsResult,
|
||||
LoggingLevel,
|
||||
ModelPreferences,
|
||||
PromptListChangedNotification,
|
||||
ResourceListChangedNotification,
|
||||
SamplingMessage,
|
||||
ServerNotification,
|
||||
ServerRequest,
|
||||
ServerResult,
|
||||
ToolListChangedNotification,
|
||||
)
|
||||
import mcp.types as types
|
||||
|
||||
|
||||
class InitializationState(Enum):
|
||||
@@ -44,37 +22,37 @@ class InitializationState(Enum):
|
||||
|
||||
class ServerSession(
|
||||
BaseSession[
|
||||
ServerRequest,
|
||||
ServerNotification,
|
||||
ServerResult,
|
||||
ClientRequest,
|
||||
ClientNotification,
|
||||
types.ServerRequest,
|
||||
types.ServerNotification,
|
||||
types.ServerResult,
|
||||
types.ClientRequest,
|
||||
types.ClientNotification,
|
||||
]
|
||||
):
|
||||
_initialized: InitializationState = InitializationState.NotInitialized
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
|
||||
init_options: InitializationOptions,
|
||||
) -> None:
|
||||
super().__init__(read_stream, write_stream, ClientRequest, ClientNotification)
|
||||
super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification)
|
||||
self._initialization_state = InitializationState.NotInitialized
|
||||
self._init_options = init_options
|
||||
|
||||
async def _received_request(
|
||||
self, responder: RequestResponder[ClientRequest, ServerResult]
|
||||
self, responder: RequestResponder[types.ClientRequest, types.ServerResult]
|
||||
):
|
||||
match responder.request.root:
|
||||
case InitializeRequest():
|
||||
case types.InitializeRequest():
|
||||
self._initialization_state = InitializationState.Initializing
|
||||
await responder.respond(
|
||||
ServerResult(
|
||||
InitializeResult(
|
||||
protocolVersion=LATEST_PROTOCOL_VERSION,
|
||||
types.ServerResult(
|
||||
types.InitializeResult(
|
||||
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
||||
capabilities=self._init_options.capabilities,
|
||||
serverInfo=Implementation(
|
||||
serverInfo=types.Implementation(
|
||||
name=self._init_options.server_name,
|
||||
version=self._init_options.server_version,
|
||||
),
|
||||
@@ -87,11 +65,11 @@ class ServerSession(
|
||||
"Received request before initialization was complete"
|
||||
)
|
||||
|
||||
async def _received_notification(self, notification: ClientNotification) -> None:
|
||||
async def _received_notification(self, notification: types.ClientNotification) -> None:
|
||||
# Need this to avoid ASYNC910
|
||||
await anyio.lowlevel.checkpoint()
|
||||
match notification.root:
|
||||
case InitializedNotification():
|
||||
case types.InitializedNotification():
|
||||
self._initialization_state = InitializationState.Initialized
|
||||
case _:
|
||||
if self._initialization_state != InitializationState.Initialized:
|
||||
@@ -100,19 +78,14 @@ class ServerSession(
|
||||
)
|
||||
|
||||
async def send_log_message(
|
||||
self, level: LoggingLevel, data: Any, logger: str | None = None
|
||||
self, level: types.LoggingLevel, data: Any, logger: str | None = None
|
||||
) -> None:
|
||||
"""Send a log message notification."""
|
||||
from mcp.types import (
|
||||
LoggingMessageNotification,
|
||||
LoggingMessageNotificationParams,
|
||||
)
|
||||
|
||||
await self.send_notification(
|
||||
ServerNotification(
|
||||
LoggingMessageNotification(
|
||||
types.ServerNotification(
|
||||
types.LoggingMessageNotification(
|
||||
method="notifications/message",
|
||||
params=LoggingMessageNotificationParams(
|
||||
params=types.LoggingMessageNotificationParams(
|
||||
level=level,
|
||||
data=data,
|
||||
logger=logger,
|
||||
@@ -123,43 +96,33 @@ class ServerSession(
|
||||
|
||||
async def send_resource_updated(self, uri: AnyUrl) -> None:
|
||||
"""Send a resource updated notification."""
|
||||
from mcp.types import (
|
||||
ResourceUpdatedNotification,
|
||||
ResourceUpdatedNotificationParams,
|
||||
)
|
||||
|
||||
await self.send_notification(
|
||||
ServerNotification(
|
||||
ResourceUpdatedNotification(
|
||||
types.ServerNotification(
|
||||
types.ResourceUpdatedNotification(
|
||||
method="notifications/resources/updated",
|
||||
params=ResourceUpdatedNotificationParams(uri=uri),
|
||||
params=types.ResourceUpdatedNotificationParams(uri=uri),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def create_message(
|
||||
self,
|
||||
messages: list[SamplingMessage],
|
||||
messages: list[types.SamplingMessage],
|
||||
*,
|
||||
max_tokens: int,
|
||||
system_prompt: str | None = None,
|
||||
include_context: IncludeContext | None = None,
|
||||
include_context: types.IncludeContext | None = None,
|
||||
temperature: float | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model_preferences: ModelPreferences | None = None,
|
||||
) -> CreateMessageResult:
|
||||
model_preferences: types.ModelPreferences | None = None,
|
||||
) -> types.CreateMessageResult:
|
||||
"""Send a sampling/create_message request."""
|
||||
from mcp.types import (
|
||||
CreateMessageRequest,
|
||||
CreateMessageRequestParams,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ServerRequest(
|
||||
CreateMessageRequest(
|
||||
types.ServerRequest(
|
||||
types.CreateMessageRequest(
|
||||
method="sampling/createMessage",
|
||||
params=CreateMessageRequestParams(
|
||||
params=types.CreateMessageRequestParams(
|
||||
messages=messages,
|
||||
systemPrompt=system_prompt,
|
||||
includeContext=include_context,
|
||||
@@ -171,46 +134,40 @@ class ServerSession(
|
||||
),
|
||||
)
|
||||
),
|
||||
CreateMessageResult,
|
||||
types.CreateMessageResult,
|
||||
)
|
||||
|
||||
async def list_roots(self) -> ListRootsResult:
|
||||
async def list_roots(self) -> types.ListRootsResult:
|
||||
"""Send a roots/list request."""
|
||||
from mcp.types import ListRootsRequest
|
||||
|
||||
return await self.send_request(
|
||||
ServerRequest(
|
||||
ListRootsRequest(
|
||||
types.ServerRequest(
|
||||
types.ListRootsRequest(
|
||||
method="roots/list",
|
||||
)
|
||||
),
|
||||
ListRootsResult,
|
||||
types.ListRootsResult,
|
||||
)
|
||||
|
||||
async def send_ping(self) -> EmptyResult:
|
||||
async def send_ping(self) -> types.EmptyResult:
|
||||
"""Send a ping request."""
|
||||
from mcp.types import PingRequest
|
||||
|
||||
return await self.send_request(
|
||||
ServerRequest(
|
||||
PingRequest(
|
||||
types.ServerRequest(
|
||||
types.PingRequest(
|
||||
method="ping",
|
||||
)
|
||||
),
|
||||
EmptyResult,
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
async def send_progress_notification(
|
||||
self, progress_token: str | int, progress: float, total: float | None = None
|
||||
) -> None:
|
||||
"""Send a progress notification."""
|
||||
from mcp.types import ProgressNotification, ProgressNotificationParams
|
||||
|
||||
await self.send_notification(
|
||||
ServerNotification(
|
||||
ProgressNotification(
|
||||
types.ServerNotification(
|
||||
types.ProgressNotification(
|
||||
method="notifications/progress",
|
||||
params=ProgressNotificationParams(
|
||||
params=types.ProgressNotificationParams(
|
||||
progressToken=progress_token,
|
||||
progress=progress,
|
||||
total=total,
|
||||
@@ -222,8 +179,8 @@ class ServerSession(
|
||||
async def send_resource_list_changed(self) -> None:
|
||||
"""Send a resource list changed notification."""
|
||||
await self.send_notification(
|
||||
ServerNotification(
|
||||
ResourceListChangedNotification(
|
||||
types.ServerNotification(
|
||||
types.ResourceListChangedNotification(
|
||||
method="notifications/resources/list_changed",
|
||||
)
|
||||
)
|
||||
@@ -232,8 +189,8 @@ class ServerSession(
|
||||
async def send_tool_list_changed(self) -> None:
|
||||
"""Send a tool list changed notification."""
|
||||
await self.send_notification(
|
||||
ServerNotification(
|
||||
ToolListChangedNotification(
|
||||
types.ServerNotification(
|
||||
types.ToolListChangedNotification(
|
||||
method="notifications/tools/list_changed",
|
||||
)
|
||||
)
|
||||
@@ -242,8 +199,8 @@ class ServerSession(
|
||||
async def send_prompt_list_changed(self) -> None:
|
||||
"""Send a prompt list changed notification."""
|
||||
await self.send_notification(
|
||||
ServerNotification(
|
||||
PromptListChangedNotification(
|
||||
types.ServerNotification(
|
||||
types.PromptListChangedNotification(
|
||||
method="notifications/prompts/list_changed",
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user