mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Format with ruff
This commit is contained in:
@@ -37,7 +37,6 @@ from .types import (
|
||||
ReadResourceResult,
|
||||
Resource,
|
||||
ResourceUpdatedNotification,
|
||||
Role as SamplingRole,
|
||||
SamplingMessage,
|
||||
ServerCapabilities,
|
||||
ServerNotification,
|
||||
@@ -49,6 +48,9 @@ from .types import (
|
||||
Tool,
|
||||
UnsubscribeRequest,
|
||||
)
|
||||
from .types import (
|
||||
Role as SamplingRole,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CallToolRequest",
|
||||
|
||||
@@ -62,7 +62,8 @@ class ClientSession(
|
||||
|
||||
if result.protocolVersion != SUPPORTED_PROTOCOL_VERSION:
|
||||
raise RuntimeError(
|
||||
f"Unsupported protocol version from the server: {result.protocolVersion}"
|
||||
"Unsupported protocol version from the server: "
|
||||
f"{result.protocolVersion}"
|
||||
)
|
||||
|
||||
await self.send_notification(
|
||||
|
||||
@@ -19,11 +19,17 @@ def remove_request_params(url: str) -> str:
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def sse_client(url: str, headers: dict[str, Any] | None = None, timeout: float = 5, sse_read_timeout: float = 60 * 5):
|
||||
async def sse_client(
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float = 5,
|
||||
sse_read_timeout: float = 60 * 5,
|
||||
):
|
||||
"""
|
||||
Client transport for SSE.
|
||||
|
||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
||||
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||
"""
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
|
||||
@@ -67,7 +73,10 @@ async def sse_client(url: str, headers: dict[str, Any] | None = None, timeout: f
|
||||
or url_parsed.scheme
|
||||
!= endpoint_parsed.scheme
|
||||
):
|
||||
error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}"
|
||||
error_msg = (
|
||||
"Endpoint origin does not match "
|
||||
f"connection origin: {endpoint_url}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
@@ -104,11 +113,16 @@ async def sse_client(url: str, headers: dict[str, Any] | None = None, timeout: f
|
||||
logger.debug(f"Sending client message: {message}")
|
||||
response = await client.post(
|
||||
endpoint_url,
|
||||
json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
json=message.model_dump(
|
||||
by_alias=True,
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.debug(
|
||||
f"Client message sent successfully: {response.status_code}"
|
||||
"Client message sent successfully: "
|
||||
f"{response.status_code}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error in post_writer: {exc}")
|
||||
|
||||
@@ -28,7 +28,8 @@ class StdioServerParameters(BaseModel):
|
||||
@asynccontextmanager
|
||||
async def stdio_client(server: StdioServerParameters):
|
||||
"""
|
||||
Client transport for stdio: this will connect to a server by spawning a process and communicating with it over stdin/stdout.
|
||||
Client transport for stdio: this will connect to a server by spawning a
|
||||
process and communicating with it over stdin/stdout.
|
||||
"""
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
|
||||
|
||||
@@ -55,9 +55,11 @@ class Server:
|
||||
|
||||
def create_initialization_options(self) -> types.InitializationOptions:
|
||||
"""Create initialization options from this server instance."""
|
||||
|
||||
def pkg_version(package: str) -> str:
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
|
||||
return version(package)
|
||||
except Exception:
|
||||
return "unknown"
|
||||
@@ -69,16 +71,17 @@ class Server:
|
||||
)
|
||||
|
||||
def get_capabilities(self) -> ServerCapabilities:
|
||||
"""Convert existing handlers to a ServerCapabilities object."""
|
||||
def get_capability(req_type: type) -> dict[str, Any] | None:
|
||||
return {} if req_type in self.request_handlers else None
|
||||
"""Convert existing handlers to a ServerCapabilities object."""
|
||||
|
||||
return ServerCapabilities(
|
||||
prompts=get_capability(ListPromptsRequest),
|
||||
resources=get_capability(ListResourcesRequest),
|
||||
tools=get_capability(ListPromptsRequest),
|
||||
logging=get_capability(SetLevelRequest)
|
||||
)
|
||||
def get_capability(req_type: type) -> dict[str, Any] | None:
|
||||
return {} if req_type in self.request_handlers else None
|
||||
|
||||
return ServerCapabilities(
|
||||
prompts=get_capability(ListPromptsRequest),
|
||||
resources=get_capability(ListResourcesRequest),
|
||||
tools=get_capability(ListPromptsRequest),
|
||||
logging=get_capability(SetLevelRequest),
|
||||
)
|
||||
|
||||
@property
|
||||
def request_context(self) -> RequestContext:
|
||||
@@ -87,7 +90,7 @@ class Server:
|
||||
|
||||
def list_prompts(self):
|
||||
def decorator(func: Callable[[], Awaitable[list[Prompt]]]):
|
||||
logger.debug(f"Registering handler for PromptListRequest")
|
||||
logger.debug("Registering handler for PromptListRequest")
|
||||
|
||||
async def handler(_: Any):
|
||||
prompts = await func()
|
||||
@@ -103,17 +106,19 @@ class Server:
|
||||
GetPromptRequest,
|
||||
GetPromptResult,
|
||||
ImageContent,
|
||||
Role as Role,
|
||||
SamplingMessage,
|
||||
TextContent,
|
||||
)
|
||||
from mcp_python.types import (
|
||||
Role as Role,
|
||||
)
|
||||
|
||||
def decorator(
|
||||
func: Callable[
|
||||
[str, dict[str, str] | None], Awaitable[types.PromptResponse]
|
||||
],
|
||||
):
|
||||
logger.debug(f"Registering handler for GetPromptRequest")
|
||||
logger.debug("Registering handler for GetPromptRequest")
|
||||
|
||||
async def handler(req: GetPromptRequest):
|
||||
prompt_get = await func(req.params.name, req.params.arguments)
|
||||
@@ -149,7 +154,7 @@ class Server:
|
||||
|
||||
def list_resources(self):
|
||||
def decorator(func: Callable[[], Awaitable[list[Resource]]]):
|
||||
logger.debug(f"Registering handler for ListResourcesRequest")
|
||||
logger.debug("Registering handler for ListResourcesRequest")
|
||||
|
||||
async def handler(_: Any):
|
||||
resources = await func()
|
||||
@@ -169,7 +174,7 @@ class Server:
|
||||
)
|
||||
|
||||
def decorator(func: Callable[[AnyUrl], Awaitable[str | bytes]]):
|
||||
logger.debug(f"Registering handler for ReadResourceRequest")
|
||||
logger.debug("Registering handler for ReadResourceRequest")
|
||||
|
||||
async def handler(req: ReadResourceRequest):
|
||||
result = await func(req.params.uri)
|
||||
@@ -204,7 +209,7 @@ class Server:
|
||||
from mcp_python.types import EmptyResult
|
||||
|
||||
def decorator(func: Callable[[LoggingLevel], Awaitable[None]]):
|
||||
logger.debug(f"Registering handler for SetLevelRequest")
|
||||
logger.debug("Registering handler for SetLevelRequest")
|
||||
|
||||
async def handler(req: SetLevelRequest):
|
||||
await func(req.params.level)
|
||||
@@ -219,7 +224,7 @@ class Server:
|
||||
from mcp_python.types import EmptyResult
|
||||
|
||||
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
|
||||
logger.debug(f"Registering handler for SubscribeRequest")
|
||||
logger.debug("Registering handler for SubscribeRequest")
|
||||
|
||||
async def handler(req: SubscribeRequest):
|
||||
await func(req.params.uri)
|
||||
@@ -234,7 +239,7 @@ class Server:
|
||||
from mcp_python.types import EmptyResult
|
||||
|
||||
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
|
||||
logger.debug(f"Registering handler for UnsubscribeRequest")
|
||||
logger.debug("Registering handler for UnsubscribeRequest")
|
||||
|
||||
async def handler(req: UnsubscribeRequest):
|
||||
await func(req.params.uri)
|
||||
@@ -249,7 +254,7 @@ class Server:
|
||||
from mcp_python.types import CallToolResult
|
||||
|
||||
def decorator(func: Callable[..., Awaitable[Any]]):
|
||||
logger.debug(f"Registering handler for CallToolRequest")
|
||||
logger.debug("Registering handler for CallToolRequest")
|
||||
|
||||
async def handler(req: CallToolRequest):
|
||||
result = await func(req.params.name, **(req.params.arguments or {}))
|
||||
@@ -264,7 +269,7 @@ class Server:
|
||||
def decorator(
|
||||
func: Callable[[str | int, float, float | None], Awaitable[None]],
|
||||
):
|
||||
logger.debug(f"Registering handler for ProgressNotification")
|
||||
logger.debug("Registering handler for ProgressNotification")
|
||||
|
||||
async def handler(req: ProgressNotification):
|
||||
await func(
|
||||
@@ -286,7 +291,7 @@ class Server:
|
||||
Awaitable[Completion | None],
|
||||
],
|
||||
):
|
||||
logger.debug(f"Registering handler for CompleteRequest")
|
||||
logger.debug("Registering handler for CompleteRequest")
|
||||
|
||||
async def handler(req: CompleteRequest):
|
||||
completion = await func(req.params.ref, req.params.argument)
|
||||
@@ -307,10 +312,12 @@ class Server:
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||
initialization_options: types.InitializationOptions
|
||||
initialization_options: types.InitializationOptions,
|
||||
):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
async with ServerSession(read_stream, write_stream, initialization_options) as session:
|
||||
async with ServerSession(
|
||||
read_stream, write_stream, initialization_options
|
||||
) as session:
|
||||
async for message in session.incoming_messages:
|
||||
logger.debug(f"Received message: {message}")
|
||||
|
||||
@@ -359,14 +366,16 @@ class Server:
|
||||
|
||||
handler = self.notification_handlers[type(notify)]
|
||||
logger.debug(
|
||||
f"Dispatching notification of type {type(notify).__name__}"
|
||||
f"Dispatching notification of type "
|
||||
f"{type(notify).__name__}"
|
||||
)
|
||||
|
||||
try:
|
||||
await handler(notify)
|
||||
except Exception as err:
|
||||
logger.error(
|
||||
f"Uncaught exception in notification handler: {err}"
|
||||
f"Uncaught exception in notification handler: "
|
||||
f"{err}"
|
||||
)
|
||||
|
||||
for warning in w:
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import importlib.metadata
|
||||
import logging
|
||||
import sys
|
||||
import importlib.metadata
|
||||
|
||||
import anyio
|
||||
|
||||
from mcp_python.server.session import ServerSession
|
||||
from mcp_python.server.types import InitializationOptions
|
||||
from mcp_python.server.stdio import stdio_server
|
||||
from mcp_python.server.types import InitializationOptions
|
||||
from mcp_python.types import ServerCapabilities
|
||||
|
||||
if not sys.warnoptions:
|
||||
@@ -30,7 +31,18 @@ async def receive_loop(session: ServerSession):
|
||||
async def main():
|
||||
version = importlib.metadata.version("mcp_python")
|
||||
async with stdio_server() as (read_stream, write_stream):
|
||||
async with ServerSession(read_stream, write_stream, InitializationOptions(server_name="mcp_python", server_version=version, capabilities=ServerCapabilities())) as session, write_stream:
|
||||
async with (
|
||||
ServerSession(
|
||||
read_stream,
|
||||
write_stream,
|
||||
InitializationOptions(
|
||||
server_name="mcp_python",
|
||||
server_version=version,
|
||||
capabilities=ServerCapabilities(),
|
||||
),
|
||||
) as session,
|
||||
write_stream,
|
||||
):
|
||||
await receive_loop(session)
|
||||
|
||||
|
||||
|
||||
@@ -6,11 +6,11 @@ import anyio.lowlevel
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl
|
||||
|
||||
from mcp_python.server.types import InitializationOptions
|
||||
from mcp_python.shared.session import (
|
||||
BaseSession,
|
||||
RequestResponder,
|
||||
)
|
||||
from mcp_python.server.types import InitializationOptions
|
||||
from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSION
|
||||
from mcp_python.types import (
|
||||
ClientNotification,
|
||||
@@ -25,7 +25,6 @@ from mcp_python.types import (
|
||||
JSONRPCMessage,
|
||||
LoggingLevel,
|
||||
SamplingMessage,
|
||||
ServerCapabilities,
|
||||
ServerNotification,
|
||||
ServerRequest,
|
||||
ServerResult,
|
||||
@@ -53,7 +52,7 @@ class ServerSession(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||
init_options: InitializationOptions
|
||||
init_options: InitializationOptions,
|
||||
) -> None:
|
||||
super().__init__(read_stream, write_stream, ClientRequest, ClientNotification)
|
||||
self._initialization_state = InitializationState.NotInitialized
|
||||
@@ -72,7 +71,7 @@ class ServerSession(
|
||||
capabilities=self._init_options.capabilities,
|
||||
serverInfo=Implementation(
|
||||
name=self._init_options.server_name,
|
||||
version=self._init_options.server_version
|
||||
version=self._init_options.server_version,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -19,10 +19,14 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class SseServerTransport:
|
||||
"""
|
||||
SSE server transport for MCP. This class provides _two_ ASGI applications, suitable to be used with a framework like Starlette and a server like Hypercorn:
|
||||
SSE server transport for MCP. This class provides _two_ ASGI applications,
|
||||
suitable to be used with a framework like Starlette and a server like Hypercorn:
|
||||
|
||||
1. connect_sse() is an ASGI application which receives incoming GET requests, and sets up a new SSE stream to send server messages to the client.
|
||||
2. handle_post_message() is an ASGI application which receives incoming POST requests, which should contain client messages that link to a previously-established SSE session.
|
||||
1. connect_sse() is an ASGI application which receives incoming GET requests,
|
||||
and sets up a new SSE stream to send server messages to the client.
|
||||
2. handle_post_message() is an ASGI application which receives incoming POST
|
||||
requests, which should contain client messages that link to a
|
||||
previously-established SSE session.
|
||||
"""
|
||||
|
||||
_endpoint: str
|
||||
@@ -30,7 +34,8 @@ class SseServerTransport:
|
||||
|
||||
def __init__(self, endpoint: str) -> None:
|
||||
"""
|
||||
Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL given.
|
||||
Creates a new SSE server transport, which will direct the client to POST
|
||||
messages to the relative or absolute URL given.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
@@ -74,7 +79,9 @@ class SseServerTransport:
|
||||
await sse_stream_writer.send(
|
||||
{
|
||||
"event": "message",
|
||||
"data": message.model_dump_json(by_alias=True, exclude_none=True),
|
||||
"data": message.model_dump_json(
|
||||
by_alias=True, exclude_none=True
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -7,14 +7,18 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStre
|
||||
|
||||
from mcp_python.types import JSONRPCMessage
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def stdio_server(
|
||||
stdin: anyio.AsyncFile[str] | None = None, stdout: anyio.AsyncFile[str] | None = None
|
||||
stdin: anyio.AsyncFile[str] | None = None,
|
||||
stdout: anyio.AsyncFile[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Server transport for stdio: this communicates with an MCP client by reading from the current process' stdin and writing to stdout.
|
||||
Server transport for stdio: this communicates with an MCP client by reading
|
||||
from the current process' stdin and writing to stdout.
|
||||
"""
|
||||
# Purposely not using context managers for these, as we don't want to close standard process handles.
|
||||
# Purposely not using context managers for these, as we don't want to close
|
||||
# standard process handles.
|
||||
if not stdin:
|
||||
stdin = anyio.wrap_file(sys.stdin)
|
||||
if not stdout:
|
||||
|
||||
@@ -6,6 +6,7 @@ from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from mcp_python.types import Role, ServerCapabilities
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,8 @@ logger = logging.getLogger(__name__)
|
||||
@asynccontextmanager
|
||||
async def websocket_server(scope: Scope, receive: Receive, send: Send):
|
||||
"""
|
||||
WebSocket server transport for MCP. This is an ASGI application, suitable to be used with a framework like Starlette and a server like Hypercorn.
|
||||
WebSocket server transport for MCP. This is an ASGI application, suitable to be
|
||||
used with a framework like Starlette and a server like Hypercorn.
|
||||
"""
|
||||
|
||||
websocket = WebSocket(scope, receive, send)
|
||||
@@ -47,7 +48,9 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send):
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for message in write_stream_reader:
|
||||
obj = message.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
obj = message.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
)
|
||||
await websocket.send_json(obj)
|
||||
except anyio.ClosedResourceError:
|
||||
await websocket.close()
|
||||
|
||||
@@ -69,9 +69,11 @@ class BaseSession(
|
||||
],
|
||||
):
|
||||
"""
|
||||
Implements an MCP "session" on top of read/write streams, including features like request/response linking, notifications, and progress.
|
||||
Implements an MCP "session" on top of read/write streams, including features
|
||||
like request/response linking, notifications, and progress.
|
||||
|
||||
This class is an async context manager that automatically starts processing messages when entered.
|
||||
This class is an async context manager that automatically starts processing
|
||||
messages when entered.
|
||||
"""
|
||||
|
||||
_response_streams: dict[
|
||||
@@ -108,7 +110,9 @@ class BaseSession(
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
# Using BaseSession as a context manager should not block on exit (this would be very surprising behavior), so make sure to cancel the tasks in the task group.
|
||||
# Using BaseSession as a context manager should not block on exit (this
|
||||
# would be very surprising behavior), so make sure to cancel the tasks
|
||||
# in the task group.
|
||||
self._task_group.cancel_scope.cancel()
|
||||
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
@@ -118,9 +122,11 @@ class BaseSession(
|
||||
result_type: type[ReceiveResultT],
|
||||
) -> ReceiveResultT:
|
||||
"""
|
||||
Sends a request and wait for a response. Raises an McpError if the response contains an error.
|
||||
Sends a request and wait for a response. Raises an McpError if the
|
||||
response contains an error.
|
||||
|
||||
Do not use this method to emit notifications! Use send_notification() instead.
|
||||
Do not use this method to emit notifications! Use send_notification()
|
||||
instead.
|
||||
"""
|
||||
|
||||
request_id = self._request_id
|
||||
@@ -132,7 +138,9 @@ class BaseSession(
|
||||
self._response_streams[request_id] = response_stream
|
||||
|
||||
jsonrpc_request = JSONRPCRequest(
|
||||
jsonrpc="2.0", id=request_id, **request.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
|
||||
# TODO: Support progress callbacks
|
||||
@@ -147,10 +155,12 @@ class BaseSession(
|
||||
|
||||
async def send_notification(self, notification: SendNotificationT) -> None:
|
||||
"""
|
||||
Emits a notification, which is a one-way message that does not expect a response.
|
||||
Emits a notification, which is a one-way message that does not expect
|
||||
a response.
|
||||
"""
|
||||
jsonrpc_notification = JSONRPCNotification(
|
||||
jsonrpc="2.0", **notification.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
jsonrpc="2.0",
|
||||
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
|
||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_notification))
|
||||
@@ -165,7 +175,9 @@ class BaseSession(
|
||||
jsonrpc_response = JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
result=response.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
),
|
||||
)
|
||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_response))
|
||||
|
||||
@@ -180,7 +192,9 @@ class BaseSession(
|
||||
await self._incoming_message_stream_writer.send(message)
|
||||
elif isinstance(message.root, JSONRPCRequest):
|
||||
validated_request = self._receive_request_type.model_validate(
|
||||
message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
message.root.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
)
|
||||
)
|
||||
responder = RequestResponder(
|
||||
request_id=message.root.id,
|
||||
@@ -196,7 +210,9 @@ class BaseSession(
|
||||
await self._incoming_message_stream_writer.send(responder)
|
||||
elif isinstance(message.root, JSONRPCNotification):
|
||||
notification = self._receive_notification_type.model_validate(
|
||||
message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
message.root.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
)
|
||||
)
|
||||
|
||||
await self._received_notification(notification)
|
||||
@@ -208,7 +224,8 @@ class BaseSession(
|
||||
else:
|
||||
await self._incoming_message_stream_writer.send(
|
||||
RuntimeError(
|
||||
f"Received response with an unknown request ID: {message}"
|
||||
"Received response with an unknown "
|
||||
f"request ID: {message}"
|
||||
)
|
||||
)
|
||||
|
||||
@@ -216,21 +233,25 @@ class BaseSession(
|
||||
self, responder: RequestResponder[ReceiveRequestT, SendResultT]
|
||||
) -> None:
|
||||
"""
|
||||
Can be overridden by subclasses to handle a request without needing to listen on the message stream.
|
||||
Can be overridden by subclasses to handle a request without needing to
|
||||
listen on the message stream.
|
||||
|
||||
If the request is responded to within this method, it will not be forwarded on to the message stream.
|
||||
If the request is responded to within this method, it will not be
|
||||
forwarded on to the message stream.
|
||||
"""
|
||||
|
||||
async def _received_notification(self, notification: ReceiveNotificationT) -> None:
|
||||
"""
|
||||
Can be overridden by subclasses to handle a notification without needing to listen on the message stream.
|
||||
Can be overridden by subclasses to handle a notification without needing
|
||||
to listen on the message stream.
|
||||
"""
|
||||
|
||||
async def send_progress_notification(
|
||||
self, progress_token: str | int, progress: float, total: float | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Sends a progress notification for a request that is currently being processed.
|
||||
Sends a progress notification for a request that is currently being
|
||||
processed.
|
||||
"""
|
||||
|
||||
@property
|
||||
|
||||
@@ -6,14 +6,19 @@ from pydantic.networks import AnyUrl
|
||||
"""
|
||||
Model Context Protocol bindings for Python
|
||||
|
||||
These bindings were generated from https://github.com/anthropic-experimental/mcp-spec, using Claude, with a prompt something like the following:
|
||||
These bindings were generated from https://github.com/anthropic-experimental/mcp-spec,
|
||||
using Claude, with a prompt something like the following:
|
||||
|
||||
Generate idiomatic Python bindings for this schema for MCP, or the "Model Context Protocol." The schema is defined in TypeScript, but there's also a JSON Schema version for reference.
|
||||
Generate idiomatic Python bindings for this schema for MCP, or the "Model Context
|
||||
Protocol." The schema is defined in TypeScript, but there's also a JSON Schema version
|
||||
for reference.
|
||||
|
||||
* For the bindings, let's use Pydantic V2 models.
|
||||
* Each model should allow extra fields everywhere, by specifying `model_config = ConfigDict(extra='allow')`. Do this in every case, instead of a custom base class.
|
||||
* Each model should allow extra fields everywhere, by specifying `model_config =
|
||||
ConfigDict(extra='allow')`. Do this in every case, instead of a custom base class.
|
||||
* Union types should be represented with a Pydantic `RootModel`.
|
||||
* Define additional model classes instead of using dictionaries. Do this even if they're not separate types in the schema.
|
||||
* Define additional model classes instead of using dictionaries. Do this even if they're
|
||||
not separate types in the schema.
|
||||
"""
|
||||
|
||||
|
||||
@@ -24,7 +29,10 @@ class RequestParams(BaseModel):
|
||||
class Meta(BaseModel):
|
||||
progressToken: ProgressToken | None = None
|
||||
"""
|
||||
If specified, the caller is requesting out-of-band progress notifications for this request (as represented by notifications/progress). The value of this parameter is an opaque token that will be attached to any subsequent notifications. The receiver is not obligated to provide these notifications.
|
||||
If specified, the caller is requesting out-of-band progress notifications for
|
||||
this request (as represented by notifications/progress). The value of this
|
||||
parameter is an opaque token that will be attached to any subsequent
|
||||
notifications. The receiver is not obligated to provide these notifications.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
@@ -38,9 +46,11 @@ class NotificationParams(BaseModel):
|
||||
|
||||
_meta: Meta | None = None
|
||||
"""
|
||||
This parameter name is reserved by MCP to allow clients and servers to attach additional metadata to their notifications.
|
||||
This parameter name is reserved by MCP to allow clients and servers to attach
|
||||
additional metadata to their notifications.
|
||||
"""
|
||||
|
||||
|
||||
RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams)
|
||||
NotificationParamsT = TypeVar("NotificationParamsT", bound=NotificationParams)
|
||||
MethodT = TypeVar("MethodT", bound=str)
|
||||
@@ -68,7 +78,8 @@ class Result(BaseModel):
|
||||
|
||||
_meta: dict[str, Any] | None = None
|
||||
"""
|
||||
This result property is reserved by the protocol to allow clients and servers to attach additional metadata to their responses.
|
||||
This result property is reserved by the protocol to allow clients and servers to
|
||||
attach additional metadata to their responses.
|
||||
"""
|
||||
|
||||
|
||||
@@ -112,9 +123,15 @@ class ErrorData(BaseModel):
|
||||
code: int
|
||||
"""The error type that occurred."""
|
||||
message: str
|
||||
"""A short description of the error. The message SHOULD be limited to a concise single sentence."""
|
||||
"""
|
||||
A short description of the error. The message SHOULD be limited to a concise single
|
||||
sentence.
|
||||
"""
|
||||
data: Any | None = None
|
||||
"""Additional information about the error. The value of this member is defined by the sender (e.g. detailed error information, nested errors etc.)."""
|
||||
"""
|
||||
Additional information about the error. The value of this member is defined by the
|
||||
sender (e.g. detailed error information, nested errors etc.).
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@@ -182,14 +199,17 @@ class InitializeRequestParams(RequestParams):
|
||||
|
||||
|
||||
class InitializeRequest(Request):
|
||||
"""This request is sent from the client to the server when it first connects, asking it to begin initialization."""
|
||||
"""
|
||||
This request is sent from the client to the server when it first connects, asking it
|
||||
to begin initialization.
|
||||
"""
|
||||
|
||||
method: Literal["initialize"]
|
||||
params: InitializeRequestParams
|
||||
|
||||
|
||||
class InitializeResult(Result):
|
||||
"""After receiving an initialize request from the client, the server sends this response."""
|
||||
"""After receiving an initialize request from the client, the server sends this."""
|
||||
|
||||
protocolVersion: Literal[1]
|
||||
"""The version of the Model Context Protocol that the server wants to use."""
|
||||
@@ -198,14 +218,20 @@ class InitializeResult(Result):
|
||||
|
||||
|
||||
class InitializedNotification(Notification):
|
||||
"""This notification is sent from the client to the server after initialization has finished."""
|
||||
"""
|
||||
This notification is sent from the client to the server after initialization has
|
||||
finished.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/initialized"]
|
||||
params: NotificationParams | None = None
|
||||
|
||||
|
||||
class PingRequest(Request):
|
||||
"""A ping, issued by either the server or the client, to check that the other party is still alive."""
|
||||
"""
|
||||
A ping, issued by either the server or the client, to check that the other party is
|
||||
still alive.
|
||||
"""
|
||||
|
||||
method: Literal["ping"]
|
||||
params: RequestParams | None = None
|
||||
@@ -215,16 +241,25 @@ class ProgressNotificationParams(NotificationParams):
|
||||
"""Parameters for progress notifications."""
|
||||
|
||||
progressToken: ProgressToken
|
||||
"""The progress token which was given in the initial request, used to associate this notification with the request that is proceeding."""
|
||||
"""
|
||||
The progress token which was given in the initial request, used to associate this
|
||||
notification with the request that is proceeding.
|
||||
"""
|
||||
progress: float
|
||||
"""The progress thus far. This should increase every time progress is made, even if the total is unknown."""
|
||||
"""
|
||||
The progress thus far. This should increase every time progress is made, even if the
|
||||
total is unknown.
|
||||
"""
|
||||
total: float | None = None
|
||||
"""Total number of items to process (or total progress required), if known."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ProgressNotification(Notification):
|
||||
"""An out-of-band notification used to inform the receiver of a progress update for a long-running request."""
|
||||
"""
|
||||
An out-of-band notification used to inform the receiver of a progress update for a
|
||||
long-running request.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/progress"]
|
||||
params: ProgressNotificationParams
|
||||
@@ -251,13 +286,19 @@ class ResourceTemplate(BaseModel):
|
||||
"""A template description for resources available on the server."""
|
||||
|
||||
uriTemplate: str
|
||||
"""A URI template (according to RFC 6570) that can be used to construct resource URIs."""
|
||||
"""
|
||||
A URI template (according to RFC 6570) that can be used to construct resource
|
||||
URIs.
|
||||
"""
|
||||
name: str | None = None
|
||||
"""A human-readable name for the type of resource this template refers to."""
|
||||
description: str | None = None
|
||||
"""A human-readable description of what this template is for."""
|
||||
mimeType: str | None = None
|
||||
"""The MIME type for all resources that match this template. This should only be included if all resources matching this template have the same type."""
|
||||
"""
|
||||
The MIME type for all resources that match this template. This should only be
|
||||
included if all resources matching this template have the same type.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@@ -272,7 +313,10 @@ class ReadResourceRequestParams(RequestParams):
|
||||
"""Parameters for reading a resource."""
|
||||
|
||||
uri: AnyUrl
|
||||
"""The URI of the resource to read. The URI can use any protocol; it is up to the server how to interpret it."""
|
||||
"""
|
||||
The URI of the resource to read. The URI can use any protocol; it is up to the
|
||||
server how to interpret it.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@@ -297,7 +341,10 @@ class TextResourceContents(ResourceContents):
|
||||
"""Text contents of a resource."""
|
||||
|
||||
text: str
|
||||
"""The text of the item. This must only be set if the item can actually be represented as text (not binary data)."""
|
||||
"""
|
||||
The text of the item. This must only be set if the item can actually be represented
|
||||
as text (not binary data).
|
||||
"""
|
||||
|
||||
|
||||
class BlobResourceContents(ResourceContents):
|
||||
@@ -314,7 +361,10 @@ class ReadResourceResult(Result):
|
||||
|
||||
|
||||
class ResourceListChangedNotification(Notification):
|
||||
"""An optional notification from the server to the client, informing it that the list of resources it can read from has changed."""
|
||||
"""
|
||||
An optional notification from the server to the client, informing it that the list
|
||||
of resources it can read from has changed.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/resources/list_changed"]
|
||||
params: NotificationParams | None = None
|
||||
@@ -324,12 +374,18 @@ class SubscribeRequestParams(RequestParams):
|
||||
"""Parameters for subscribing to a resource."""
|
||||
|
||||
uri: AnyUrl
|
||||
"""The URI of the resource to subscribe to. The URI can use any protocol; it is up to the server how to interpret it."""
|
||||
"""
|
||||
The URI of the resource to subscribe to. The URI can use any protocol; it is up to
|
||||
the server how to interpret it.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class SubscribeRequest(Request):
|
||||
"""Sent from the client to request resources/updated notifications from the server whenever a particular resource changes."""
|
||||
"""
|
||||
Sent from the client to request resources/updated notifications from the server
|
||||
whenever a particular resource changes.
|
||||
"""
|
||||
|
||||
method: Literal["resources/subscribe"]
|
||||
params: SubscribeRequestParams
|
||||
@@ -344,7 +400,10 @@ class UnsubscribeRequestParams(RequestParams):
|
||||
|
||||
|
||||
class UnsubscribeRequest(Request):
|
||||
"""Sent from the client to request cancellation of resources/updated notifications from the server."""
|
||||
"""
|
||||
Sent from the client to request cancellation of resources/updated notifications from
|
||||
the server.
|
||||
"""
|
||||
|
||||
method: Literal["resources/unsubscribe"]
|
||||
params: UnsubscribeRequestParams
|
||||
@@ -354,19 +413,25 @@ class ResourceUpdatedNotificationParams(NotificationParams):
|
||||
"""Parameters for resource update notifications."""
|
||||
|
||||
uri: AnyUrl
|
||||
"""The URI of the resource that has been updated. This might be a sub-resource of the one that the client actually subscribed to."""
|
||||
"""
|
||||
The URI of the resource that has been updated. This might be a sub-resource of the
|
||||
one that the client actually subscribed to.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ResourceUpdatedNotification(Notification):
|
||||
"""A notification from the server to the client, informing it that a resource has changed and may need to be read again."""
|
||||
"""
|
||||
A notification from the server to the client, informing it that a resource has
|
||||
changed and may need to be read again.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/resources/updated"]
|
||||
params: ResourceUpdatedNotificationParams
|
||||
|
||||
|
||||
class ListPromptsRequest(Request):
|
||||
"""Sent from the client to request a list of prompts and prompt templates the server has."""
|
||||
"""Sent from the client to request a list of prompts and prompt templates."""
|
||||
|
||||
method: Literal["prompts/list"]
|
||||
params: RequestParams | None = None
|
||||
@@ -435,7 +500,10 @@ class ImageContent(BaseModel):
|
||||
data: str
|
||||
"""The base64-encoded image data."""
|
||||
mimeType: str
|
||||
"""The MIME type of the image. Different providers may support different image types."""
|
||||
"""
|
||||
The MIME type of the image. Different providers may support different
|
||||
image types.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@@ -505,7 +573,10 @@ class CallToolResult(Result):
|
||||
|
||||
|
||||
class ToolListChangedNotification(Notification):
|
||||
"""An optional notification from the server to the client, informing it that the list of tools it offers has changed."""
|
||||
"""
|
||||
An optional notification from the server to the client, informing it that the list
|
||||
of tools it offers has changed.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/tools/list_changed"]
|
||||
params: NotificationParams | None = None
|
||||
@@ -537,7 +608,10 @@ class LoggingMessageNotificationParams(NotificationParams):
|
||||
logger: str | None = None
|
||||
"""An optional name of the logger issuing this message."""
|
||||
data: Any
|
||||
"""The data to be logged, such as a string message or an object. Any JSON serializable type is allowed here."""
|
||||
"""
|
||||
The data to be logged, such as a string message or an object. Any JSON serializable
|
||||
type is allowed here.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@@ -558,7 +632,10 @@ class CreateMessageRequestParams(RequestParams):
|
||||
systemPrompt: str | None = None
|
||||
"""An optional system prompt the server wants to use for sampling."""
|
||||
includeContext: IncludeContext | None = None
|
||||
"""A request to include context from one or more MCP servers (including the caller), to be attached to the prompt."""
|
||||
"""
|
||||
A request to include context from one or more MCP servers (including the caller), to
|
||||
be attached to the prompt.
|
||||
"""
|
||||
temperature: float | None = None
|
||||
maxTokens: int
|
||||
"""The maximum number of tokens to sample, as requested by the server."""
|
||||
@@ -638,9 +715,15 @@ class Completion(BaseModel):
|
||||
values: list[str]
|
||||
"""An array of completion values. Must not exceed 100 items."""
|
||||
total: int | None = None
|
||||
"""The total number of completion options available. This can exceed the number of values actually sent in the response."""
|
||||
"""
|
||||
The total number of completion options available. This can exceed the number of
|
||||
values actually sent in the response.
|
||||
"""
|
||||
hasMore: bool | None = None
|
||||
"""Indicates whether there are additional completion options beyond those provided in the current response, even if the exact total is unknown."""
|
||||
"""
|
||||
Indicates whether there are additional completion options beyond those provided in
|
||||
the current response, even if the exact total is unknown.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
|
||||
@@ -59,14 +59,18 @@ async def test_client_session_initialize():
|
||||
JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=jsonrpc_request.root.id,
|
||||
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
result=result.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
jsonrpc_notification = await client_to_server_receive.receive()
|
||||
assert isinstance(jsonrpc_notification.root, JSONRPCNotification)
|
||||
initialized_notification = ClientNotification.model_validate(
|
||||
jsonrpc_notification.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
jsonrpc_notification.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
)
|
||||
)
|
||||
|
||||
async def listen_session():
|
||||
|
||||
@@ -33,7 +33,13 @@ async def test_server_session_initialize():
|
||||
nonlocal received_initialized
|
||||
|
||||
async with ServerSession(
|
||||
client_to_server_receive, server_to_client_send, InitializationOptions(server_name='mcp_python', server_version='0.1.0', capabilities=ServerCapabilities())
|
||||
client_to_server_receive,
|
||||
server_to_client_send,
|
||||
InitializationOptions(
|
||||
server_name="mcp_python",
|
||||
server_version="0.1.0",
|
||||
capabilities=ServerCapabilities(),
|
||||
),
|
||||
) as server_session:
|
||||
async for message in server_session.incoming_messages:
|
||||
if isinstance(message, Exception):
|
||||
|
||||
Reference in New Issue
Block a user