mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Don't re-export types
We will be a bit more low level and expect callees to import mcp.types instead of relying in re-exported types. This makes usage more explicit and avoids potential collisions in mcp.server.
This commit is contained in:
@@ -5,83 +5,54 @@ from pydantic import AnyUrl
|
|||||||
|
|
||||||
from mcp.shared.session import BaseSession
|
from mcp.shared.session import BaseSession
|
||||||
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
||||||
from mcp.types import (
|
import mcp.types as types
|
||||||
LATEST_PROTOCOL_VERSION,
|
|
||||||
CallToolResult,
|
|
||||||
ClientCapabilities,
|
|
||||||
ClientNotification,
|
|
||||||
ClientRequest,
|
|
||||||
ClientResult,
|
|
||||||
CompleteResult,
|
|
||||||
EmptyResult,
|
|
||||||
GetPromptResult,
|
|
||||||
Implementation,
|
|
||||||
InitializedNotification,
|
|
||||||
InitializeResult,
|
|
||||||
JSONRPCMessage,
|
|
||||||
ListPromptsResult,
|
|
||||||
ListResourcesResult,
|
|
||||||
ListToolsResult,
|
|
||||||
LoggingLevel,
|
|
||||||
PromptReference,
|
|
||||||
ReadResourceResult,
|
|
||||||
ResourceReference,
|
|
||||||
RootsCapability,
|
|
||||||
ServerNotification,
|
|
||||||
ServerRequest,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ClientSession(
|
class ClientSession(
|
||||||
BaseSession[
|
BaseSession[
|
||||||
ClientRequest,
|
types.ClientRequest,
|
||||||
ClientNotification,
|
types.ClientNotification,
|
||||||
ClientResult,
|
types.ClientResult,
|
||||||
ServerRequest,
|
types.ServerRequest,
|
||||||
ServerNotification,
|
types.ServerNotification,
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
|
||||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
|
||||||
read_timeout_seconds: timedelta | None = None,
|
read_timeout_seconds: timedelta | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
read_stream,
|
read_stream,
|
||||||
write_stream,
|
write_stream,
|
||||||
ServerRequest,
|
types.ServerRequest,
|
||||||
ServerNotification,
|
types.ServerNotification,
|
||||||
read_timeout_seconds=read_timeout_seconds,
|
read_timeout_seconds=read_timeout_seconds,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def initialize(self) -> InitializeResult:
|
async def initialize(self) -> types.InitializeResult:
|
||||||
from mcp.types import (
|
|
||||||
InitializeRequest,
|
|
||||||
InitializeRequestParams,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await self.send_request(
|
result = await self.send_request(
|
||||||
ClientRequest(
|
types.ClientRequest(
|
||||||
InitializeRequest(
|
types.InitializeRequest(
|
||||||
method="initialize",
|
method="initialize",
|
||||||
params=InitializeRequestParams(
|
params=types.InitializeRequestParams(
|
||||||
protocolVersion=LATEST_PROTOCOL_VERSION,
|
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
||||||
capabilities=ClientCapabilities(
|
capabilities=types.ClientCapabilities(
|
||||||
sampling=None,
|
sampling=None,
|
||||||
experimental=None,
|
experimental=None,
|
||||||
roots=RootsCapability(
|
roots=types.RootsCapability(
|
||||||
# TODO: Should this be based on whether we
|
# TODO: Should this be based on whether we
|
||||||
# _will_ send notifications, or only whether
|
# _will_ send notifications, or only whether
|
||||||
# they're supported?
|
# they're supported?
|
||||||
listChanged=True
|
listChanged=True
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
clientInfo=Implementation(name="mcp", version="0.1.0"),
|
clientInfo=types.Implementation(name="mcp", version="0.1.0"),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
InitializeResult,
|
types.InitializeResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
|
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
|
||||||
@@ -91,40 +62,33 @@ class ClientSession(
|
|||||||
)
|
)
|
||||||
|
|
||||||
await self.send_notification(
|
await self.send_notification(
|
||||||
ClientNotification(
|
types.ClientNotification(
|
||||||
InitializedNotification(method="notifications/initialized")
|
types.InitializedNotification(method="notifications/initialized")
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def send_ping(self) -> EmptyResult:
|
async def send_ping(self) -> types.EmptyResult:
|
||||||
"""Send a ping request."""
|
"""Send a ping request."""
|
||||||
from mcp.types import PingRequest
|
|
||||||
|
|
||||||
return await self.send_request(
|
return await self.send_request(
|
||||||
ClientRequest(
|
types.ClientRequest(
|
||||||
PingRequest(
|
types.PingRequest(
|
||||||
method="ping",
|
method="ping",
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
EmptyResult,
|
types.EmptyResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def send_progress_notification(
|
async def send_progress_notification(
|
||||||
self, progress_token: str | int, progress: float, total: float | None = None
|
self, progress_token: str | int, progress: float, total: float | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Send a progress notification."""
|
"""Send a progress notification."""
|
||||||
from mcp.types import (
|
|
||||||
ProgressNotification,
|
|
||||||
ProgressNotificationParams,
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.send_notification(
|
await self.send_notification(
|
||||||
ClientNotification(
|
types.ClientNotification(
|
||||||
ProgressNotification(
|
types.ProgressNotification(
|
||||||
method="notifications/progress",
|
method="notifications/progress",
|
||||||
params=ProgressNotificationParams(
|
params=types.ProgressNotificationParams(
|
||||||
progressToken=progress_token,
|
progressToken=progress_token,
|
||||||
progress=progress,
|
progress=progress,
|
||||||
total=total,
|
total=total,
|
||||||
@@ -133,180 +97,137 @@ class ClientSession(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def set_logging_level(self, level: LoggingLevel) -> EmptyResult:
|
async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
|
||||||
"""Send a logging/setLevel request."""
|
"""Send a logging/setLevel request."""
|
||||||
from mcp.types import (
|
|
||||||
SetLevelRequest,
|
|
||||||
SetLevelRequestParams,
|
|
||||||
)
|
|
||||||
|
|
||||||
return await self.send_request(
|
return await self.send_request(
|
||||||
ClientRequest(
|
types.ClientRequest(
|
||||||
SetLevelRequest(
|
types.SetLevelRequest(
|
||||||
method="logging/setLevel",
|
method="logging/setLevel",
|
||||||
params=SetLevelRequestParams(level=level),
|
params=types.SetLevelRequestParams(level=level),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
EmptyResult,
|
types.EmptyResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_resources(self) -> ListResourcesResult:
|
async def list_resources(self) -> types.ListResourcesResult:
|
||||||
"""Send a resources/list request."""
|
"""Send a resources/list request."""
|
||||||
from mcp.types import (
|
|
||||||
ListResourcesRequest,
|
|
||||||
)
|
|
||||||
|
|
||||||
return await self.send_request(
|
return await self.send_request(
|
||||||
ClientRequest(
|
types.ClientRequest(
|
||||||
ListResourcesRequest(
|
types.ListResourcesRequest(
|
||||||
method="resources/list",
|
method="resources/list",
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
ListResourcesResult,
|
types.ListResourcesResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def read_resource(self, uri: AnyUrl) -> ReadResourceResult:
|
async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
|
||||||
"""Send a resources/read request."""
|
"""Send a resources/read request."""
|
||||||
from mcp.types import (
|
|
||||||
ReadResourceRequest,
|
|
||||||
ReadResourceRequestParams,
|
|
||||||
)
|
|
||||||
|
|
||||||
return await self.send_request(
|
return await self.send_request(
|
||||||
ClientRequest(
|
types.ClientRequest(
|
||||||
ReadResourceRequest(
|
types.ReadResourceRequest(
|
||||||
method="resources/read",
|
method="resources/read",
|
||||||
params=ReadResourceRequestParams(uri=uri),
|
params=types.ReadResourceRequestParams(uri=uri),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
ReadResourceResult,
|
types.ReadResourceResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def subscribe_resource(self, uri: AnyUrl) -> EmptyResult:
|
async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
|
||||||
"""Send a resources/subscribe request."""
|
"""Send a resources/subscribe request."""
|
||||||
from mcp.types import (
|
|
||||||
SubscribeRequest,
|
|
||||||
SubscribeRequestParams,
|
|
||||||
)
|
|
||||||
|
|
||||||
return await self.send_request(
|
return await self.send_request(
|
||||||
ClientRequest(
|
types.ClientRequest(
|
||||||
SubscribeRequest(
|
types.SubscribeRequest(
|
||||||
method="resources/subscribe",
|
method="resources/subscribe",
|
||||||
params=SubscribeRequestParams(uri=uri),
|
params=types.SubscribeRequestParams(uri=uri),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
EmptyResult,
|
types.EmptyResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def unsubscribe_resource(self, uri: AnyUrl) -> EmptyResult:
|
async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
|
||||||
"""Send a resources/unsubscribe request."""
|
"""Send a resources/unsubscribe request."""
|
||||||
from mcp.types import (
|
|
||||||
UnsubscribeRequest,
|
|
||||||
UnsubscribeRequestParams,
|
|
||||||
)
|
|
||||||
|
|
||||||
return await self.send_request(
|
return await self.send_request(
|
||||||
ClientRequest(
|
types.ClientRequest(
|
||||||
UnsubscribeRequest(
|
types.UnsubscribeRequest(
|
||||||
method="resources/unsubscribe",
|
method="resources/unsubscribe",
|
||||||
params=UnsubscribeRequestParams(uri=uri),
|
params=types.UnsubscribeRequestParams(uri=uri),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
EmptyResult,
|
types.EmptyResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def call_tool(
|
async def call_tool(
|
||||||
self, name: str, arguments: dict | None = None
|
self, name: str, arguments: dict | None = None
|
||||||
) -> CallToolResult:
|
) -> types.CallToolResult:
|
||||||
"""Send a tools/call request."""
|
"""Send a tools/call request."""
|
||||||
from mcp.types import (
|
|
||||||
CallToolRequest,
|
|
||||||
CallToolRequestParams,
|
|
||||||
)
|
|
||||||
|
|
||||||
return await self.send_request(
|
return await self.send_request(
|
||||||
ClientRequest(
|
types.ClientRequest(
|
||||||
CallToolRequest(
|
types.CallToolRequest(
|
||||||
method="tools/call",
|
method="tools/call",
|
||||||
params=CallToolRequestParams(name=name, arguments=arguments),
|
params=types.CallToolRequestParams(name=name, arguments=arguments),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
CallToolResult,
|
types.CallToolResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_prompts(self) -> ListPromptsResult:
|
async def list_prompts(self) -> types.ListPromptsResult:
|
||||||
"""Send a prompts/list request."""
|
"""Send a prompts/list request."""
|
||||||
from mcp.types import ListPromptsRequest
|
|
||||||
|
|
||||||
return await self.send_request(
|
return await self.send_request(
|
||||||
ClientRequest(
|
types.ClientRequest(
|
||||||
ListPromptsRequest(
|
types.ListPromptsRequest(
|
||||||
method="prompts/list",
|
method="prompts/list",
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
ListPromptsResult,
|
types.ListPromptsResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_prompt(
|
async def get_prompt(
|
||||||
self, name: str, arguments: dict[str, str] | None = None
|
self, name: str, arguments: dict[str, str] | None = None
|
||||||
) -> GetPromptResult:
|
) -> types.GetPromptResult:
|
||||||
"""Send a prompts/get request."""
|
"""Send a prompts/get request."""
|
||||||
from mcp.types import GetPromptRequest, GetPromptRequestParams
|
|
||||||
|
|
||||||
return await self.send_request(
|
return await self.send_request(
|
||||||
ClientRequest(
|
types.ClientRequest(
|
||||||
GetPromptRequest(
|
types.GetPromptRequest(
|
||||||
method="prompts/get",
|
method="prompts/get",
|
||||||
params=GetPromptRequestParams(name=name, arguments=arguments),
|
params=types.GetPromptRequestParams(name=name, arguments=arguments),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
GetPromptResult,
|
types.GetPromptResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def complete(
|
async def complete(
|
||||||
self, ref: ResourceReference | PromptReference, argument: dict
|
self, ref: types.ResourceReference | types.PromptReference, argument: dict
|
||||||
) -> CompleteResult:
|
) -> types.CompleteResult:
|
||||||
"""Send a completion/complete request."""
|
"""Send a completion/complete request."""
|
||||||
from mcp.types import (
|
|
||||||
CompleteRequest,
|
|
||||||
CompleteRequestParams,
|
|
||||||
CompletionArgument,
|
|
||||||
)
|
|
||||||
|
|
||||||
return await self.send_request(
|
return await self.send_request(
|
||||||
ClientRequest(
|
types.ClientRequest(
|
||||||
CompleteRequest(
|
types.CompleteRequest(
|
||||||
method="completion/complete",
|
method="completion/complete",
|
||||||
params=CompleteRequestParams(
|
params=types.CompleteRequestParams(
|
||||||
ref=ref,
|
ref=ref,
|
||||||
argument=CompletionArgument(**argument),
|
argument=types.CompletionArgument(**argument),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
CompleteResult,
|
types.CompleteResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_tools(self) -> ListToolsResult:
|
async def list_tools(self) -> types.ListToolsResult:
|
||||||
"""Send a tools/list request."""
|
"""Send a tools/list request."""
|
||||||
from mcp.types import ListToolsRequest
|
|
||||||
|
|
||||||
return await self.send_request(
|
return await self.send_request(
|
||||||
ClientRequest(
|
types.ClientRequest(
|
||||||
ListToolsRequest(
|
types.ListToolsRequest(
|
||||||
method="tools/list",
|
method="tools/list",
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
ListToolsResult,
|
types.ListToolsResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def send_roots_list_changed(self) -> None:
|
async def send_roots_list_changed(self) -> None:
|
||||||
"""Send a roots/list_changed notification."""
|
"""Send a roots/list_changed notification."""
|
||||||
from mcp.types import RootsListChangedNotification
|
|
||||||
|
|
||||||
await self.send_notification(
|
await self.send_notification(
|
||||||
ClientNotification(
|
types.ClientNotification(
|
||||||
RootsListChangedNotification(
|
types.RootsListChangedNotification(
|
||||||
method="notifications/roots/list_changed",
|
method="notifications/roots/list_changed",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from anyio.abc import TaskStatus
|
|||||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||||
from httpx_sse import aconnect_sse
|
from httpx_sse import aconnect_sse
|
||||||
|
|
||||||
from mcp.types import JSONRPCMessage
|
import mcp.types as types
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -31,11 +31,11 @@ async def sse_client(
|
|||||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
||||||
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||||
"""
|
"""
|
||||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
|
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
||||||
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
|
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||||
|
|
||||||
write_stream: MemoryObjectSendStream[JSONRPCMessage]
|
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
||||||
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
|
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
||||||
|
|
||||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||||
@@ -85,7 +85,7 @@ async def sse_client(
|
|||||||
case "message":
|
case "message":
|
||||||
try:
|
try:
|
||||||
message = (
|
message = (
|
||||||
JSONRPCMessage.model_validate_json(
|
types.JSONRPCMessage.model_validate_json(
|
||||||
sse.data
|
sse.data
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStre
|
|||||||
from anyio.streams.text import TextReceiveStream
|
from anyio.streams.text import TextReceiveStream
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from mcp.types import JSONRPCMessage
|
import mcp.types as types
|
||||||
|
|
||||||
# Environment variables to inherit by default
|
# Environment variables to inherit by default
|
||||||
DEFAULT_INHERITED_ENV_VARS = (
|
DEFAULT_INHERITED_ENV_VARS = (
|
||||||
@@ -72,11 +72,11 @@ async def stdio_client(server: StdioServerParameters):
|
|||||||
Client transport for stdio: this will connect to a server by spawning a
|
Client transport for stdio: this will connect to a server by spawning a
|
||||||
process and communicating with it over stdin/stdout.
|
process and communicating with it over stdin/stdout.
|
||||||
"""
|
"""
|
||||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
|
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
||||||
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
|
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||||
|
|
||||||
write_stream: MemoryObjectSendStream[JSONRPCMessage]
|
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
||||||
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
|
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
||||||
|
|
||||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||||
@@ -99,7 +99,7 @@ async def stdio_client(server: StdioServerParameters):
|
|||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
try:
|
try:
|
||||||
message = JSONRPCMessage.model_validate_json(line)
|
message = types.JSONRPCMessage.model_validate_json(line)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
await read_stream_writer.send(exc)
|
await read_stream_writer.send(exc)
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -12,48 +12,7 @@ from mcp.server.session import ServerSession
|
|||||||
from mcp.server.stdio import stdio_server as stdio_server
|
from mcp.server.stdio import stdio_server as stdio_server
|
||||||
from mcp.shared.context import RequestContext
|
from mcp.shared.context import RequestContext
|
||||||
from mcp.shared.session import RequestResponder
|
from mcp.shared.session import RequestResponder
|
||||||
from mcp.types import (
|
import mcp.types as types
|
||||||
METHOD_NOT_FOUND,
|
|
||||||
CallToolRequest,
|
|
||||||
GetPromptResult,
|
|
||||||
GetPromptRequest,
|
|
||||||
GetPromptResult,
|
|
||||||
ImageContent,
|
|
||||||
ClientNotification,
|
|
||||||
ClientRequest,
|
|
||||||
CompleteRequest,
|
|
||||||
EmbeddedResource,
|
|
||||||
EmptyResult,
|
|
||||||
ErrorData,
|
|
||||||
JSONRPCMessage,
|
|
||||||
ListPromptsRequest,
|
|
||||||
ListPromptsResult,
|
|
||||||
ListResourcesRequest,
|
|
||||||
ListResourcesResult,
|
|
||||||
ListToolsRequest,
|
|
||||||
ListToolsResult,
|
|
||||||
LoggingCapability,
|
|
||||||
LoggingLevel,
|
|
||||||
PingRequest,
|
|
||||||
ProgressNotification,
|
|
||||||
Prompt,
|
|
||||||
PromptMessage,
|
|
||||||
PromptReference,
|
|
||||||
PromptsCapability,
|
|
||||||
ReadResourceRequest,
|
|
||||||
ReadResourceResult,
|
|
||||||
Resource,
|
|
||||||
ResourceReference,
|
|
||||||
ResourcesCapability,
|
|
||||||
ServerCapabilities,
|
|
||||||
ServerResult,
|
|
||||||
SetLevelRequest,
|
|
||||||
SubscribeRequest,
|
|
||||||
TextContent,
|
|
||||||
Tool,
|
|
||||||
ToolsCapability,
|
|
||||||
UnsubscribeRequest,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -77,8 +36,8 @@ class NotificationOptions:
|
|||||||
class Server:
|
class Server:
|
||||||
def __init__(self, name: str):
|
def __init__(self, name: str):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]] = {
|
self.request_handlers: dict[type, Callable[..., Awaitable[types.ServerResult]]] = {
|
||||||
PingRequest: _ping_handler,
|
types.PingRequest: _ping_handler,
|
||||||
}
|
}
|
||||||
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
|
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
|
||||||
self.notification_options = NotificationOptions()
|
self.notification_options = NotificationOptions()
|
||||||
@@ -116,7 +75,7 @@ class Server:
|
|||||||
self,
|
self,
|
||||||
notification_options: NotificationOptions,
|
notification_options: NotificationOptions,
|
||||||
experimental_capabilities: dict[str, dict[str, Any]],
|
experimental_capabilities: dict[str, dict[str, Any]],
|
||||||
) -> ServerCapabilities:
|
) -> types.ServerCapabilities:
|
||||||
"""Convert existing handlers to a ServerCapabilities object."""
|
"""Convert existing handlers to a ServerCapabilities object."""
|
||||||
prompts_capability = None
|
prompts_capability = None
|
||||||
resources_capability = None
|
resources_capability = None
|
||||||
@@ -124,28 +83,28 @@ class Server:
|
|||||||
logging_capability = None
|
logging_capability = None
|
||||||
|
|
||||||
# Set prompt capabilities if handler exists
|
# Set prompt capabilities if handler exists
|
||||||
if ListPromptsRequest in self.request_handlers:
|
if types.ListPromptsRequest in self.request_handlers:
|
||||||
prompts_capability = PromptsCapability(
|
prompts_capability = types.PromptsCapability(
|
||||||
listChanged=notification_options.prompts_changed
|
listChanged=notification_options.prompts_changed
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set resource capabilities if handler exists
|
# Set resource capabilities if handler exists
|
||||||
if ListResourcesRequest in self.request_handlers:
|
if types.ListResourcesRequest in self.request_handlers:
|
||||||
resources_capability = ResourcesCapability(
|
resources_capability = types.ResourcesCapability(
|
||||||
subscribe=False, listChanged=notification_options.resources_changed
|
subscribe=False, listChanged=notification_options.resources_changed
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set tool capabilities if handler exists
|
# Set tool capabilities if handler exists
|
||||||
if ListToolsRequest in self.request_handlers:
|
if types.ListToolsRequest in self.request_handlers:
|
||||||
tools_capability = ToolsCapability(
|
tools_capability = types.ToolsCapability(
|
||||||
listChanged=notification_options.tools_changed
|
listChanged=notification_options.tools_changed
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set logging capabilities if handler exists
|
# Set logging capabilities if handler exists
|
||||||
if SetLevelRequest in self.request_handlers:
|
if types.SetLevelRequest in self.request_handlers:
|
||||||
logging_capability = LoggingCapability()
|
logging_capability = types.LoggingCapability()
|
||||||
|
|
||||||
return ServerCapabilities(
|
return types.ServerCapabilities(
|
||||||
prompts=prompts_capability,
|
prompts=prompts_capability,
|
||||||
resources=resources_capability,
|
resources=resources_capability,
|
||||||
tools=tools_capability,
|
tools=tools_capability,
|
||||||
@@ -159,14 +118,14 @@ class Server:
|
|||||||
return request_ctx.get()
|
return request_ctx.get()
|
||||||
|
|
||||||
def list_prompts(self):
|
def list_prompts(self):
|
||||||
def decorator(func: Callable[[], Awaitable[list[Prompt]]]):
|
def decorator(func: Callable[[], Awaitable[list[types.Prompt]]]):
|
||||||
logger.debug("Registering handler for PromptListRequest")
|
logger.debug("Registering handler for PromptListRequest")
|
||||||
|
|
||||||
async def handler(_: Any):
|
async def handler(_: Any):
|
||||||
prompts = await func()
|
prompts = await func()
|
||||||
return ServerResult(ListPromptsResult(prompts=prompts))
|
return types.ServerResult(types.ListPromptsResult(prompts=prompts))
|
||||||
|
|
||||||
self.request_handlers[ListPromptsRequest] = handler
|
self.request_handlers[types.ListPromptsRequest] = handler
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
@@ -174,47 +133,42 @@ class Server:
|
|||||||
def get_prompt(self):
|
def get_prompt(self):
|
||||||
def decorator(
|
def decorator(
|
||||||
func: Callable[
|
func: Callable[
|
||||||
[str, dict[str, str] | None], Awaitable[GetPromptResult]
|
[str, dict[str, str] | None], Awaitable[types.GetPromptResult]
|
||||||
],
|
],
|
||||||
):
|
):
|
||||||
logger.debug("Registering handler for GetPromptRequest")
|
logger.debug("Registering handler for GetPromptRequest")
|
||||||
|
|
||||||
async def handler(req: GetPromptRequest):
|
async def handler(req: types.GetPromptRequest):
|
||||||
prompt_get = await func(req.params.name, req.params.arguments)
|
prompt_get = await func(req.params.name, req.params.arguments)
|
||||||
return ServerResult(prompt_get)
|
return types.ServerResult(prompt_get)
|
||||||
|
|
||||||
self.request_handlers[GetPromptRequest] = handler
|
self.request_handlers[types.GetPromptRequest] = handler
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def list_resources(self):
|
def list_resources(self):
|
||||||
def decorator(func: Callable[[], Awaitable[list[Resource]]]):
|
def decorator(func: Callable[[], Awaitable[list[types.Resource]]]):
|
||||||
logger.debug("Registering handler for ListResourcesRequest")
|
logger.debug("Registering handler for ListResourcesRequest")
|
||||||
|
|
||||||
async def handler(_: Any):
|
async def handler(_: Any):
|
||||||
resources = await func()
|
resources = await func()
|
||||||
return ServerResult(ListResourcesResult(resources=resources))
|
return types.ServerResult(types.ListResourcesResult(resources=resources))
|
||||||
|
|
||||||
self.request_handlers[ListResourcesRequest] = handler
|
self.request_handlers[types.ListResourcesRequest] = handler
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def read_resource(self):
|
def read_resource(self):
|
||||||
from mcp.types import (
|
|
||||||
BlobResourceContents,
|
|
||||||
TextResourceContents,
|
|
||||||
)
|
|
||||||
|
|
||||||
def decorator(func: Callable[[AnyUrl], Awaitable[str | bytes]]):
|
def decorator(func: Callable[[AnyUrl], Awaitable[str | bytes]]):
|
||||||
logger.debug("Registering handler for ReadResourceRequest")
|
logger.debug("Registering handler for ReadResourceRequest")
|
||||||
|
|
||||||
async def handler(req: ReadResourceRequest):
|
async def handler(req: types.ReadResourceRequest):
|
||||||
result = await func(req.params.uri)
|
result = await func(req.params.uri)
|
||||||
match result:
|
match result:
|
||||||
case str(s):
|
case str(s):
|
||||||
content = TextResourceContents(
|
content = types.TextResourceContents(
|
||||||
uri=req.params.uri,
|
uri=req.params.uri,
|
||||||
text=s,
|
text=s,
|
||||||
mimeType="text/plain",
|
mimeType="text/plain",
|
||||||
@@ -222,130 +176,117 @@ class Server:
|
|||||||
case bytes(b):
|
case bytes(b):
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
content = BlobResourceContents(
|
content = types.BlobResourceContents(
|
||||||
uri=req.params.uri,
|
uri=req.params.uri,
|
||||||
blob=base64.urlsafe_b64encode(b).decode(),
|
blob=base64.urlsafe_b64encode(b).decode(),
|
||||||
mimeType="application/octet-stream",
|
mimeType="application/octet-stream",
|
||||||
)
|
)
|
||||||
|
|
||||||
return ServerResult(
|
return types.ServerResult(
|
||||||
ReadResourceResult(
|
types.ReadResourceResult(
|
||||||
contents=[content],
|
contents=[content],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.request_handlers[ReadResourceRequest] = handler
|
self.request_handlers[types.ReadResourceRequest] = handler
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def set_logging_level(self):
|
def set_logging_level(self):
|
||||||
from mcp.types import EmptyResult
|
def decorator(func: Callable[[types.LoggingLevel], Awaitable[None]]):
|
||||||
|
|
||||||
def decorator(func: Callable[[LoggingLevel], Awaitable[None]]):
|
|
||||||
logger.debug("Registering handler for SetLevelRequest")
|
logger.debug("Registering handler for SetLevelRequest")
|
||||||
|
|
||||||
async def handler(req: SetLevelRequest):
|
async def handler(req: types.SetLevelRequest):
|
||||||
await func(req.params.level)
|
await func(req.params.level)
|
||||||
return ServerResult(EmptyResult())
|
return types.ServerResult(types.EmptyResult())
|
||||||
|
|
||||||
self.request_handlers[SetLevelRequest] = handler
|
self.request_handlers[types.SetLevelRequest] = handler
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def subscribe_resource(self):
|
def subscribe_resource(self):
|
||||||
from mcp.types import EmptyResult
|
|
||||||
|
|
||||||
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
|
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
|
||||||
logger.debug("Registering handler for SubscribeRequest")
|
logger.debug("Registering handler for SubscribeRequest")
|
||||||
|
|
||||||
async def handler(req: SubscribeRequest):
|
async def handler(req: types.SubscribeRequest):
|
||||||
await func(req.params.uri)
|
await func(req.params.uri)
|
||||||
return ServerResult(EmptyResult())
|
return types.ServerResult(types.EmptyResult())
|
||||||
|
|
||||||
self.request_handlers[SubscribeRequest] = handler
|
self.request_handlers[types.SubscribeRequest] = handler
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def unsubscribe_resource(self):
|
def unsubscribe_resource(self):
|
||||||
from mcp.types import EmptyResult
|
|
||||||
|
|
||||||
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
|
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
|
||||||
logger.debug("Registering handler for UnsubscribeRequest")
|
logger.debug("Registering handler for UnsubscribeRequest")
|
||||||
|
|
||||||
async def handler(req: UnsubscribeRequest):
|
async def handler(req: types.UnsubscribeRequest):
|
||||||
await func(req.params.uri)
|
await func(req.params.uri)
|
||||||
return ServerResult(EmptyResult())
|
return types.ServerResult(types.EmptyResult())
|
||||||
|
|
||||||
self.request_handlers[UnsubscribeRequest] = handler
|
self.request_handlers[types.UnsubscribeRequest] = handler
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def list_tools(self):
|
def list_tools(self):
|
||||||
def decorator(func: Callable[[], Awaitable[list[Tool]]]):
|
def decorator(func: Callable[[], Awaitable[list[types.Tool]]]):
|
||||||
logger.debug("Registering handler for ListToolsRequest")
|
logger.debug("Registering handler for ListToolsRequest")
|
||||||
|
|
||||||
async def handler(_: Any):
|
async def handler(_: Any):
|
||||||
tools = await func()
|
tools = await func()
|
||||||
return ServerResult(ListToolsResult(tools=tools))
|
return types.ServerResult(types.ListToolsResult(tools=tools))
|
||||||
|
|
||||||
self.request_handlers[ListToolsRequest] = handler
|
self.request_handlers[types.ListToolsRequest] = handler
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def call_tool(self):
|
def call_tool(self):
|
||||||
from mcp.types import (
|
|
||||||
CallToolResult,
|
|
||||||
EmbeddedResource,
|
|
||||||
ImageContent,
|
|
||||||
TextContent,
|
|
||||||
)
|
|
||||||
|
|
||||||
def decorator(
|
def decorator(
|
||||||
func: Callable[
|
func: Callable[
|
||||||
...,
|
...,
|
||||||
Awaitable[Sequence[TextContent | ImageContent | EmbeddedResource]],
|
Awaitable[Sequence[types.TextContent | types.ImageContent | types.EmbeddedResource]],
|
||||||
],
|
],
|
||||||
):
|
):
|
||||||
logger.debug("Registering handler for CallToolRequest")
|
logger.debug("Registering handler for CallToolRequest")
|
||||||
|
|
||||||
async def handler(req: CallToolRequest):
|
async def handler(req: types.CallToolRequest):
|
||||||
try:
|
try:
|
||||||
results = await func(req.params.name, (req.params.arguments or {}))
|
results = await func(req.params.name, (req.params.arguments or {}))
|
||||||
content = []
|
content = []
|
||||||
for result in results:
|
for result in results:
|
||||||
match result:
|
match result:
|
||||||
case str() as text:
|
case str() as text:
|
||||||
content.append(TextContent(type="text", text=text))
|
content.append(types.TextContent(type="text", text=text))
|
||||||
case ImageContent() as img:
|
case types.ImageContent() as img:
|
||||||
content.append(
|
content.append(
|
||||||
ImageContent(
|
types.ImageContent(
|
||||||
type="image",
|
type="image",
|
||||||
data=img.data,
|
data=img.data,
|
||||||
mimeType=img.mimeType,
|
mimeType=img.mimeType,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
case EmbeddedResource() as resource:
|
case types.EmbeddedResource() as resource:
|
||||||
content.append(
|
content.append(
|
||||||
EmbeddedResource(
|
types.EmbeddedResource(
|
||||||
type="resource", resource=resource.resource
|
type="resource", resource=resource.resource
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return ServerResult(CallToolResult(content=content, isError=False))
|
return types.ServerResult(types.CallToolResult(content=content, isError=False))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return ServerResult(
|
return types.ServerResult(
|
||||||
CallToolResult(
|
types.CallToolResult(
|
||||||
content=[TextContent(type="text", text=str(e))],
|
content=[types.TextContent(type="text", text=str(e))],
|
||||||
isError=True,
|
isError=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.request_handlers[CallToolRequest] = handler
|
self.request_handlers[types.CallToolRequest] = handler
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
@@ -356,47 +297,46 @@ class Server:
|
|||||||
):
|
):
|
||||||
logger.debug("Registering handler for ProgressNotification")
|
logger.debug("Registering handler for ProgressNotification")
|
||||||
|
|
||||||
async def handler(req: ProgressNotification):
|
async def handler(req: types.ProgressNotification):
|
||||||
await func(
|
await func(
|
||||||
req.params.progressToken, req.params.progress, req.params.total
|
req.params.progressToken, req.params.progress, req.params.total
|
||||||
)
|
)
|
||||||
|
|
||||||
self.notification_handlers[ProgressNotification] = handler
|
self.notification_handlers[types.ProgressNotification] = handler
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def completion(self):
|
def completion(self):
|
||||||
"""Provides completions for prompts and resource templates"""
|
"""Provides completions for prompts and resource templates"""
|
||||||
from mcp.types import CompleteResult, Completion, CompletionArgument
|
|
||||||
|
|
||||||
def decorator(
|
def decorator(
|
||||||
func: Callable[
|
func: Callable[
|
||||||
[PromptReference | ResourceReference, CompletionArgument],
|
[types.PromptReference | types.ResourceReference, types.CompletionArgument],
|
||||||
Awaitable[Completion | None],
|
Awaitable[types.Completion | None],
|
||||||
],
|
],
|
||||||
):
|
):
|
||||||
logger.debug("Registering handler for CompleteRequest")
|
logger.debug("Registering handler for CompleteRequest")
|
||||||
|
|
||||||
async def handler(req: CompleteRequest):
|
async def handler(req: types.CompleteRequest):
|
||||||
completion = await func(req.params.ref, req.params.argument)
|
completion = await func(req.params.ref, req.params.argument)
|
||||||
return ServerResult(
|
return types.ServerResult(
|
||||||
CompleteResult(
|
types.CompleteResult(
|
||||||
completion=completion
|
completion=completion
|
||||||
if completion is not None
|
if completion is not None
|
||||||
else Completion(values=[], total=None, hasMore=None),
|
else types.Completion(values=[], total=None, hasMore=None),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.request_handlers[CompleteRequest] = handler
|
self.request_handlers[types.CompleteRequest] = handler
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
|
||||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
|
||||||
initialization_options: InitializationOptions,
|
initialization_options: InitializationOptions,
|
||||||
# When True, exceptions are returned as messages to the client.
|
# When True, exceptions are returned as messages to the client.
|
||||||
# When False, exceptions are raised, which will cause the server to shut down
|
# When False, exceptions are raised, which will cause the server to shut down
|
||||||
@@ -412,7 +352,7 @@ class Server:
|
|||||||
logger.debug(f"Received message: {message}")
|
logger.debug(f"Received message: {message}")
|
||||||
|
|
||||||
match message:
|
match message:
|
||||||
case RequestResponder(request=ClientRequest(root=req)):
|
case RequestResponder(request=types.ClientRequest(root=req)):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Processing request of type {type(req).__name__}"
|
f"Processing request of type {type(req).__name__}"
|
||||||
)
|
)
|
||||||
@@ -437,7 +377,7 @@ class Server:
|
|||||||
except Exception as err:
|
except Exception as err:
|
||||||
if raise_exceptions:
|
if raise_exceptions:
|
||||||
raise err
|
raise err
|
||||||
response = ErrorData(
|
response = types.ErrorData(
|
||||||
code=0, message=str(err), data=None
|
code=0, message=str(err), data=None
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
@@ -448,14 +388,14 @@ class Server:
|
|||||||
await message.respond(response)
|
await message.respond(response)
|
||||||
else:
|
else:
|
||||||
await message.respond(
|
await message.respond(
|
||||||
ErrorData(
|
types.ErrorData(
|
||||||
code=METHOD_NOT_FOUND,
|
code=types.METHOD_NOT_FOUND,
|
||||||
message="Method not found",
|
message="Method not found",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Response sent")
|
logger.debug("Response sent")
|
||||||
case ClientNotification(root=notify):
|
case types.ClientNotification(root=notify):
|
||||||
if type(notify) in self.notification_handlers:
|
if type(notify) in self.notification_handlers:
|
||||||
assert type(notify) in self.notification_handlers
|
assert type(notify) in self.notification_handlers
|
||||||
|
|
||||||
@@ -479,5 +419,5 @@ class Server:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _ping_handler(request: PingRequest) -> ServerResult:
|
async def _ping_handler(request: types.PingRequest) -> types.ServerResult:
|
||||||
return ServerResult(EmptyResult())
|
return types.ServerResult(types.EmptyResult())
|
||||||
|
|||||||
@@ -11,29 +11,7 @@ from mcp.shared.session import (
|
|||||||
BaseSession,
|
BaseSession,
|
||||||
RequestResponder,
|
RequestResponder,
|
||||||
)
|
)
|
||||||
from mcp.types import (
|
import mcp.types as types
|
||||||
LATEST_PROTOCOL_VERSION,
|
|
||||||
ClientNotification,
|
|
||||||
ClientRequest,
|
|
||||||
CreateMessageResult,
|
|
||||||
EmptyResult,
|
|
||||||
Implementation,
|
|
||||||
IncludeContext,
|
|
||||||
InitializedNotification,
|
|
||||||
InitializeRequest,
|
|
||||||
InitializeResult,
|
|
||||||
JSONRPCMessage,
|
|
||||||
ListRootsResult,
|
|
||||||
LoggingLevel,
|
|
||||||
ModelPreferences,
|
|
||||||
PromptListChangedNotification,
|
|
||||||
ResourceListChangedNotification,
|
|
||||||
SamplingMessage,
|
|
||||||
ServerNotification,
|
|
||||||
ServerRequest,
|
|
||||||
ServerResult,
|
|
||||||
ToolListChangedNotification,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class InitializationState(Enum):
|
class InitializationState(Enum):
|
||||||
@@ -44,37 +22,37 @@ class InitializationState(Enum):
|
|||||||
|
|
||||||
class ServerSession(
|
class ServerSession(
|
||||||
BaseSession[
|
BaseSession[
|
||||||
ServerRequest,
|
types.ServerRequest,
|
||||||
ServerNotification,
|
types.ServerNotification,
|
||||||
ServerResult,
|
types.ServerResult,
|
||||||
ClientRequest,
|
types.ClientRequest,
|
||||||
ClientNotification,
|
types.ClientNotification,
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
_initialized: InitializationState = InitializationState.NotInitialized
|
_initialized: InitializationState = InitializationState.NotInitialized
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
|
||||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
|
||||||
init_options: InitializationOptions,
|
init_options: InitializationOptions,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(read_stream, write_stream, ClientRequest, ClientNotification)
|
super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification)
|
||||||
self._initialization_state = InitializationState.NotInitialized
|
self._initialization_state = InitializationState.NotInitialized
|
||||||
self._init_options = init_options
|
self._init_options = init_options
|
||||||
|
|
||||||
async def _received_request(
|
async def _received_request(
|
||||||
self, responder: RequestResponder[ClientRequest, ServerResult]
|
self, responder: RequestResponder[types.ClientRequest, types.ServerResult]
|
||||||
):
|
):
|
||||||
match responder.request.root:
|
match responder.request.root:
|
||||||
case InitializeRequest():
|
case types.InitializeRequest():
|
||||||
self._initialization_state = InitializationState.Initializing
|
self._initialization_state = InitializationState.Initializing
|
||||||
await responder.respond(
|
await responder.respond(
|
||||||
ServerResult(
|
types.ServerResult(
|
||||||
InitializeResult(
|
types.InitializeResult(
|
||||||
protocolVersion=LATEST_PROTOCOL_VERSION,
|
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
||||||
capabilities=self._init_options.capabilities,
|
capabilities=self._init_options.capabilities,
|
||||||
serverInfo=Implementation(
|
serverInfo=types.Implementation(
|
||||||
name=self._init_options.server_name,
|
name=self._init_options.server_name,
|
||||||
version=self._init_options.server_version,
|
version=self._init_options.server_version,
|
||||||
),
|
),
|
||||||
@@ -87,11 +65,11 @@ class ServerSession(
|
|||||||
"Received request before initialization was complete"
|
"Received request before initialization was complete"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _received_notification(self, notification: ClientNotification) -> None:
|
async def _received_notification(self, notification: types.ClientNotification) -> None:
|
||||||
# Need this to avoid ASYNC910
|
# Need this to avoid ASYNC910
|
||||||
await anyio.lowlevel.checkpoint()
|
await anyio.lowlevel.checkpoint()
|
||||||
match notification.root:
|
match notification.root:
|
||||||
case InitializedNotification():
|
case types.InitializedNotification():
|
||||||
self._initialization_state = InitializationState.Initialized
|
self._initialization_state = InitializationState.Initialized
|
||||||
case _:
|
case _:
|
||||||
if self._initialization_state != InitializationState.Initialized:
|
if self._initialization_state != InitializationState.Initialized:
|
||||||
@@ -100,19 +78,14 @@ class ServerSession(
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def send_log_message(
|
async def send_log_message(
|
||||||
self, level: LoggingLevel, data: Any, logger: str | None = None
|
self, level: types.LoggingLevel, data: Any, logger: str | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Send a log message notification."""
|
"""Send a log message notification."""
|
||||||
from mcp.types import (
|
|
||||||
LoggingMessageNotification,
|
|
||||||
LoggingMessageNotificationParams,
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.send_notification(
|
await self.send_notification(
|
||||||
ServerNotification(
|
types.ServerNotification(
|
||||||
LoggingMessageNotification(
|
types.LoggingMessageNotification(
|
||||||
method="notifications/message",
|
method="notifications/message",
|
||||||
params=LoggingMessageNotificationParams(
|
params=types.LoggingMessageNotificationParams(
|
||||||
level=level,
|
level=level,
|
||||||
data=data,
|
data=data,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
@@ -123,43 +96,33 @@ class ServerSession(
|
|||||||
|
|
||||||
async def send_resource_updated(self, uri: AnyUrl) -> None:
|
async def send_resource_updated(self, uri: AnyUrl) -> None:
|
||||||
"""Send a resource updated notification."""
|
"""Send a resource updated notification."""
|
||||||
from mcp.types import (
|
|
||||||
ResourceUpdatedNotification,
|
|
||||||
ResourceUpdatedNotificationParams,
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.send_notification(
|
await self.send_notification(
|
||||||
ServerNotification(
|
types.ServerNotification(
|
||||||
ResourceUpdatedNotification(
|
types.ResourceUpdatedNotification(
|
||||||
method="notifications/resources/updated",
|
method="notifications/resources/updated",
|
||||||
params=ResourceUpdatedNotificationParams(uri=uri),
|
params=types.ResourceUpdatedNotificationParams(uri=uri),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_message(
|
async def create_message(
|
||||||
self,
|
self,
|
||||||
messages: list[SamplingMessage],
|
messages: list[types.SamplingMessage],
|
||||||
*,
|
*,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
include_context: IncludeContext | None = None,
|
include_context: types.IncludeContext | None = None,
|
||||||
temperature: float | None = None,
|
temperature: float | None = None,
|
||||||
stop_sequences: list[str] | None = None,
|
stop_sequences: list[str] | None = None,
|
||||||
metadata: dict[str, Any] | None = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
model_preferences: ModelPreferences | None = None,
|
model_preferences: types.ModelPreferences | None = None,
|
||||||
) -> CreateMessageResult:
|
) -> types.CreateMessageResult:
|
||||||
"""Send a sampling/create_message request."""
|
"""Send a sampling/create_message request."""
|
||||||
from mcp.types import (
|
|
||||||
CreateMessageRequest,
|
|
||||||
CreateMessageRequestParams,
|
|
||||||
)
|
|
||||||
|
|
||||||
return await self.send_request(
|
return await self.send_request(
|
||||||
ServerRequest(
|
types.ServerRequest(
|
||||||
CreateMessageRequest(
|
types.CreateMessageRequest(
|
||||||
method="sampling/createMessage",
|
method="sampling/createMessage",
|
||||||
params=CreateMessageRequestParams(
|
params=types.CreateMessageRequestParams(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
systemPrompt=system_prompt,
|
systemPrompt=system_prompt,
|
||||||
includeContext=include_context,
|
includeContext=include_context,
|
||||||
@@ -171,46 +134,40 @@ class ServerSession(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
CreateMessageResult,
|
types.CreateMessageResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_roots(self) -> ListRootsResult:
|
async def list_roots(self) -> types.ListRootsResult:
|
||||||
"""Send a roots/list request."""
|
"""Send a roots/list request."""
|
||||||
from mcp.types import ListRootsRequest
|
|
||||||
|
|
||||||
return await self.send_request(
|
return await self.send_request(
|
||||||
ServerRequest(
|
types.ServerRequest(
|
||||||
ListRootsRequest(
|
types.ListRootsRequest(
|
||||||
method="roots/list",
|
method="roots/list",
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
ListRootsResult,
|
types.ListRootsResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def send_ping(self) -> EmptyResult:
|
async def send_ping(self) -> types.EmptyResult:
|
||||||
"""Send a ping request."""
|
"""Send a ping request."""
|
||||||
from mcp.types import PingRequest
|
|
||||||
|
|
||||||
return await self.send_request(
|
return await self.send_request(
|
||||||
ServerRequest(
|
types.ServerRequest(
|
||||||
PingRequest(
|
types.PingRequest(
|
||||||
method="ping",
|
method="ping",
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
EmptyResult,
|
types.EmptyResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def send_progress_notification(
|
async def send_progress_notification(
|
||||||
self, progress_token: str | int, progress: float, total: float | None = None
|
self, progress_token: str | int, progress: float, total: float | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Send a progress notification."""
|
"""Send a progress notification."""
|
||||||
from mcp.types import ProgressNotification, ProgressNotificationParams
|
|
||||||
|
|
||||||
await self.send_notification(
|
await self.send_notification(
|
||||||
ServerNotification(
|
types.ServerNotification(
|
||||||
ProgressNotification(
|
types.ProgressNotification(
|
||||||
method="notifications/progress",
|
method="notifications/progress",
|
||||||
params=ProgressNotificationParams(
|
params=types.ProgressNotificationParams(
|
||||||
progressToken=progress_token,
|
progressToken=progress_token,
|
||||||
progress=progress,
|
progress=progress,
|
||||||
total=total,
|
total=total,
|
||||||
@@ -222,8 +179,8 @@ class ServerSession(
|
|||||||
async def send_resource_list_changed(self) -> None:
|
async def send_resource_list_changed(self) -> None:
|
||||||
"""Send a resource list changed notification."""
|
"""Send a resource list changed notification."""
|
||||||
await self.send_notification(
|
await self.send_notification(
|
||||||
ServerNotification(
|
types.ServerNotification(
|
||||||
ResourceListChangedNotification(
|
types.ResourceListChangedNotification(
|
||||||
method="notifications/resources/list_changed",
|
method="notifications/resources/list_changed",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -232,8 +189,8 @@ class ServerSession(
|
|||||||
async def send_tool_list_changed(self) -> None:
|
async def send_tool_list_changed(self) -> None:
|
||||||
"""Send a tool list changed notification."""
|
"""Send a tool list changed notification."""
|
||||||
await self.send_notification(
|
await self.send_notification(
|
||||||
ServerNotification(
|
types.ServerNotification(
|
||||||
ToolListChangedNotification(
|
types.ToolListChangedNotification(
|
||||||
method="notifications/tools/list_changed",
|
method="notifications/tools/list_changed",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -242,8 +199,8 @@ class ServerSession(
|
|||||||
async def send_prompt_list_changed(self) -> None:
|
async def send_prompt_list_changed(self) -> None:
|
||||||
"""Send a prompt list changed notification."""
|
"""Send a prompt list changed notification."""
|
||||||
await self.send_notification(
|
await self.send_notification(
|
||||||
ServerNotification(
|
types.ServerNotification(
|
||||||
PromptListChangedNotification(
|
types.PromptListChangedNotification(
|
||||||
method="notifications/prompts/list_changed",
|
method="notifications/prompts/list_changed",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from starlette.requests import Request
|
|||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
from starlette.types import Receive, Scope, Send
|
from starlette.types import Receive, Scope, Send
|
||||||
|
|
||||||
from mcp.types import JSONRPCMessage
|
import mcp.types as types
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -30,7 +30,7 @@ class SseServerTransport:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_endpoint: str
|
_endpoint: str
|
||||||
_read_stream_writers: dict[UUID, MemoryObjectSendStream[JSONRPCMessage | Exception]]
|
_read_stream_writers: dict[UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception]]
|
||||||
|
|
||||||
def __init__(self, endpoint: str) -> None:
|
def __init__(self, endpoint: str) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -50,11 +50,11 @@ class SseServerTransport:
|
|||||||
raise ValueError("connect_sse can only handle HTTP requests")
|
raise ValueError("connect_sse can only handle HTTP requests")
|
||||||
|
|
||||||
logger.debug("Setting up SSE connection")
|
logger.debug("Setting up SSE connection")
|
||||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
|
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
||||||
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
|
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||||
|
|
||||||
write_stream: MemoryObjectSendStream[JSONRPCMessage]
|
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
||||||
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
|
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
||||||
|
|
||||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||||
@@ -125,7 +125,7 @@ class SseServerTransport:
|
|||||||
logger.debug(f"Received JSON: {json}")
|
logger.debug(f"Received JSON: {json}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
message = JSONRPCMessage.model_validate(json)
|
message = types.JSONRPCMessage.model_validate(json)
|
||||||
logger.debug(f"Validated client message: {message}")
|
logger.debug(f"Validated client message: {message}")
|
||||||
except ValidationError as err:
|
except ValidationError as err:
|
||||||
logger.error(f"Failed to parse message: {err}")
|
logger.error(f"Failed to parse message: {err}")
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import anyio
|
|||||||
import anyio.lowlevel
|
import anyio.lowlevel
|
||||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||||
|
|
||||||
from mcp.types import JSONRPCMessage
|
import mcp.types as types
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@@ -24,11 +24,11 @@ async def stdio_server(
|
|||||||
if not stdout:
|
if not stdout:
|
||||||
stdout = anyio.wrap_file(sys.stdout)
|
stdout = anyio.wrap_file(sys.stdout)
|
||||||
|
|
||||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
|
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
||||||
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
|
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||||
|
|
||||||
write_stream: MemoryObjectSendStream[JSONRPCMessage]
|
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
||||||
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
|
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
||||||
|
|
||||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||||
@@ -38,7 +38,7 @@ async def stdio_server(
|
|||||||
async with read_stream_writer:
|
async with read_stream_writer:
|
||||||
async for line in stdin:
|
async for line in stdin:
|
||||||
try:
|
try:
|
||||||
message = JSONRPCMessage.model_validate_json(line)
|
message = types.JSONRPCMessage.model_validate_json(line)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
await read_stream_writer.send(exc)
|
await read_stream_writer.send(exc)
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStre
|
|||||||
from starlette.types import Receive, Scope, Send
|
from starlette.types import Receive, Scope, Send
|
||||||
from starlette.websockets import WebSocket
|
from starlette.websockets import WebSocket
|
||||||
|
|
||||||
from mcp.types import JSONRPCMessage
|
import mcp.types as types
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -21,11 +21,11 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send):
|
|||||||
websocket = WebSocket(scope, receive, send)
|
websocket = WebSocket(scope, receive, send)
|
||||||
await websocket.accept(subprotocol="mcp")
|
await websocket.accept(subprotocol="mcp")
|
||||||
|
|
||||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
|
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
||||||
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
|
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||||
|
|
||||||
write_stream: MemoryObjectSendStream[JSONRPCMessage]
|
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
||||||
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
|
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
||||||
|
|
||||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||||
@@ -35,7 +35,7 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send):
|
|||||||
async with read_stream_writer:
|
async with read_stream_writer:
|
||||||
async for message in websocket.iter_json():
|
async for message in websocket.iter_json():
|
||||||
try:
|
try:
|
||||||
client_message = JSONRPCMessage.model_validate(message)
|
client_message = types.JSONRPCMessage.model_validate(message)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
await read_stream_writer.send(exc)
|
await read_stream_writer.send(exc)
|
||||||
continue
|
continue
|
||||||
|
|||||||
Reference in New Issue
Block a user