Merge pull request #16 from modelcontextprotocol/davidsp/init-options

Introduce Initialization options that are passed to ServerSession
This commit is contained in:
David Soria Parra
2024-10-11 16:10:31 +01:00
committed by GitHub
6 changed files with 86 additions and 14 deletions

View File

@@ -32,6 +32,7 @@ from mcp_python.types import (
ReadResourceResult, ReadResourceResult,
Resource, Resource,
ResourceReference, ResourceReference,
ServerCapabilities,
ServerResult, ServerResult,
SetLevelRequest, SetLevelRequest,
SubscribeRequest, SubscribeRequest,
@@ -40,7 +41,6 @@ from mcp_python.types import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
request_ctx: contextvars.ContextVar[RequestContext] = contextvars.ContextVar( request_ctx: contextvars.ContextVar[RequestContext] = contextvars.ContextVar(
"request_ctx" "request_ctx"
) )
@@ -53,6 +53,33 @@ class Server:
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
logger.info(f"Initializing server '{name}'") logger.info(f"Initializing server '{name}'")
def create_initialization_options(self) -> types.InitializationOptions:
"""Create initialization options from this server instance."""
def pkg_version(package: str) -> str:
try:
from importlib.metadata import version
return version(package)
except Exception:
return "unknown"
return types.InitializationOptions(
server_name=self.name,
server_version=pkg_version("mcp_python"),
capabilities=self.get_capabilities(),
)
def get_capabilities(self) -> ServerCapabilities:
"""Convert existing handlers to a ServerCapabilities object."""
def get_capability(req_type: type) -> dict[str, Any] | None:
return {} if req_type in self.request_handlers else None
return ServerCapabilities(
prompts=get_capability(ListPromptsRequest),
resources=get_capability(ListResourcesRequest),
tools=get_capability(ListPromptsRequest),
logging=get_capability(SetLevelRequest)
)
@property @property
def request_context(self) -> RequestContext: def request_context(self) -> RequestContext:
"""If called outside of a request context, this will raise a LookupError.""" """If called outside of a request context, this will raise a LookupError."""
@@ -280,9 +307,10 @@ class Server:
self, self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage], write_stream: MemoryObjectSendStream[JSONRPCMessage],
initialization_options: types.InitializationOptions
): ):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
async with ServerSession(read_stream, write_stream) as session: async with ServerSession(read_stream, write_stream, initialization_options) as session:
async for message in session.incoming_messages: async for message in session.incoming_messages:
logger.debug(f"Received message: {message}") logger.debug(f"Received message: {message}")

View File

@@ -1,10 +1,12 @@
import logging import logging
import sys import sys
import importlib.metadata
import anyio import anyio
from mcp_python.server.session import ServerSession from mcp_python.server.session import ServerSession
from mcp_python.server.types import InitializationOptions
from mcp_python.server.stdio import stdio_server from mcp_python.server.stdio import stdio_server
from mcp_python.types import ServerCapabilities
if not sys.warnoptions: if not sys.warnoptions:
import warnings import warnings
@@ -26,8 +28,9 @@ async def receive_loop(session: ServerSession):
async def main(): async def main():
version = importlib.metadata.version("mcp_python")
async with stdio_server() as (read_stream, write_stream): async with stdio_server() as (read_stream, write_stream):
async with ServerSession(read_stream, write_stream) as session, write_stream: async with ServerSession(read_stream, write_stream, InitializationOptions(server_name="mcp_python", server_version=version, capabilities=ServerCapabilities())) as session, write_stream:
await receive_loop(session) await receive_loop(session)

View File

@@ -10,6 +10,7 @@ from mcp_python.shared.session import (
BaseSession, BaseSession,
RequestResponder, RequestResponder,
) )
from mcp_python.server.types import InitializationOptions
from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSION from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSION
from mcp_python.types import ( from mcp_python.types import (
ClientNotification, ClientNotification,
@@ -52,9 +53,11 @@ class ServerSession(
self, self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage], write_stream: MemoryObjectSendStream[JSONRPCMessage],
init_options: InitializationOptions
) -> None: ) -> None:
super().__init__(read_stream, write_stream, ClientRequest, ClientNotification) super().__init__(read_stream, write_stream, ClientRequest, ClientNotification)
self._initialization_state = InitializationState.NotInitialized self._initialization_state = InitializationState.NotInitialized
self._init_options = init_options
async def _received_request( async def _received_request(
self, responder: RequestResponder[ClientRequest, ServerResult] self, responder: RequestResponder[ClientRequest, ServerResult]
@@ -66,15 +69,10 @@ class ServerSession(
ServerResult( ServerResult(
InitializeResult( InitializeResult(
protocolVersion=SUPPORTED_PROTOCOL_VERSION, protocolVersion=SUPPORTED_PROTOCOL_VERSION,
capabilities=ServerCapabilities( capabilities=self._init_options.capabilities,
logging=None,
resources=None,
tools=None,
experimental=None,
prompts={},
),
serverInfo=Implementation( serverInfo=Implementation(
name="mcp_python", version="0.1.0" name=self._init_options.server_name,
version=self._init_options.server_version
), ),
) )
) )

View File

@@ -5,7 +5,8 @@ This module provides simpler types to use with the server for managing prompts.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal from typing import Literal
from mcp_python.types import Role from pydantic import BaseModel
from mcp_python.types import Role, ServerCapabilities
@dataclass @dataclass
@@ -25,3 +26,9 @@ class Message:
class PromptResponse: class PromptResponse:
messages: list[Message] messages: list[Message]
desc: str | None = None desc: str | None = None
class InitializationOptions(BaseModel):
server_name: str
server_version: str
capabilities: ServerCapabilities

View File

@@ -35,3 +35,8 @@ target-version = "py38"
[tool.ruff.lint.per-file-ignores] [tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401"] "__init__.py" = ["F401"]
[tool.uv]
dev-dependencies = [
"trio>=0.26.2",
]

View File

@@ -2,11 +2,14 @@ import anyio
import pytest import pytest
from mcp_python.client.session import ClientSession from mcp_python.client.session import ClientSession
from mcp_python.server import Server
from mcp_python.server.session import ServerSession from mcp_python.server.session import ServerSession
from mcp_python.server.types import InitializationOptions
from mcp_python.types import ( from mcp_python.types import (
ClientNotification, ClientNotification,
InitializedNotification, InitializedNotification,
JSONRPCMessage, JSONRPCMessage,
ServerCapabilities,
) )
@@ -30,7 +33,7 @@ async def test_server_session_initialize():
nonlocal received_initialized nonlocal received_initialized
async with ServerSession( async with ServerSession(
client_to_server_receive, server_to_client_send client_to_server_receive, server_to_client_send, InitializationOptions(server_name='mcp_python', server_version='0.1.0', capabilities=ServerCapabilities())
) as server_session: ) as server_session:
async for message in server_session.incoming_messages: async for message in server_session.incoming_messages:
if isinstance(message, Exception): if isinstance(message, Exception):
@@ -57,3 +60,31 @@ async def test_server_session_initialize():
pass pass
assert received_initialized assert received_initialized
@pytest.mark.anyio
async def test_server_capabilities():
server = Server("test")
# Initially no capabilities
caps = server.get_capabilities()
assert caps.prompts is None
assert caps.resources is None
# Add a prompts handler
@server.list_prompts()
async def list_prompts():
return []
caps = server.get_capabilities()
assert caps.prompts == {}
assert caps.resources is None
# Add a resources handler
@server.list_resources()
async def list_resources():
return []
caps = server.get_capabilities()
assert caps.prompts == {}
assert caps.resources == {}