mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
336 lines
12 KiB
Python
336 lines
12 KiB
Python
"""
|
|
ServerSession Module
|
|
|
|
This module provides the ServerSession class, which manages communication between the
|
|
server and client in the MCP (Model Context Protocol) framework. It is most commonly
|
|
used in MCP servers to interact with the client.
|
|
|
|
Common usage pattern:
|
|
```
|
|
server = Server(name)
|
|
|
|
@server.call_tool()
|
|
async def handle_tool_call(ctx: RequestContext, arguments: dict[str, Any]) -> Any:
|
|
# Check client capabilities before proceeding
|
|
if ctx.session.check_client_capability(
|
|
types.ClientCapabilities(experimental={"advanced_tools": dict()})
|
|
):
|
|
# Perform advanced tool operations
|
|
result = await perform_advanced_tool_operation(arguments)
|
|
else:
|
|
# Fall back to basic tool operations
|
|
result = await perform_basic_tool_operation(arguments)
|
|
|
|
return result
|
|
|
|
@server.list_prompts()
|
|
async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
|
|
# Access session for any necessary checks or operations
|
|
if ctx.session.client_params:
|
|
# Customize prompts based on client initialization parameters
|
|
return generate_custom_prompts(ctx.session.client_params)
|
|
else:
|
|
return default_prompts
|
|
```
|
|
|
|
The ServerSession class is typically used internally by the Server class and should not
|
|
be instantiated directly by users of the MCP framework.
|
|
"""
|
|
|
|
from enum import Enum
|
|
from typing import Any, TypeVar
|
|
|
|
import anyio
|
|
import anyio.lowlevel
|
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
from pydantic import AnyUrl
|
|
|
|
import mcp.types as types
|
|
from mcp.server.models import InitializationOptions
|
|
from mcp.shared.message import SessionMessage
|
|
from mcp.shared.session import (
|
|
BaseSession,
|
|
RequestResponder,
|
|
)
|
|
|
|
|
|
class InitializationState(Enum):
|
|
NotInitialized = 1
|
|
Initializing = 2
|
|
Initialized = 3
|
|
|
|
|
|
ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession")
|
|
|
|
ServerRequestResponder = (
|
|
RequestResponder[types.ClientRequest, types.ServerResult]
|
|
| types.ClientNotification
|
|
| Exception
|
|
)
|
|
|
|
|
|
class ServerSession(
|
|
BaseSession[
|
|
types.ServerRequest,
|
|
types.ServerNotification,
|
|
types.ServerResult,
|
|
types.ClientRequest,
|
|
types.ClientNotification,
|
|
]
|
|
):
|
|
_initialized: InitializationState = InitializationState.NotInitialized
|
|
_client_params: types.InitializeRequestParams | None = None
|
|
|
|
def __init__(
|
|
self,
|
|
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
|
|
write_stream: MemoryObjectSendStream[SessionMessage],
|
|
init_options: InitializationOptions,
|
|
stateless: bool = False,
|
|
) -> None:
|
|
super().__init__(
|
|
read_stream, write_stream, types.ClientRequest, types.ClientNotification
|
|
)
|
|
self._initialization_state = (
|
|
InitializationState.Initialized
|
|
if stateless
|
|
else InitializationState.NotInitialized
|
|
)
|
|
|
|
self._init_options = init_options
|
|
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
|
|
anyio.create_memory_object_stream[ServerRequestResponder](0)
|
|
)
|
|
self._exit_stack.push_async_callback(
|
|
lambda: self._incoming_message_stream_reader.aclose()
|
|
)
|
|
|
|
@property
|
|
def client_params(self) -> types.InitializeRequestParams | None:
|
|
return self._client_params
|
|
|
|
def check_client_capability(self, capability: types.ClientCapabilities) -> bool:
|
|
"""Check if the client supports a specific capability."""
|
|
if self._client_params is None:
|
|
return False
|
|
|
|
# Get client capabilities from initialization params
|
|
client_caps = self._client_params.capabilities
|
|
|
|
# Check each specified capability in the passed in capability object
|
|
if capability.roots is not None:
|
|
if client_caps.roots is None:
|
|
return False
|
|
if capability.roots.listChanged and not client_caps.roots.listChanged:
|
|
return False
|
|
|
|
if capability.sampling is not None:
|
|
if client_caps.sampling is None:
|
|
return False
|
|
|
|
if capability.experimental is not None:
|
|
if client_caps.experimental is None:
|
|
return False
|
|
# Check each experimental capability
|
|
for exp_key, exp_value in capability.experimental.items():
|
|
if (
|
|
exp_key not in client_caps.experimental
|
|
or client_caps.experimental[exp_key] != exp_value
|
|
):
|
|
return False
|
|
|
|
return True
|
|
|
|
async def _receive_loop(self) -> None:
|
|
async with self._incoming_message_stream_writer:
|
|
await super()._receive_loop()
|
|
|
|
async def _received_request(
|
|
self, responder: RequestResponder[types.ClientRequest, types.ServerResult]
|
|
):
|
|
match responder.request.root:
|
|
case types.InitializeRequest(params=params):
|
|
self._initialization_state = InitializationState.Initializing
|
|
self._client_params = params
|
|
with responder:
|
|
await responder.respond(
|
|
types.ServerResult(
|
|
types.InitializeResult(
|
|
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
|
capabilities=self._init_options.capabilities,
|
|
serverInfo=types.Implementation(
|
|
name=self._init_options.server_name,
|
|
version=self._init_options.server_version,
|
|
),
|
|
instructions=self._init_options.instructions,
|
|
)
|
|
)
|
|
)
|
|
case _:
|
|
if self._initialization_state != InitializationState.Initialized:
|
|
raise RuntimeError(
|
|
"Received request before initialization was complete"
|
|
)
|
|
|
|
async def _received_notification(
|
|
self, notification: types.ClientNotification
|
|
) -> None:
|
|
# Need this to avoid ASYNC910
|
|
await anyio.lowlevel.checkpoint()
|
|
match notification.root:
|
|
case types.InitializedNotification():
|
|
self._initialization_state = InitializationState.Initialized
|
|
case _:
|
|
if self._initialization_state != InitializationState.Initialized:
|
|
raise RuntimeError(
|
|
"Received notification before initialization was complete"
|
|
)
|
|
|
|
async def send_log_message(
|
|
self,
|
|
level: types.LoggingLevel,
|
|
data: Any,
|
|
logger: str | None = None,
|
|
related_request_id: types.RequestId | None = None,
|
|
) -> None:
|
|
"""Send a log message notification."""
|
|
await self.send_notification(
|
|
types.ServerNotification(
|
|
types.LoggingMessageNotification(
|
|
method="notifications/message",
|
|
params=types.LoggingMessageNotificationParams(
|
|
level=level,
|
|
data=data,
|
|
logger=logger,
|
|
),
|
|
)
|
|
),
|
|
related_request_id,
|
|
)
|
|
|
|
async def send_resource_updated(self, uri: AnyUrl) -> None:
|
|
"""Send a resource updated notification."""
|
|
await self.send_notification(
|
|
types.ServerNotification(
|
|
types.ResourceUpdatedNotification(
|
|
method="notifications/resources/updated",
|
|
params=types.ResourceUpdatedNotificationParams(uri=uri),
|
|
)
|
|
)
|
|
)
|
|
|
|
async def create_message(
|
|
self,
|
|
messages: list[types.SamplingMessage],
|
|
*,
|
|
max_tokens: int,
|
|
system_prompt: str | None = None,
|
|
include_context: types.IncludeContext | None = None,
|
|
temperature: float | None = None,
|
|
stop_sequences: list[str] | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
model_preferences: types.ModelPreferences | None = None,
|
|
) -> types.CreateMessageResult:
|
|
"""Send a sampling/create_message request."""
|
|
return await self.send_request(
|
|
types.ServerRequest(
|
|
types.CreateMessageRequest(
|
|
method="sampling/createMessage",
|
|
params=types.CreateMessageRequestParams(
|
|
messages=messages,
|
|
systemPrompt=system_prompt,
|
|
includeContext=include_context,
|
|
temperature=temperature,
|
|
maxTokens=max_tokens,
|
|
stopSequences=stop_sequences,
|
|
metadata=metadata,
|
|
modelPreferences=model_preferences,
|
|
),
|
|
)
|
|
),
|
|
types.CreateMessageResult,
|
|
)
|
|
|
|
async def list_roots(self) -> types.ListRootsResult:
|
|
"""Send a roots/list request."""
|
|
return await self.send_request(
|
|
types.ServerRequest(
|
|
types.ListRootsRequest(
|
|
method="roots/list",
|
|
)
|
|
),
|
|
types.ListRootsResult,
|
|
)
|
|
|
|
async def send_ping(self) -> types.EmptyResult:
|
|
"""Send a ping request."""
|
|
return await self.send_request(
|
|
types.ServerRequest(
|
|
types.PingRequest(
|
|
method="ping",
|
|
)
|
|
),
|
|
types.EmptyResult,
|
|
)
|
|
|
|
async def send_progress_notification(
|
|
self,
|
|
progress_token: str | int,
|
|
progress: float,
|
|
total: float | None = None,
|
|
related_request_id: str | None = None,
|
|
) -> None:
|
|
"""Send a progress notification."""
|
|
await self.send_notification(
|
|
types.ServerNotification(
|
|
types.ProgressNotification(
|
|
method="notifications/progress",
|
|
params=types.ProgressNotificationParams(
|
|
progressToken=progress_token,
|
|
progress=progress,
|
|
total=total,
|
|
),
|
|
)
|
|
),
|
|
related_request_id,
|
|
)
|
|
|
|
async def send_resource_list_changed(self) -> None:
|
|
"""Send a resource list changed notification."""
|
|
await self.send_notification(
|
|
types.ServerNotification(
|
|
types.ResourceListChangedNotification(
|
|
method="notifications/resources/list_changed",
|
|
)
|
|
)
|
|
)
|
|
|
|
async def send_tool_list_changed(self) -> None:
|
|
"""Send a tool list changed notification."""
|
|
await self.send_notification(
|
|
types.ServerNotification(
|
|
types.ToolListChangedNotification(
|
|
method="notifications/tools/list_changed",
|
|
)
|
|
)
|
|
)
|
|
|
|
async def send_prompt_list_changed(self) -> None:
|
|
"""Send a prompt list changed notification."""
|
|
await self.send_notification(
|
|
types.ServerNotification(
|
|
types.PromptListChangedNotification(
|
|
method="notifications/prompts/list_changed",
|
|
)
|
|
)
|
|
)
|
|
|
|
async def _handle_incoming(self, req: ServerRequestResponder) -> None:
|
|
await self._incoming_message_stream_writer.send(req)
|
|
|
|
@property
|
|
def incoming_messages(
|
|
self,
|
|
) -> MemoryObjectReceiveStream[ServerRequestResponder]:
|
|
return self._incoming_message_stream_reader
|