mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
201 lines
6.5 KiB
Python
201 lines
6.5 KiB
Python
from enum import Enum
|
|
from typing import Any
|
|
|
|
import anyio
|
|
import anyio.lowlevel
|
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
from pydantic import AnyUrl
|
|
|
|
from mcp_python.server.types import InitializationOptions
|
|
from mcp_python.shared.session import (
|
|
BaseSession,
|
|
RequestResponder,
|
|
)
|
|
from mcp_python.types import (
|
|
LATEST_PROTOCOL_VERSION,
|
|
ClientNotification,
|
|
ClientRequest,
|
|
CreateMessageResult,
|
|
EmptyResult,
|
|
Implementation,
|
|
IncludeContext,
|
|
InitializedNotification,
|
|
InitializeRequest,
|
|
InitializeResult,
|
|
JSONRPCMessage,
|
|
LoggingLevel,
|
|
SamplingMessage,
|
|
ServerNotification,
|
|
ServerRequest,
|
|
ServerResult,
|
|
)
|
|
|
|
|
|
class InitializationState(Enum):
|
|
NotInitialized = 1
|
|
Initializing = 2
|
|
Initialized = 3
|
|
|
|
|
|
class ServerSession(
|
|
BaseSession[
|
|
ServerRequest,
|
|
ServerNotification,
|
|
ServerResult,
|
|
ClientRequest,
|
|
ClientNotification,
|
|
]
|
|
):
|
|
_initialized: InitializationState = InitializationState.NotInitialized
|
|
|
|
def __init__(
|
|
self,
|
|
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
|
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
|
init_options: InitializationOptions,
|
|
) -> None:
|
|
super().__init__(read_stream, write_stream, ClientRequest, ClientNotification)
|
|
self._initialization_state = InitializationState.NotInitialized
|
|
self._init_options = init_options
|
|
|
|
async def _received_request(
|
|
self, responder: RequestResponder[ClientRequest, ServerResult]
|
|
):
|
|
match responder.request.root:
|
|
case InitializeRequest():
|
|
self._initialization_state = InitializationState.Initializing
|
|
await responder.respond(
|
|
ServerResult(
|
|
InitializeResult(
|
|
protocolVersion=LATEST_PROTOCOL_VERSION,
|
|
capabilities=self._init_options.capabilities,
|
|
serverInfo=Implementation(
|
|
name=self._init_options.server_name,
|
|
version=self._init_options.server_version,
|
|
),
|
|
)
|
|
)
|
|
)
|
|
case _:
|
|
if self._initialization_state != InitializationState.Initialized:
|
|
raise RuntimeError(
|
|
"Received request before initialization was complete"
|
|
)
|
|
|
|
async def _received_notification(self, notification: ClientNotification) -> None:
|
|
# Need this to avoid ASYNC910
|
|
await anyio.lowlevel.checkpoint()
|
|
match notification.root:
|
|
case InitializedNotification():
|
|
self._initialization_state = InitializationState.Initialized
|
|
case _:
|
|
if self._initialization_state != InitializationState.Initialized:
|
|
raise RuntimeError(
|
|
"Received notification before initialization was complete"
|
|
)
|
|
|
|
async def send_log_message(
|
|
self, level: LoggingLevel, data: Any, logger: str | None = None
|
|
) -> None:
|
|
"""Send a log message notification."""
|
|
from mcp_python.types import (
|
|
LoggingMessageNotification,
|
|
LoggingMessageNotificationParams,
|
|
)
|
|
|
|
await self.send_notification(
|
|
ServerNotification(
|
|
LoggingMessageNotification(
|
|
method="notifications/message",
|
|
params=LoggingMessageNotificationParams(
|
|
level=level,
|
|
data=data,
|
|
logger=logger,
|
|
),
|
|
)
|
|
)
|
|
)
|
|
|
|
async def send_resource_updated(self, uri: AnyUrl) -> None:
|
|
"""Send a resource updated notification."""
|
|
from mcp_python.types import (
|
|
ResourceUpdatedNotification,
|
|
ResourceUpdatedNotificationParams,
|
|
)
|
|
|
|
await self.send_notification(
|
|
ServerNotification(
|
|
ResourceUpdatedNotification(
|
|
method="notifications/resources/updated",
|
|
params=ResourceUpdatedNotificationParams(uri=uri),
|
|
)
|
|
)
|
|
)
|
|
|
|
async def request_create_message(
|
|
self,
|
|
messages: list[SamplingMessage],
|
|
*,
|
|
max_tokens: int,
|
|
system_prompt: str | None = None,
|
|
include_context: IncludeContext | None = None,
|
|
temperature: float | None = None,
|
|
stop_sequences: list[str] | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> CreateMessageResult:
|
|
"""Send a sampling/create_message request."""
|
|
from mcp_python.types import (
|
|
CreateMessageRequest,
|
|
CreateMessageRequestParams,
|
|
)
|
|
|
|
return await self.send_request(
|
|
ServerRequest(
|
|
CreateMessageRequest(
|
|
method="sampling/createMessage",
|
|
params=CreateMessageRequestParams(
|
|
messages=messages,
|
|
systemPrompt=system_prompt,
|
|
includeContext=include_context,
|
|
temperature=temperature,
|
|
maxTokens=max_tokens,
|
|
stopSequences=stop_sequences,
|
|
metadata=metadata,
|
|
),
|
|
)
|
|
),
|
|
CreateMessageResult,
|
|
)
|
|
|
|
async def send_ping(self) -> EmptyResult:
|
|
"""Send a ping request."""
|
|
from mcp_python.types import PingRequest
|
|
|
|
return await self.send_request(
|
|
ServerRequest(
|
|
PingRequest(
|
|
method="ping",
|
|
)
|
|
),
|
|
EmptyResult,
|
|
)
|
|
|
|
async def send_progress_notification(
|
|
self, progress_token: str | int, progress: float, total: float | None = None
|
|
) -> None:
|
|
"""Send a progress notification."""
|
|
from mcp_python.types import ProgressNotification, ProgressNotificationParams
|
|
|
|
await self.send_notification(
|
|
ServerNotification(
|
|
ProgressNotification(
|
|
method="notifications/progress",
|
|
params=ProgressNotificationParams(
|
|
progressToken=progress_token,
|
|
progress=progress,
|
|
total=total,
|
|
),
|
|
)
|
|
)
|
|
)
|