mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Allow passing initialization options to a session
We need a way for servers to pass initialization options to the session. This is the beginning of this.
This commit is contained in:
@@ -32,6 +32,7 @@ from mcp_python.types import (
|
|||||||
ReadResourceResult,
|
ReadResourceResult,
|
||||||
Resource,
|
Resource,
|
||||||
ResourceReference,
|
ResourceReference,
|
||||||
|
ServerCapabilities,
|
||||||
ServerResult,
|
ServerResult,
|
||||||
SetLevelRequest,
|
SetLevelRequest,
|
||||||
SubscribeRequest,
|
SubscribeRequest,
|
||||||
@@ -40,7 +41,6 @@ from mcp_python.types import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
request_ctx: contextvars.ContextVar[RequestContext] = contextvars.ContextVar(
|
request_ctx: contextvars.ContextVar[RequestContext] = contextvars.ContextVar(
|
||||||
"request_ctx"
|
"request_ctx"
|
||||||
)
|
)
|
||||||
@@ -53,6 +53,33 @@ class Server:
|
|||||||
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
|
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
|
||||||
logger.info(f"Initializing server '{name}'")
|
logger.info(f"Initializing server '{name}'")
|
||||||
|
|
||||||
|
def create_initialization_options(self) -> types.InitializationOptions:
|
||||||
|
"""Create initialization options from this server instance."""
|
||||||
|
def pkg_version(package: str) -> str:
|
||||||
|
try:
|
||||||
|
from importlib.metadata import version
|
||||||
|
return version(package)
|
||||||
|
except Exception:
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
return types.InitializationOptions(
|
||||||
|
server_name=self.name,
|
||||||
|
server_version=pkg_version("mcp_python"),
|
||||||
|
capabilities=self.get_capabilities(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_capabilities(self) -> ServerCapabilities:
|
||||||
|
"""Convert existing handlers to a ServerCapabilities object."""
|
||||||
|
def get_capability(req_type: type) -> dict[str, Any] | None:
|
||||||
|
return {} if req_type in self.request_handlers else None
|
||||||
|
|
||||||
|
return ServerCapabilities(
|
||||||
|
prompts=get_capability(ListPromptsRequest),
|
||||||
|
resources=get_capability(ListResourcesRequest),
|
||||||
|
tools=get_capability(ListPromptsRequest),
|
||||||
|
logging=get_capability(SetLevelRequest)
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def request_context(self) -> RequestContext:
|
def request_context(self) -> RequestContext:
|
||||||
"""If called outside of a request context, this will raise a LookupError."""
|
"""If called outside of a request context, this will raise a LookupError."""
|
||||||
@@ -280,9 +307,10 @@ class Server:
|
|||||||
self,
|
self,
|
||||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||||
|
initialization_options: types.InitializationOptions
|
||||||
):
|
):
|
||||||
with warnings.catch_warnings(record=True) as w:
|
with warnings.catch_warnings(record=True) as w:
|
||||||
async with ServerSession(read_stream, write_stream) as session:
|
async with ServerSession(read_stream, write_stream, initialization_options) as session:
|
||||||
async for message in session.incoming_messages:
|
async for message in session.incoming_messages:
|
||||||
logger.debug(f"Received message: {message}")
|
logger.debug(f"Received message: {message}")
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
import importlib.metadata
|
||||||
import anyio
|
import anyio
|
||||||
|
|
||||||
from mcp_python.server.session import ServerSession
|
from mcp_python.server.session import ServerSession
|
||||||
|
from mcp_python.server.types import InitializationOptions
|
||||||
from mcp_python.server.stdio import stdio_server
|
from mcp_python.server.stdio import stdio_server
|
||||||
|
from mcp_python.types import ServerCapabilities
|
||||||
|
|
||||||
if not sys.warnoptions:
|
if not sys.warnoptions:
|
||||||
import warnings
|
import warnings
|
||||||
@@ -26,8 +28,9 @@ async def receive_loop(session: ServerSession):
|
|||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
version = importlib.metadata.version("mcp_python")
|
||||||
async with stdio_server() as (read_stream, write_stream):
|
async with stdio_server() as (read_stream, write_stream):
|
||||||
async with ServerSession(read_stream, write_stream) as session, write_stream:
|
async with ServerSession(read_stream, write_stream, InitializationOptions(server_name="mcp_python", server_version=version, capabilities=ServerCapabilities())) as session, write_stream:
|
||||||
await receive_loop(session)
|
await receive_loop(session)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from mcp_python.shared.session import (
|
|||||||
BaseSession,
|
BaseSession,
|
||||||
RequestResponder,
|
RequestResponder,
|
||||||
)
|
)
|
||||||
|
from mcp_python.server.types import InitializationOptions
|
||||||
from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSION
|
from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSION
|
||||||
from mcp_python.types import (
|
from mcp_python.types import (
|
||||||
ClientNotification,
|
ClientNotification,
|
||||||
@@ -52,9 +53,11 @@ class ServerSession(
|
|||||||
self,
|
self,
|
||||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||||
|
init_options: InitializationOptions
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(read_stream, write_stream, ClientRequest, ClientNotification)
|
super().__init__(read_stream, write_stream, ClientRequest, ClientNotification)
|
||||||
self._initialization_state = InitializationState.NotInitialized
|
self._initialization_state = InitializationState.NotInitialized
|
||||||
|
self._init_options = init_options
|
||||||
|
|
||||||
async def _received_request(
|
async def _received_request(
|
||||||
self, responder: RequestResponder[ClientRequest, ServerResult]
|
self, responder: RequestResponder[ClientRequest, ServerResult]
|
||||||
@@ -66,15 +69,10 @@ class ServerSession(
|
|||||||
ServerResult(
|
ServerResult(
|
||||||
InitializeResult(
|
InitializeResult(
|
||||||
protocolVersion=SUPPORTED_PROTOCOL_VERSION,
|
protocolVersion=SUPPORTED_PROTOCOL_VERSION,
|
||||||
capabilities=ServerCapabilities(
|
capabilities=self._init_options.capabilities,
|
||||||
logging=None,
|
|
||||||
resources=None,
|
|
||||||
tools=None,
|
|
||||||
experimental=None,
|
|
||||||
prompts={},
|
|
||||||
),
|
|
||||||
serverInfo=Implementation(
|
serverInfo=Implementation(
|
||||||
name="mcp_python", version="0.1.0"
|
name=self._init_options.server_name,
|
||||||
|
version=self._init_options.server_version
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,7 +5,8 @@ This module provides simpler types to use with the server for managing prompts.
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from mcp_python.types import Role
|
from pydantic import BaseModel
|
||||||
|
from mcp_python.types import Role, ServerCapabilities
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -25,3 +26,9 @@ class Message:
|
|||||||
class PromptResponse:
|
class PromptResponse:
|
||||||
messages: list[Message]
|
messages: list[Message]
|
||||||
desc: str | None = None
|
desc: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class InitializationOptions(BaseModel):
|
||||||
|
server_name: str
|
||||||
|
server_version: str
|
||||||
|
capabilities: ServerCapabilities
|
||||||
|
|||||||
@@ -35,3 +35,8 @@ target-version = "py38"
|
|||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
"__init__.py" = ["F401"]
|
"__init__.py" = ["F401"]
|
||||||
|
|
||||||
|
[tool.uv]
|
||||||
|
dev-dependencies = [
|
||||||
|
"trio>=0.26.2",
|
||||||
|
]
|
||||||
|
|||||||
@@ -2,11 +2,14 @@ import anyio
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from mcp_python.client.session import ClientSession
|
from mcp_python.client.session import ClientSession
|
||||||
|
from mcp_python.server import Server
|
||||||
from mcp_python.server.session import ServerSession
|
from mcp_python.server.session import ServerSession
|
||||||
|
from mcp_python.server.types import InitializationOptions
|
||||||
from mcp_python.types import (
|
from mcp_python.types import (
|
||||||
ClientNotification,
|
ClientNotification,
|
||||||
InitializedNotification,
|
InitializedNotification,
|
||||||
JSONRPCMessage,
|
JSONRPCMessage,
|
||||||
|
ServerCapabilities,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -30,7 +33,7 @@ async def test_server_session_initialize():
|
|||||||
nonlocal received_initialized
|
nonlocal received_initialized
|
||||||
|
|
||||||
async with ServerSession(
|
async with ServerSession(
|
||||||
client_to_server_receive, server_to_client_send
|
client_to_server_receive, server_to_client_send, InitializationOptions(server_name='mcp_python', server_version='0.1.0', capabilities=ServerCapabilities())
|
||||||
) as server_session:
|
) as server_session:
|
||||||
async for message in server_session.incoming_messages:
|
async for message in server_session.incoming_messages:
|
||||||
if isinstance(message, Exception):
|
if isinstance(message, Exception):
|
||||||
@@ -57,3 +60,31 @@ async def test_server_session_initialize():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
assert received_initialized
|
assert received_initialized
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_server_capabilities():
|
||||||
|
server = Server("test")
|
||||||
|
|
||||||
|
# Initially no capabilities
|
||||||
|
caps = server.get_capabilities()
|
||||||
|
assert caps.prompts is None
|
||||||
|
assert caps.resources is None
|
||||||
|
|
||||||
|
# Add a prompts handler
|
||||||
|
@server.list_prompts()
|
||||||
|
async def list_prompts():
|
||||||
|
return []
|
||||||
|
|
||||||
|
caps = server.get_capabilities()
|
||||||
|
assert caps.prompts == {}
|
||||||
|
assert caps.resources is None
|
||||||
|
|
||||||
|
# Add a resources handler
|
||||||
|
@server.list_resources()
|
||||||
|
async def list_resources():
|
||||||
|
return []
|
||||||
|
|
||||||
|
caps = server.get_capabilities()
|
||||||
|
assert caps.prompts == {}
|
||||||
|
assert caps.resources == {}
|
||||||
|
|||||||
Reference in New Issue
Block a user