mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-20 07:14:24 +01:00
rename mcp_python to mcp
This commit is contained in:
114
src/mcp/__init__.py
Normal file
114
src/mcp/__init__.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from .client.session import ClientSession
|
||||
from .client.stdio import StdioServerParameters, stdio_client
|
||||
from .server.session import ServerSession
|
||||
from .server.stdio import stdio_server
|
||||
from .shared.exceptions import McpError
|
||||
from .types import (
|
||||
CallToolRequest,
|
||||
ClientCapabilities,
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
ClientResult,
|
||||
CompleteRequest,
|
||||
CreateMessageRequest,
|
||||
CreateMessageResult,
|
||||
ErrorData,
|
||||
GetPromptRequest,
|
||||
GetPromptResult,
|
||||
Implementation,
|
||||
IncludeContext,
|
||||
InitializedNotification,
|
||||
InitializeRequest,
|
||||
InitializeResult,
|
||||
JSONRPCError,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
ListPromptsRequest,
|
||||
ListPromptsResult,
|
||||
ListResourcesRequest,
|
||||
ListResourcesResult,
|
||||
ListToolsResult,
|
||||
LoggingLevel,
|
||||
LoggingMessageNotification,
|
||||
Notification,
|
||||
PingRequest,
|
||||
ProgressNotification,
|
||||
PromptsCapability,
|
||||
ReadResourceRequest,
|
||||
ReadResourceResult,
|
||||
Resource,
|
||||
ResourcesCapability,
|
||||
ResourceUpdatedNotification,
|
||||
RootsCapability,
|
||||
SamplingMessage,
|
||||
ServerCapabilities,
|
||||
ServerNotification,
|
||||
ServerRequest,
|
||||
ServerResult,
|
||||
SetLevelRequest,
|
||||
StopReason,
|
||||
SubscribeRequest,
|
||||
Tool,
|
||||
ToolsCapability,
|
||||
UnsubscribeRequest,
|
||||
)
|
||||
from .types import (
|
||||
Role as SamplingRole,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CallToolRequest",
|
||||
"ClientCapabilities",
|
||||
"ClientNotification",
|
||||
"ClientRequest",
|
||||
"ClientResult",
|
||||
"ClientSession",
|
||||
"CreateMessageRequest",
|
||||
"CreateMessageResult",
|
||||
"ErrorData",
|
||||
"GetPromptRequest",
|
||||
"GetPromptResult",
|
||||
"Implementation",
|
||||
"IncludeContext",
|
||||
"InitializeRequest",
|
||||
"InitializeResult",
|
||||
"InitializedNotification",
|
||||
"JSONRPCError",
|
||||
"JSONRPCRequest",
|
||||
"ListPromptsRequest",
|
||||
"ListPromptsResult",
|
||||
"ListResourcesRequest",
|
||||
"ListResourcesResult",
|
||||
"ListToolsResult",
|
||||
"LoggingLevel",
|
||||
"LoggingMessageNotification",
|
||||
"McpError",
|
||||
"Notification",
|
||||
"PingRequest",
|
||||
"ProgressNotification",
|
||||
"PromptsCapability",
|
||||
"ReadResourceRequest",
|
||||
"ReadResourceResult",
|
||||
"ResourcesCapability",
|
||||
"ResourceUpdatedNotification",
|
||||
"Resource",
|
||||
"RootsCapability",
|
||||
"SamplingMessage",
|
||||
"SamplingRole",
|
||||
"ServerCapabilities",
|
||||
"ServerNotification",
|
||||
"ServerRequest",
|
||||
"ServerResult",
|
||||
"ServerSession",
|
||||
"SetLevelRequest",
|
||||
"StdioServerParameters",
|
||||
"StopReason",
|
||||
"SubscribeRequest",
|
||||
"Tool",
|
||||
"ToolsCapability",
|
||||
"UnsubscribeRequest",
|
||||
"stdio_client",
|
||||
"stdio_server",
|
||||
"CompleteRequest",
|
||||
"JSONRPCResponse",
|
||||
]
|
||||
0
src/mcp/client/__init__.py
Normal file
0
src/mcp/client/__init__.py
Normal file
76
src/mcp/client/__main__.py
Normal file
76
src/mcp/client/__main__.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import logging
|
||||
import sys
|
||||
from functools import partial
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import anyio
|
||||
import click
|
||||
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import StdioServerParameters, stdio_client
|
||||
|
||||
if not sys.warnoptions:
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("client")
|
||||
|
||||
|
||||
async def receive_loop(session: ClientSession):
|
||||
logger.info("Starting receive loop")
|
||||
async for message in session.incoming_messages:
|
||||
if isinstance(message, Exception):
|
||||
logger.error("Error: %s", message)
|
||||
continue
|
||||
|
||||
logger.info("Received message from server: %s", message)
|
||||
|
||||
|
||||
async def run_session(read_stream, write_stream):
|
||||
async with (
|
||||
ClientSession(read_stream, write_stream) as session,
|
||||
anyio.create_task_group() as tg,
|
||||
):
|
||||
tg.start_soon(receive_loop, session)
|
||||
|
||||
logger.info("Initializing session")
|
||||
await session.initialize()
|
||||
logger.info("Initialized")
|
||||
|
||||
|
||||
async def main(command_or_url: str, args: list[str], env: list[tuple[str, str]]):
|
||||
env_dict = dict(env)
|
||||
|
||||
if urlparse(command_or_url).scheme in ("http", "https"):
|
||||
# Use SSE client for HTTP(S) URLs
|
||||
async with sse_client(command_or_url) as streams:
|
||||
await run_session(*streams)
|
||||
else:
|
||||
# Use stdio client for commands
|
||||
server_parameters = StdioServerParameters(
|
||||
command=command_or_url, args=args, env=env_dict
|
||||
)
|
||||
async with stdio_client(server_parameters) as streams:
|
||||
await run_session(*streams)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("command_or_url")
|
||||
@click.argument("args", nargs=-1)
|
||||
@click.option(
|
||||
"--env",
|
||||
"-e",
|
||||
multiple=True,
|
||||
nargs=2,
|
||||
metavar="KEY VALUE",
|
||||
help="Environment variables to set. Can be used multiple times.",
|
||||
)
|
||||
def cli(*args, **kwargs):
|
||||
anyio.run(partial(main, *args, **kwargs), backend="trio")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
313
src/mcp/client/session.py
Normal file
313
src/mcp/client/session.py
Normal file
@@ -0,0 +1,313 @@
|
||||
from datetime import timedelta
|
||||
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl
|
||||
|
||||
from mcp.shared.session import BaseSession
|
||||
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
||||
from mcp.types import (
|
||||
LATEST_PROTOCOL_VERSION,
|
||||
CallToolResult,
|
||||
ClientCapabilities,
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
ClientResult,
|
||||
CompleteResult,
|
||||
EmptyResult,
|
||||
GetPromptResult,
|
||||
Implementation,
|
||||
InitializedNotification,
|
||||
InitializeResult,
|
||||
JSONRPCMessage,
|
||||
ListPromptsResult,
|
||||
ListResourcesResult,
|
||||
ListToolsResult,
|
||||
LoggingLevel,
|
||||
PromptReference,
|
||||
ReadResourceResult,
|
||||
ResourceReference,
|
||||
RootsCapability,
|
||||
ServerNotification,
|
||||
ServerRequest,
|
||||
)
|
||||
|
||||
|
||||
class ClientSession(
|
||||
BaseSession[
|
||||
ClientRequest,
|
||||
ClientNotification,
|
||||
ClientResult,
|
||||
ServerRequest,
|
||||
ServerNotification,
|
||||
]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
read_stream,
|
||||
write_stream,
|
||||
ServerRequest,
|
||||
ServerNotification,
|
||||
read_timeout_seconds=read_timeout_seconds,
|
||||
)
|
||||
|
||||
async def initialize(self) -> InitializeResult:
|
||||
from mcp.types import (
|
||||
InitializeRequest,
|
||||
InitializeRequestParams,
|
||||
)
|
||||
|
||||
result = await self.send_request(
|
||||
ClientRequest(
|
||||
InitializeRequest(
|
||||
method="initialize",
|
||||
params=InitializeRequestParams(
|
||||
protocolVersion=LATEST_PROTOCOL_VERSION,
|
||||
capabilities=ClientCapabilities(
|
||||
sampling=None,
|
||||
experimental=None,
|
||||
roots=RootsCapability(
|
||||
# TODO: Should this be based on whether we
|
||||
# _will_ send notifications, or only whether
|
||||
# they're supported?
|
||||
listChanged=True
|
||||
),
|
||||
),
|
||||
clientInfo=Implementation(name="mcp", version="0.1.0"),
|
||||
),
|
||||
)
|
||||
),
|
||||
InitializeResult,
|
||||
)
|
||||
|
||||
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
|
||||
raise RuntimeError(
|
||||
"Unsupported protocol version from the server: "
|
||||
f"{result.protocolVersion}"
|
||||
)
|
||||
|
||||
await self.send_notification(
|
||||
ClientNotification(
|
||||
InitializedNotification(method="notifications/initialized")
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def send_ping(self) -> EmptyResult:
|
||||
"""Send a ping request."""
|
||||
from mcp.types import PingRequest
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
PingRequest(
|
||||
method="ping",
|
||||
)
|
||||
),
|
||||
EmptyResult,
|
||||
)
|
||||
|
||||
async def send_progress_notification(
|
||||
self, progress_token: str | int, progress: float, total: float | None = None
|
||||
) -> None:
|
||||
"""Send a progress notification."""
|
||||
from mcp.types import (
|
||||
ProgressNotification,
|
||||
ProgressNotificationParams,
|
||||
)
|
||||
|
||||
await self.send_notification(
|
||||
ClientNotification(
|
||||
ProgressNotification(
|
||||
method="notifications/progress",
|
||||
params=ProgressNotificationParams(
|
||||
progressToken=progress_token,
|
||||
progress=progress,
|
||||
total=total,
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
async def set_logging_level(self, level: LoggingLevel) -> EmptyResult:
|
||||
"""Send a logging/setLevel request."""
|
||||
from mcp.types import (
|
||||
SetLevelRequest,
|
||||
SetLevelRequestParams,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
SetLevelRequest(
|
||||
method="logging/setLevel",
|
||||
params=SetLevelRequestParams(level=level),
|
||||
)
|
||||
),
|
||||
EmptyResult,
|
||||
)
|
||||
|
||||
async def list_resources(self) -> ListResourcesResult:
|
||||
"""Send a resources/list request."""
|
||||
from mcp.types import (
|
||||
ListResourcesRequest,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
ListResourcesRequest(
|
||||
method="resources/list",
|
||||
)
|
||||
),
|
||||
ListResourcesResult,
|
||||
)
|
||||
|
||||
async def read_resource(self, uri: AnyUrl) -> ReadResourceResult:
|
||||
"""Send a resources/read request."""
|
||||
from mcp.types import (
|
||||
ReadResourceRequest,
|
||||
ReadResourceRequestParams,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
ReadResourceRequest(
|
||||
method="resources/read",
|
||||
params=ReadResourceRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
ReadResourceResult,
|
||||
)
|
||||
|
||||
async def subscribe_resource(self, uri: AnyUrl) -> EmptyResult:
|
||||
"""Send a resources/subscribe request."""
|
||||
from mcp.types import (
|
||||
SubscribeRequest,
|
||||
SubscribeRequestParams,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
SubscribeRequest(
|
||||
method="resources/subscribe",
|
||||
params=SubscribeRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
EmptyResult,
|
||||
)
|
||||
|
||||
async def unsubscribe_resource(self, uri: AnyUrl) -> EmptyResult:
|
||||
"""Send a resources/unsubscribe request."""
|
||||
from mcp.types import (
|
||||
UnsubscribeRequest,
|
||||
UnsubscribeRequestParams,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
UnsubscribeRequest(
|
||||
method="resources/unsubscribe",
|
||||
params=UnsubscribeRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
EmptyResult,
|
||||
)
|
||||
|
||||
async def call_tool(
|
||||
self, name: str, arguments: dict | None = None
|
||||
) -> CallToolResult:
|
||||
"""Send a tools/call request."""
|
||||
from mcp.types import (
|
||||
CallToolRequest,
|
||||
CallToolRequestParams,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
CallToolRequest(
|
||||
method="tools/call",
|
||||
params=CallToolRequestParams(name=name, arguments=arguments),
|
||||
)
|
||||
),
|
||||
CallToolResult,
|
||||
)
|
||||
|
||||
async def list_prompts(self) -> ListPromptsResult:
|
||||
"""Send a prompts/list request."""
|
||||
from mcp.types import ListPromptsRequest
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
ListPromptsRequest(
|
||||
method="prompts/list",
|
||||
)
|
||||
),
|
||||
ListPromptsResult,
|
||||
)
|
||||
|
||||
async def get_prompt(
|
||||
self, name: str, arguments: dict[str, str] | None = None
|
||||
) -> GetPromptResult:
|
||||
"""Send a prompts/get request."""
|
||||
from mcp.types import GetPromptRequest, GetPromptRequestParams
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
GetPromptRequest(
|
||||
method="prompts/get",
|
||||
params=GetPromptRequestParams(name=name, arguments=arguments),
|
||||
)
|
||||
),
|
||||
GetPromptResult,
|
||||
)
|
||||
|
||||
async def complete(
|
||||
self, ref: ResourceReference | PromptReference, argument: dict
|
||||
) -> CompleteResult:
|
||||
"""Send a completion/complete request."""
|
||||
from mcp.types import (
|
||||
CompleteRequest,
|
||||
CompleteRequestParams,
|
||||
CompletionArgument,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
CompleteRequest(
|
||||
method="completion/complete",
|
||||
params=CompleteRequestParams(
|
||||
ref=ref,
|
||||
argument=CompletionArgument(**argument),
|
||||
),
|
||||
)
|
||||
),
|
||||
CompleteResult,
|
||||
)
|
||||
|
||||
async def list_tools(self) -> ListToolsResult:
|
||||
"""Send a tools/list request."""
|
||||
from mcp.types import ListToolsRequest
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
ListToolsRequest(
|
||||
method="tools/list",
|
||||
)
|
||||
),
|
||||
ListToolsResult,
|
||||
)
|
||||
|
||||
async def send_roots_list_changed(self) -> None:
|
||||
"""Send a roots/list_changed notification."""
|
||||
from mcp.types import RootsListChangedNotification
|
||||
|
||||
await self.send_notification(
|
||||
ClientNotification(
|
||||
RootsListChangedNotification(
|
||||
method="notifications/roots/list_changed",
|
||||
)
|
||||
)
|
||||
)
|
||||
144
src/mcp/client/sse.py
Normal file
144
src/mcp/client/sse.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
from anyio.abc import TaskStatus
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from httpx_sse import aconnect_sse
|
||||
|
||||
from mcp.types import JSONRPCMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def remove_request_params(url: str) -> str:
|
||||
return urljoin(url, urlparse(url).path)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def sse_client(
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float = 5,
|
||||
sse_read_timeout: float = 60 * 5,
|
||||
):
|
||||
"""
|
||||
Client transport for SSE.
|
||||
|
||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
||||
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||
"""
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
try:
|
||||
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
|
||||
async with httpx.AsyncClient(headers=headers) as client:
|
||||
async with aconnect_sse(
|
||||
client,
|
||||
"GET",
|
||||
url,
|
||||
timeout=httpx.Timeout(timeout, read=sse_read_timeout),
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("SSE connection established")
|
||||
|
||||
async def sse_reader(
|
||||
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
|
||||
):
|
||||
try:
|
||||
async for sse in event_source.aiter_sse():
|
||||
logger.debug(f"Received SSE event: {sse.event}")
|
||||
match sse.event:
|
||||
case "endpoint":
|
||||
endpoint_url = urljoin(url, sse.data)
|
||||
logger.info(
|
||||
f"Received endpoint URL: {endpoint_url}"
|
||||
)
|
||||
|
||||
url_parsed = urlparse(url)
|
||||
endpoint_parsed = urlparse(endpoint_url)
|
||||
if (
|
||||
url_parsed.netloc != endpoint_parsed.netloc
|
||||
or url_parsed.scheme
|
||||
!= endpoint_parsed.scheme
|
||||
):
|
||||
error_msg = (
|
||||
"Endpoint origin does not match "
|
||||
f"connection origin: {endpoint_url}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
task_status.started(endpoint_url)
|
||||
|
||||
case "message":
|
||||
try:
|
||||
message = (
|
||||
JSONRPCMessage.model_validate_json(
|
||||
sse.data
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
f"Received server message: {message}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f"Error parsing server message: {exc}"
|
||||
)
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
await read_stream_writer.send(message)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error in sse_reader: {exc}")
|
||||
await read_stream_writer.send(exc)
|
||||
finally:
|
||||
await read_stream_writer.aclose()
|
||||
|
||||
async def post_writer(endpoint_url: str):
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for message in write_stream_reader:
|
||||
logger.debug(f"Sending client message: {message}")
|
||||
response = await client.post(
|
||||
endpoint_url,
|
||||
json=message.model_dump(
|
||||
by_alias=True,
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.debug(
|
||||
"Client message sent successfully: "
|
||||
f"{response.status_code}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error in post_writer: {exc}")
|
||||
finally:
|
||||
await write_stream.aclose()
|
||||
|
||||
endpoint_url = await tg.start(sse_reader)
|
||||
logger.info(
|
||||
f"Starting post writer with endpoint URL: {endpoint_url}"
|
||||
)
|
||||
tg.start_soon(post_writer, endpoint_url)
|
||||
|
||||
try:
|
||||
yield read_stream, write_stream
|
||||
finally:
|
||||
tg.cancel_scope.cancel()
|
||||
finally:
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream.aclose()
|
||||
128
src/mcp/client/stdio.py
Normal file
128
src/mcp/client/stdio.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import os
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import anyio
|
||||
import anyio.lowlevel
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from anyio.streams.text import TextReceiveStream
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from mcp.types import JSONRPCMessage
|
||||
|
||||
# Environment variables to inherit by default
|
||||
DEFAULT_INHERITED_ENV_VARS = (
|
||||
[
|
||||
"APPDATA",
|
||||
"HOMEDRIVE",
|
||||
"HOMEPATH",
|
||||
"LOCALAPPDATA",
|
||||
"PATH",
|
||||
"PROCESSOR_ARCHITECTURE",
|
||||
"SYSTEMDRIVE",
|
||||
"SYSTEMROOT",
|
||||
"TEMP",
|
||||
"USERNAME",
|
||||
"USERPROFILE",
|
||||
]
|
||||
if sys.platform == "win32"
|
||||
else ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"]
|
||||
)
|
||||
|
||||
|
||||
def get_default_environment() -> dict[str, str]:
|
||||
"""
|
||||
Returns a default environment object including only environment variables deemed
|
||||
safe to inherit.
|
||||
"""
|
||||
env: dict[str, str] = {}
|
||||
|
||||
for key in DEFAULT_INHERITED_ENV_VARS:
|
||||
value = os.environ.get(key)
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
if value.startswith("()"):
|
||||
# Skip functions, which are a security risk
|
||||
continue
|
||||
|
||||
env[key] = value
|
||||
|
||||
return env
|
||||
|
||||
|
||||
class StdioServerParameters(BaseModel):
|
||||
command: str
|
||||
"""The executable to run to start the server."""
|
||||
|
||||
args: list[str] = Field(default_factory=list)
|
||||
"""Command line arguments to pass to the executable."""
|
||||
|
||||
env: dict[str, str] | None = None
|
||||
"""
|
||||
The environment to use when spawning the process.
|
||||
|
||||
If not specified, the result of get_default_environment() will be used.
|
||||
"""
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def stdio_client(server: StdioServerParameters):
|
||||
"""
|
||||
Client transport for stdio: this will connect to a server by spawning a
|
||||
process and communicating with it over stdin/stdout.
|
||||
"""
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
process = await anyio.open_process(
|
||||
[server.command, *server.args],
|
||||
env=server.env if server.env is not None else get_default_environment(),
|
||||
stderr=sys.stderr,
|
||||
)
|
||||
|
||||
async def stdout_reader():
|
||||
assert process.stdout, "Opened process is missing stdout"
|
||||
|
||||
try:
|
||||
async with read_stream_writer:
|
||||
buffer = ""
|
||||
async for chunk in TextReceiveStream(process.stdout):
|
||||
lines = (buffer + chunk).split("\n")
|
||||
buffer = lines.pop()
|
||||
|
||||
for line in lines:
|
||||
try:
|
||||
message = JSONRPCMessage.model_validate_json(line)
|
||||
except Exception as exc:
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
await read_stream_writer.send(message)
|
||||
except anyio.ClosedResourceError:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
async def stdin_writer():
|
||||
assert process.stdin, "Opened process is missing stdin"
|
||||
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for message in write_stream_reader:
|
||||
json = message.model_dump_json(by_alias=True, exclude_none=True)
|
||||
await process.stdin.send((json + "\n").encode())
|
||||
except anyio.ClosedResourceError:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
async with (
|
||||
anyio.create_task_group() as tg,
|
||||
process,
|
||||
):
|
||||
tg.start_soon(stdout_reader)
|
||||
tg.start_soon(stdin_writer)
|
||||
yield read_stream, write_stream
|
||||
0
src/mcp/py.typed
Normal file
0
src/mcp/py.typed
Normal file
513
src/mcp/server/__init__.py
Normal file
513
src/mcp/server/__init__.py
Normal file
@@ -0,0 +1,513 @@
|
||||
import contextvars
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, Sequence
|
||||
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl
|
||||
|
||||
from mcp.server import types
|
||||
from mcp.server.session import ServerSession
|
||||
from mcp.server.stdio import stdio_server as stdio_server
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.shared.session import RequestResponder
|
||||
from mcp.types import (
|
||||
METHOD_NOT_FOUND,
|
||||
CallToolRequest,
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
CompleteRequest,
|
||||
EmbeddedResource,
|
||||
EmptyResult,
|
||||
ErrorData,
|
||||
JSONRPCMessage,
|
||||
ListPromptsRequest,
|
||||
ListPromptsResult,
|
||||
ListResourcesRequest,
|
||||
ListResourcesResult,
|
||||
ListToolsRequest,
|
||||
ListToolsResult,
|
||||
LoggingCapability,
|
||||
LoggingLevel,
|
||||
PingRequest,
|
||||
ProgressNotification,
|
||||
Prompt,
|
||||
PromptMessage,
|
||||
PromptReference,
|
||||
PromptsCapability,
|
||||
ReadResourceRequest,
|
||||
ReadResourceResult,
|
||||
Resource,
|
||||
ResourceReference,
|
||||
ResourcesCapability,
|
||||
ServerCapabilities,
|
||||
ServerResult,
|
||||
SetLevelRequest,
|
||||
SubscribeRequest,
|
||||
TextContent,
|
||||
Tool,
|
||||
ToolsCapability,
|
||||
UnsubscribeRequest,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
request_ctx: contextvars.ContextVar[RequestContext] = contextvars.ContextVar(
|
||||
"request_ctx"
|
||||
)
|
||||
|
||||
|
||||
class NotificationOptions:
|
||||
def __init__(
|
||||
self,
|
||||
prompts_changed: bool = False,
|
||||
resources_changed: bool = False,
|
||||
tools_changed: bool = False,
|
||||
):
|
||||
self.prompts_changed = prompts_changed
|
||||
self.resources_changed = resources_changed
|
||||
self.tools_changed = tools_changed
|
||||
|
||||
|
||||
class Server:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]] = {
|
||||
PingRequest: _ping_handler,
|
||||
}
|
||||
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
|
||||
self.notification_options = NotificationOptions()
|
||||
logger.debug(f"Initializing server '{name}'")
|
||||
|
||||
def create_initialization_options(
|
||||
self,
|
||||
notification_options: NotificationOptions | None = None,
|
||||
experimental_capabilities: dict[str, dict[str, Any]] | None = None,
|
||||
) -> types.InitializationOptions:
|
||||
"""Create initialization options from this server instance."""
|
||||
|
||||
def pkg_version(package: str) -> str:
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
|
||||
v = version(package)
|
||||
if v is not None:
|
||||
return v
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return "unknown"
|
||||
|
||||
return types.InitializationOptions(
|
||||
server_name=self.name,
|
||||
server_version=pkg_version("mcp"),
|
||||
capabilities=self.get_capabilities(
|
||||
notification_options or NotificationOptions(),
|
||||
experimental_capabilities or {},
|
||||
),
|
||||
)
|
||||
|
||||
def get_capabilities(
|
||||
self,
|
||||
notification_options: NotificationOptions,
|
||||
experimental_capabilities: dict[str, dict[str, Any]],
|
||||
) -> ServerCapabilities:
|
||||
"""Convert existing handlers to a ServerCapabilities object."""
|
||||
prompts_capability = None
|
||||
resources_capability = None
|
||||
tools_capability = None
|
||||
logging_capability = None
|
||||
|
||||
# Set prompt capabilities if handler exists
|
||||
if ListPromptsRequest in self.request_handlers:
|
||||
prompts_capability = PromptsCapability(
|
||||
listChanged=notification_options.prompts_changed
|
||||
)
|
||||
|
||||
# Set resource capabilities if handler exists
|
||||
if ListResourcesRequest in self.request_handlers:
|
||||
resources_capability = ResourcesCapability(
|
||||
subscribe=False, listChanged=notification_options.resources_changed
|
||||
)
|
||||
|
||||
# Set tool capabilities if handler exists
|
||||
if ListToolsRequest in self.request_handlers:
|
||||
tools_capability = ToolsCapability(
|
||||
listChanged=notification_options.tools_changed
|
||||
)
|
||||
|
||||
# Set logging capabilities if handler exists
|
||||
if SetLevelRequest in self.request_handlers:
|
||||
logging_capability = LoggingCapability()
|
||||
|
||||
return ServerCapabilities(
|
||||
prompts=prompts_capability,
|
||||
resources=resources_capability,
|
||||
tools=tools_capability,
|
||||
logging=logging_capability,
|
||||
experimental=experimental_capabilities,
|
||||
)
|
||||
|
||||
@property
|
||||
def request_context(self) -> RequestContext:
|
||||
"""If called outside of a request context, this will raise a LookupError."""
|
||||
return request_ctx.get()
|
||||
|
||||
def list_prompts(self):
|
||||
def decorator(func: Callable[[], Awaitable[list[Prompt]]]):
|
||||
logger.debug("Registering handler for PromptListRequest")
|
||||
|
||||
async def handler(_: Any):
|
||||
prompts = await func()
|
||||
return ServerResult(ListPromptsResult(prompts=prompts))
|
||||
|
||||
self.request_handlers[ListPromptsRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def get_prompt(self):
|
||||
from mcp.types import (
|
||||
GetPromptRequest,
|
||||
GetPromptResult,
|
||||
ImageContent,
|
||||
)
|
||||
from mcp.types import (
|
||||
Role as Role,
|
||||
)
|
||||
|
||||
def decorator(
|
||||
func: Callable[
|
||||
[str, dict[str, str] | None], Awaitable[types.PromptResponse]
|
||||
],
|
||||
):
|
||||
logger.debug("Registering handler for GetPromptRequest")
|
||||
|
||||
async def handler(req: GetPromptRequest):
|
||||
prompt_get = await func(req.params.name, req.params.arguments)
|
||||
messages: list[PromptMessage] = []
|
||||
for message in prompt_get.messages:
|
||||
match message.content:
|
||||
case str() as text_content:
|
||||
content = TextContent(type="text", text=text_content)
|
||||
case types.ImageContent() as img_content:
|
||||
content = ImageContent(
|
||||
type="image",
|
||||
data=img_content.data,
|
||||
mimeType=img_content.mime_type,
|
||||
)
|
||||
case types.EmbeddedResource() as resource:
|
||||
content = EmbeddedResource(
|
||||
type="resource", resource=resource.resource
|
||||
)
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Unexpected content type: {type(message.content)}"
|
||||
)
|
||||
|
||||
prompt_message = PromptMessage(role=message.role, content=content)
|
||||
messages.append(prompt_message)
|
||||
|
||||
return ServerResult(
|
||||
GetPromptResult(description=prompt_get.desc, messages=messages)
|
||||
)
|
||||
|
||||
self.request_handlers[GetPromptRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def list_resources(self):
|
||||
def decorator(func: Callable[[], Awaitable[list[Resource]]]):
|
||||
logger.debug("Registering handler for ListResourcesRequest")
|
||||
|
||||
async def handler(_: Any):
|
||||
resources = await func()
|
||||
return ServerResult(ListResourcesResult(resources=resources))
|
||||
|
||||
self.request_handlers[ListResourcesRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def read_resource(self):
|
||||
from mcp.types import (
|
||||
BlobResourceContents,
|
||||
TextResourceContents,
|
||||
)
|
||||
|
||||
def decorator(func: Callable[[AnyUrl], Awaitable[str | bytes]]):
|
||||
logger.debug("Registering handler for ReadResourceRequest")
|
||||
|
||||
async def handler(req: ReadResourceRequest):
|
||||
result = await func(req.params.uri)
|
||||
match result:
|
||||
case str(s):
|
||||
content = TextResourceContents(
|
||||
uri=req.params.uri,
|
||||
text=s,
|
||||
mimeType="text/plain",
|
||||
)
|
||||
case bytes(b):
|
||||
import base64
|
||||
|
||||
content = BlobResourceContents(
|
||||
uri=req.params.uri,
|
||||
blob=base64.urlsafe_b64encode(b).decode(),
|
||||
mimeType="application/octet-stream",
|
||||
)
|
||||
|
||||
return ServerResult(
|
||||
ReadResourceResult(
|
||||
contents=[content],
|
||||
)
|
||||
)
|
||||
|
||||
self.request_handlers[ReadResourceRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def set_logging_level(self):
|
||||
from mcp.types import EmptyResult
|
||||
|
||||
def decorator(func: Callable[[LoggingLevel], Awaitable[None]]):
|
||||
logger.debug("Registering handler for SetLevelRequest")
|
||||
|
||||
async def handler(req: SetLevelRequest):
|
||||
await func(req.params.level)
|
||||
return ServerResult(EmptyResult())
|
||||
|
||||
self.request_handlers[SetLevelRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def subscribe_resource(self):
|
||||
from mcp.types import EmptyResult
|
||||
|
||||
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
|
||||
logger.debug("Registering handler for SubscribeRequest")
|
||||
|
||||
async def handler(req: SubscribeRequest):
|
||||
await func(req.params.uri)
|
||||
return ServerResult(EmptyResult())
|
||||
|
||||
self.request_handlers[SubscribeRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def unsubscribe_resource(self):
|
||||
from mcp.types import EmptyResult
|
||||
|
||||
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
|
||||
logger.debug("Registering handler for UnsubscribeRequest")
|
||||
|
||||
async def handler(req: UnsubscribeRequest):
|
||||
await func(req.params.uri)
|
||||
return ServerResult(EmptyResult())
|
||||
|
||||
self.request_handlers[UnsubscribeRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def list_tools(self):
|
||||
def decorator(func: Callable[[], Awaitable[list[Tool]]]):
|
||||
logger.debug("Registering handler for ListToolsRequest")
|
||||
|
||||
async def handler(_: Any):
|
||||
tools = await func()
|
||||
return ServerResult(ListToolsResult(tools=tools))
|
||||
|
||||
self.request_handlers[ListToolsRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def call_tool(self):
|
||||
from mcp.types import (
|
||||
CallToolResult,
|
||||
EmbeddedResource,
|
||||
ImageContent,
|
||||
TextContent,
|
||||
)
|
||||
|
||||
def decorator(
|
||||
func: Callable[
|
||||
...,
|
||||
Awaitable[Sequence[str | types.ImageContent | types.EmbeddedResource]],
|
||||
],
|
||||
):
|
||||
logger.debug("Registering handler for CallToolRequest")
|
||||
|
||||
async def handler(req: CallToolRequest):
|
||||
try:
|
||||
results = await func(req.params.name, (req.params.arguments or {}))
|
||||
content = []
|
||||
for result in results:
|
||||
match result:
|
||||
case str() as text:
|
||||
content.append(TextContent(type="text", text=text))
|
||||
case types.ImageContent() as img:
|
||||
content.append(
|
||||
ImageContent(
|
||||
type="image",
|
||||
data=img.data,
|
||||
mimeType=img.mime_type,
|
||||
)
|
||||
)
|
||||
case types.EmbeddedResource() as resource:
|
||||
content.append(
|
||||
EmbeddedResource(
|
||||
type="resource", resource=resource.resource
|
||||
)
|
||||
)
|
||||
|
||||
return ServerResult(CallToolResult(content=content, isError=False))
|
||||
except Exception as e:
|
||||
return ServerResult(
|
||||
CallToolResult(
|
||||
content=[TextContent(type="text", text=str(e))],
|
||||
isError=True,
|
||||
)
|
||||
)
|
||||
|
||||
self.request_handlers[CallToolRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def progress_notification(self):
|
||||
def decorator(
|
||||
func: Callable[[str | int, float, float | None], Awaitable[None]],
|
||||
):
|
||||
logger.debug("Registering handler for ProgressNotification")
|
||||
|
||||
async def handler(req: ProgressNotification):
|
||||
await func(
|
||||
req.params.progressToken, req.params.progress, req.params.total
|
||||
)
|
||||
|
||||
self.notification_handlers[ProgressNotification] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def completion(self):
|
||||
"""Provides completions for prompts and resource templates"""
|
||||
from mcp.types import CompleteResult, Completion, CompletionArgument
|
||||
|
||||
def decorator(
|
||||
func: Callable[
|
||||
[PromptReference | ResourceReference, CompletionArgument],
|
||||
Awaitable[Completion | None],
|
||||
],
|
||||
):
|
||||
logger.debug("Registering handler for CompleteRequest")
|
||||
|
||||
async def handler(req: CompleteRequest):
|
||||
completion = await func(req.params.ref, req.params.argument)
|
||||
return ServerResult(
|
||||
CompleteResult(
|
||||
completion=completion
|
||||
if completion is not None
|
||||
else Completion(values=[], total=None, hasMore=None),
|
||||
)
|
||||
)
|
||||
|
||||
self.request_handlers[CompleteRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
async def run(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||
initialization_options: types.InitializationOptions,
|
||||
# When True, exceptions are returned as messages to the client.
|
||||
# When False, exceptions are raised, which will cause the server to shut down
|
||||
# but also make tracing exceptions much easier during testing and when using
|
||||
# in-process servers.
|
||||
raise_exceptions: bool = False,
|
||||
):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
async with ServerSession(
|
||||
read_stream, write_stream, initialization_options
|
||||
) as session:
|
||||
async for message in session.incoming_messages:
|
||||
logger.debug(f"Received message: {message}")
|
||||
|
||||
match message:
|
||||
case RequestResponder(request=ClientRequest(root=req)):
|
||||
logger.info(
|
||||
f"Processing request of type {type(req).__name__}"
|
||||
)
|
||||
if type(req) in self.request_handlers:
|
||||
handler = self.request_handlers[type(req)]
|
||||
logger.debug(
|
||||
f"Dispatching request of type {type(req).__name__}"
|
||||
)
|
||||
|
||||
token = None
|
||||
try:
|
||||
# Set our global state that can be retrieved via
|
||||
# app.get_request_context()
|
||||
token = request_ctx.set(
|
||||
RequestContext(
|
||||
message.request_id,
|
||||
message.request_meta,
|
||||
session,
|
||||
)
|
||||
)
|
||||
response = await handler(req)
|
||||
except Exception as err:
|
||||
if raise_exceptions:
|
||||
raise err
|
||||
response = ErrorData(
|
||||
code=0, message=str(err), data=None
|
||||
)
|
||||
finally:
|
||||
# Reset the global state after we are done
|
||||
if token is not None:
|
||||
request_ctx.reset(token)
|
||||
|
||||
await message.respond(response)
|
||||
else:
|
||||
await message.respond(
|
||||
ErrorData(
|
||||
code=METHOD_NOT_FOUND,
|
||||
message="Method not found",
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug("Response sent")
|
||||
case ClientNotification(root=notify):
|
||||
if type(notify) in self.notification_handlers:
|
||||
assert type(notify) in self.notification_handlers
|
||||
|
||||
handler = self.notification_handlers[type(notify)]
|
||||
logger.debug(
|
||||
f"Dispatching notification of type "
|
||||
f"{type(notify).__name__}"
|
||||
)
|
||||
|
||||
try:
|
||||
await handler(notify)
|
||||
except Exception as err:
|
||||
logger.error(
|
||||
f"Uncaught exception in notification handler: "
|
||||
f"{err}"
|
||||
)
|
||||
|
||||
for warning in w:
|
||||
logger.info(
|
||||
f"Warning: {warning.category.__name__}: {warning.message}"
|
||||
)
|
||||
|
||||
|
||||
async def _ping_handler(request: PingRequest) -> ServerResult:
|
||||
return ServerResult(EmptyResult())
|
||||
50
src/mcp/server/__main__.py
Normal file
50
src/mcp/server/__main__.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import importlib.metadata
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import anyio
|
||||
|
||||
from mcp.server.session import ServerSession
|
||||
from mcp.server.stdio import stdio_server
|
||||
from mcp.server.types import InitializationOptions
|
||||
from mcp.types import ServerCapabilities
|
||||
|
||||
if not sys.warnoptions:
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("server")
|
||||
|
||||
|
||||
async def receive_loop(session: ServerSession):
|
||||
logger.info("Starting receive loop")
|
||||
async for message in session.incoming_messages:
|
||||
if isinstance(message, Exception):
|
||||
logger.error("Error: %s", message)
|
||||
continue
|
||||
|
||||
logger.info("Received message from client: %s", message)
|
||||
|
||||
|
||||
async def main():
|
||||
version = importlib.metadata.version("mcp")
|
||||
async with stdio_server() as (read_stream, write_stream):
|
||||
async with (
|
||||
ServerSession(
|
||||
read_stream,
|
||||
write_stream,
|
||||
InitializationOptions(
|
||||
server_name="mcp",
|
||||
server_version=version,
|
||||
capabilities=ServerCapabilities(),
|
||||
),
|
||||
) as session,
|
||||
write_stream,
|
||||
):
|
||||
await receive_loop(session)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
anyio.run(main, backend="trio")
|
||||
250
src/mcp/server/session.py
Normal file
250
src/mcp/server/session.py
Normal file
@@ -0,0 +1,250 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
import anyio.lowlevel
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl
|
||||
|
||||
from mcp.server.types import InitializationOptions
|
||||
from mcp.shared.session import (
|
||||
BaseSession,
|
||||
RequestResponder,
|
||||
)
|
||||
from mcp.types import (
|
||||
LATEST_PROTOCOL_VERSION,
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
CreateMessageResult,
|
||||
EmptyResult,
|
||||
Implementation,
|
||||
IncludeContext,
|
||||
InitializedNotification,
|
||||
InitializeRequest,
|
||||
InitializeResult,
|
||||
JSONRPCMessage,
|
||||
ListRootsResult,
|
||||
LoggingLevel,
|
||||
ModelPreferences,
|
||||
PromptListChangedNotification,
|
||||
ResourceListChangedNotification,
|
||||
SamplingMessage,
|
||||
ServerNotification,
|
||||
ServerRequest,
|
||||
ServerResult,
|
||||
ToolListChangedNotification,
|
||||
)
|
||||
|
||||
|
||||
class InitializationState(Enum):
|
||||
NotInitialized = 1
|
||||
Initializing = 2
|
||||
Initialized = 3
|
||||
|
||||
|
||||
class ServerSession(
|
||||
BaseSession[
|
||||
ServerRequest,
|
||||
ServerNotification,
|
||||
ServerResult,
|
||||
ClientRequest,
|
||||
ClientNotification,
|
||||
]
|
||||
):
|
||||
_initialized: InitializationState = InitializationState.NotInitialized
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||
init_options: InitializationOptions,
|
||||
) -> None:
|
||||
super().__init__(read_stream, write_stream, ClientRequest, ClientNotification)
|
||||
self._initialization_state = InitializationState.NotInitialized
|
||||
self._init_options = init_options
|
||||
|
||||
async def _received_request(
|
||||
self, responder: RequestResponder[ClientRequest, ServerResult]
|
||||
):
|
||||
match responder.request.root:
|
||||
case InitializeRequest():
|
||||
self._initialization_state = InitializationState.Initializing
|
||||
await responder.respond(
|
||||
ServerResult(
|
||||
InitializeResult(
|
||||
protocolVersion=LATEST_PROTOCOL_VERSION,
|
||||
capabilities=self._init_options.capabilities,
|
||||
serverInfo=Implementation(
|
||||
name=self._init_options.server_name,
|
||||
version=self._init_options.server_version,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
case _:
|
||||
if self._initialization_state != InitializationState.Initialized:
|
||||
raise RuntimeError(
|
||||
"Received request before initialization was complete"
|
||||
)
|
||||
|
||||
async def _received_notification(self, notification: ClientNotification) -> None:
|
||||
# Need this to avoid ASYNC910
|
||||
await anyio.lowlevel.checkpoint()
|
||||
match notification.root:
|
||||
case InitializedNotification():
|
||||
self._initialization_state = InitializationState.Initialized
|
||||
case _:
|
||||
if self._initialization_state != InitializationState.Initialized:
|
||||
raise RuntimeError(
|
||||
"Received notification before initialization was complete"
|
||||
)
|
||||
|
||||
async def send_log_message(
|
||||
self, level: LoggingLevel, data: Any, logger: str | None = None
|
||||
) -> None:
|
||||
"""Send a log message notification."""
|
||||
from mcp.types import (
|
||||
LoggingMessageNotification,
|
||||
LoggingMessageNotificationParams,
|
||||
)
|
||||
|
||||
await self.send_notification(
|
||||
ServerNotification(
|
||||
LoggingMessageNotification(
|
||||
method="notifications/message",
|
||||
params=LoggingMessageNotificationParams(
|
||||
level=level,
|
||||
data=data,
|
||||
logger=logger,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def send_resource_updated(self, uri: AnyUrl) -> None:
|
||||
"""Send a resource updated notification."""
|
||||
from mcp.types import (
|
||||
ResourceUpdatedNotification,
|
||||
ResourceUpdatedNotificationParams,
|
||||
)
|
||||
|
||||
await self.send_notification(
|
||||
ServerNotification(
|
||||
ResourceUpdatedNotification(
|
||||
method="notifications/resources/updated",
|
||||
params=ResourceUpdatedNotificationParams(uri=uri),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def create_message(
|
||||
self,
|
||||
messages: list[SamplingMessage],
|
||||
*,
|
||||
max_tokens: int,
|
||||
system_prompt: str | None = None,
|
||||
include_context: IncludeContext | None = None,
|
||||
temperature: float | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model_preferences: ModelPreferences | None = None,
|
||||
) -> CreateMessageResult:
|
||||
"""Send a sampling/create_message request."""
|
||||
from mcp.types import (
|
||||
CreateMessageRequest,
|
||||
CreateMessageRequestParams,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ServerRequest(
|
||||
CreateMessageRequest(
|
||||
method="sampling/createMessage",
|
||||
params=CreateMessageRequestParams(
|
||||
messages=messages,
|
||||
systemPrompt=system_prompt,
|
||||
includeContext=include_context,
|
||||
temperature=temperature,
|
||||
maxTokens=max_tokens,
|
||||
stopSequences=stop_sequences,
|
||||
metadata=metadata,
|
||||
modelPreferences=model_preferences,
|
||||
),
|
||||
)
|
||||
),
|
||||
CreateMessageResult,
|
||||
)
|
||||
|
||||
async def list_roots(self) -> ListRootsResult:
|
||||
"""Send a roots/list request."""
|
||||
from mcp.types import ListRootsRequest
|
||||
|
||||
return await self.send_request(
|
||||
ServerRequest(
|
||||
ListRootsRequest(
|
||||
method="roots/list",
|
||||
)
|
||||
),
|
||||
ListRootsResult,
|
||||
)
|
||||
|
||||
async def send_ping(self) -> EmptyResult:
|
||||
"""Send a ping request."""
|
||||
from mcp.types import PingRequest
|
||||
|
||||
return await self.send_request(
|
||||
ServerRequest(
|
||||
PingRequest(
|
||||
method="ping",
|
||||
)
|
||||
),
|
||||
EmptyResult,
|
||||
)
|
||||
|
||||
async def send_progress_notification(
|
||||
self, progress_token: str | int, progress: float, total: float | None = None
|
||||
) -> None:
|
||||
"""Send a progress notification."""
|
||||
from mcp.types import ProgressNotification, ProgressNotificationParams
|
||||
|
||||
await self.send_notification(
|
||||
ServerNotification(
|
||||
ProgressNotification(
|
||||
method="notifications/progress",
|
||||
params=ProgressNotificationParams(
|
||||
progressToken=progress_token,
|
||||
progress=progress,
|
||||
total=total,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def send_resource_list_changed(self) -> None:
|
||||
"""Send a resource list changed notification."""
|
||||
await self.send_notification(
|
||||
ServerNotification(
|
||||
ResourceListChangedNotification(
|
||||
method="notifications/resources/list_changed",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def send_tool_list_changed(self) -> None:
|
||||
"""Send a tool list changed notification."""
|
||||
await self.send_notification(
|
||||
ServerNotification(
|
||||
ToolListChangedNotification(
|
||||
method="notifications/tools/list_changed",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def send_prompt_list_changed(self) -> None:
|
||||
"""Send a prompt list changed notification."""
|
||||
await self.send_notification(
|
||||
ServerNotification(
|
||||
PromptListChangedNotification(
|
||||
method="notifications/prompts/list_changed",
|
||||
)
|
||||
)
|
||||
)
|
||||
140
src/mcp/server/sse.py
Normal file
140
src/mcp/server/sse.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import ValidationError
|
||||
from sse_starlette import EventSourceResponse
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
from mcp.types import JSONRPCMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SseServerTransport:
|
||||
"""
|
||||
SSE server transport for MCP. This class provides _two_ ASGI applications,
|
||||
suitable to be used with a framework like Starlette and a server like Hypercorn:
|
||||
|
||||
1. connect_sse() is an ASGI application which receives incoming GET requests,
|
||||
and sets up a new SSE stream to send server messages to the client.
|
||||
2. handle_post_message() is an ASGI application which receives incoming POST
|
||||
requests, which should contain client messages that link to a
|
||||
previously-established SSE session.
|
||||
"""
|
||||
|
||||
_endpoint: str
|
||||
_read_stream_writers: dict[UUID, MemoryObjectSendStream[JSONRPCMessage | Exception]]
|
||||
|
||||
def __init__(self, endpoint: str) -> None:
|
||||
"""
|
||||
Creates a new SSE server transport, which will direct the client to POST
|
||||
messages to the relative or absolute URL given.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self._endpoint = endpoint
|
||||
self._read_stream_writers = {}
|
||||
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
|
||||
if scope["type"] != "http":
|
||||
logger.error("connect_sse received non-HTTP request")
|
||||
raise ValueError("connect_sse can only handle HTTP requests")
|
||||
|
||||
logger.debug("Setting up SSE connection")
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
session_id = uuid4()
|
||||
session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}"
|
||||
self._read_stream_writers[session_id] = read_stream_writer
|
||||
logger.debug(f"Created new session with ID: {session_id}")
|
||||
|
||||
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream(
|
||||
0, dict[str, Any]
|
||||
)
|
||||
|
||||
async def sse_writer():
|
||||
logger.debug("Starting SSE writer")
|
||||
async with sse_stream_writer, write_stream_reader:
|
||||
await sse_stream_writer.send({"event": "endpoint", "data": session_uri})
|
||||
logger.debug(f"Sent endpoint event: {session_uri}")
|
||||
|
||||
async for message in write_stream_reader:
|
||||
logger.debug(f"Sending message via SSE: {message}")
|
||||
await sse_stream_writer.send(
|
||||
{
|
||||
"event": "message",
|
||||
"data": message.model_dump_json(
|
||||
by_alias=True, exclude_none=True
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
response = EventSourceResponse(
|
||||
content=sse_stream_reader, data_sender_callable=sse_writer
|
||||
)
|
||||
logger.debug("Starting SSE response task")
|
||||
tg.start_soon(response, scope, receive, send)
|
||||
|
||||
logger.debug("Yielding read and write streams")
|
||||
yield (read_stream, write_stream)
|
||||
|
||||
async def handle_post_message(
|
||||
self, scope: Scope, receive: Receive, send: Send
|
||||
) -> None:
|
||||
logger.debug("Handling POST message")
|
||||
request = Request(scope, receive)
|
||||
|
||||
session_id_param = request.query_params.get("session_id")
|
||||
if session_id_param is None:
|
||||
logger.warning("Received request without session_id")
|
||||
response = Response("session_id is required", status_code=400)
|
||||
return await response(scope, receive, send)
|
||||
|
||||
try:
|
||||
session_id = UUID(hex=session_id_param)
|
||||
logger.debug(f"Parsed session ID: {session_id}")
|
||||
except ValueError:
|
||||
logger.warning(f"Received invalid session ID: {session_id_param}")
|
||||
response = Response("Invalid session ID", status_code=400)
|
||||
return await response(scope, receive, send)
|
||||
|
||||
writer = self._read_stream_writers.get(session_id)
|
||||
if not writer:
|
||||
logger.warning(f"Could not find session for ID: {session_id}")
|
||||
response = Response("Could not find session", status_code=404)
|
||||
return await response(scope, receive, send)
|
||||
|
||||
json = await request.json()
|
||||
logger.debug(f"Received JSON: {json}")
|
||||
|
||||
try:
|
||||
message = JSONRPCMessage.model_validate(json)
|
||||
logger.debug(f"Validated client message: {message}")
|
||||
except ValidationError as err:
|
||||
logger.error(f"Failed to parse message: {err}")
|
||||
response = Response("Could not parse message", status_code=400)
|
||||
await response(scope, receive, send)
|
||||
await writer.send(err)
|
||||
return
|
||||
|
||||
logger.debug(f"Sending message to writer: {message}")
|
||||
response = Response("Accepted", status_code=202)
|
||||
await response(scope, receive, send)
|
||||
await writer.send(message)
|
||||
63
src/mcp/server/stdio.py
Normal file
63
src/mcp/server/stdio.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import anyio
|
||||
import anyio.lowlevel
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
|
||||
from mcp.types import JSONRPCMessage
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def stdio_server(
|
||||
stdin: anyio.AsyncFile[str] | None = None,
|
||||
stdout: anyio.AsyncFile[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Server transport for stdio: this communicates with an MCP client by reading
|
||||
from the current process' stdin and writing to stdout.
|
||||
"""
|
||||
# Purposely not using context managers for these, as we don't want to close
|
||||
# standard process handles.
|
||||
if not stdin:
|
||||
stdin = anyio.wrap_file(sys.stdin)
|
||||
if not stdout:
|
||||
stdout = anyio.wrap_file(sys.stdout)
|
||||
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
async def stdin_reader():
|
||||
try:
|
||||
async with read_stream_writer:
|
||||
async for line in stdin:
|
||||
try:
|
||||
message = JSONRPCMessage.model_validate_json(line)
|
||||
except Exception as exc:
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
await read_stream_writer.send(message)
|
||||
except anyio.ClosedResourceError:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
async def stdout_writer():
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for message in write_stream_reader:
|
||||
json = message.model_dump_json(by_alias=True, exclude_none=True)
|
||||
await stdout.write(json + "\n")
|
||||
await stdout.flush()
|
||||
except anyio.ClosedResourceError:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(stdin_reader)
|
||||
tg.start_soon(stdout_writer)
|
||||
yield read_stream, write_stream
|
||||
46
src/mcp/server/types.py
Normal file
46
src/mcp/server/types.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
This module provides simpler types to use with the server for managing prompts
|
||||
and tools.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from mcp.types import (
|
||||
BlobResourceContents,
|
||||
Role,
|
||||
ServerCapabilities,
|
||||
TextResourceContents,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageContent:
|
||||
type: Literal["image"]
|
||||
data: str
|
||||
mime_type: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddedResource:
|
||||
resource: TextResourceContents | BlobResourceContents
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
role: Role
|
||||
content: str | ImageContent | EmbeddedResource
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptResponse:
|
||||
messages: list[Message]
|
||||
desc: str | None = None
|
||||
|
||||
|
||||
class InitializationOptions(BaseModel):
|
||||
server_name: str
|
||||
server_version: str
|
||||
capabilities: ServerCapabilities
|
||||
61
src/mcp/server/websocket.py
Normal file
61
src/mcp/server/websocket.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from starlette.types import Receive, Scope, Send
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
from mcp.types import JSONRPCMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def websocket_server(scope: Scope, receive: Receive, send: Send):
|
||||
"""
|
||||
WebSocket server transport for MCP. This is an ASGI application, suitable to be
|
||||
used with a framework like Starlette and a server like Hypercorn.
|
||||
"""
|
||||
|
||||
websocket = WebSocket(scope, receive, send)
|
||||
await websocket.accept(subprotocol="mcp")
|
||||
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
async def ws_reader():
|
||||
try:
|
||||
async with read_stream_writer:
|
||||
async for message in websocket.iter_json():
|
||||
try:
|
||||
client_message = JSONRPCMessage.model_validate(message)
|
||||
except Exception as exc:
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
await read_stream_writer.send(client_message)
|
||||
except anyio.ClosedResourceError:
|
||||
await websocket.close()
|
||||
|
||||
async def ws_writer():
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for message in write_stream_reader:
|
||||
obj = message.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
)
|
||||
await websocket.send_json(obj)
|
||||
except anyio.ClosedResourceError:
|
||||
await websocket.close()
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(ws_reader)
|
||||
tg.start_soon(ws_writer)
|
||||
yield (read_stream, write_stream)
|
||||
0
src/mcp/shared/__init__.py
Normal file
0
src/mcp/shared/__init__.py
Normal file
14
src/mcp/shared/context.py
Normal file
14
src/mcp/shared/context.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from mcp.shared.session import BaseSession
|
||||
from mcp.types import RequestId, RequestParams
|
||||
|
||||
SessionT = TypeVar("SessionT", bound=BaseSession)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext(Generic[SessionT]):
|
||||
request_id: RequestId
|
||||
meta: RequestParams.Meta | None
|
||||
session: SessionT
|
||||
9
src/mcp/shared/exceptions.py
Normal file
9
src/mcp/shared/exceptions.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from mcp.types import ErrorData
|
||||
|
||||
|
||||
class McpError(Exception):
|
||||
"""
|
||||
Exception type raised when an error arrives over an MCP connection.
|
||||
"""
|
||||
|
||||
error: ErrorData
|
||||
87
src/mcp/shared/memory.py
Normal file
87
src/mcp/shared/memory.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
In-memory transports
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import timedelta
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.server import Server
|
||||
from mcp.types import JSONRPCMessage
|
||||
|
||||
MessageStream = tuple[
|
||||
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
MemoryObjectSendStream[JSONRPCMessage],
|
||||
]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def create_client_server_memory_streams() -> (
|
||||
AsyncGenerator[tuple[MessageStream, MessageStream], None]
|
||||
):
|
||||
"""
|
||||
Creates a pair of bidirectional memory streams for client-server communication.
|
||||
|
||||
Returns:
|
||||
A tuple of (client_streams, server_streams) where each is a tuple of
|
||||
(read_stream, write_stream)
|
||||
"""
|
||||
# Create streams for both directions
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||
JSONRPCMessage | Exception
|
||||
](1)
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||
JSONRPCMessage | Exception
|
||||
](1)
|
||||
|
||||
client_streams = (server_to_client_receive, client_to_server_send)
|
||||
server_streams = (client_to_server_receive, server_to_client_send)
|
||||
|
||||
async with (
|
||||
server_to_client_receive,
|
||||
client_to_server_send,
|
||||
client_to_server_receive,
|
||||
server_to_client_send,
|
||||
):
|
||||
yield client_streams, server_streams
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def create_connected_server_and_client_session(
|
||||
server: Server,
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
raise_exceptions: bool = False,
|
||||
) -> AsyncGenerator[ClientSession, None]:
|
||||
"""Creates a ClientSession that is connected to a running MCP server."""
|
||||
async with create_client_server_memory_streams() as (
|
||||
client_streams,
|
||||
server_streams,
|
||||
):
|
||||
client_read, client_write = client_streams
|
||||
server_read, server_write = server_streams
|
||||
|
||||
# Create a cancel scope for the server task
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(
|
||||
lambda: server.run(
|
||||
server_read,
|
||||
server_write,
|
||||
server.create_initialization_options(),
|
||||
raise_exceptions=raise_exceptions,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
async with ClientSession(
|
||||
read_stream=client_read,
|
||||
write_stream=client_write,
|
||||
read_timeout_seconds=read_timeout_seconds,
|
||||
) as client_session:
|
||||
await client_session.initialize()
|
||||
yield client_session
|
||||
finally:
|
||||
tg.cancel_scope.cancel()
|
||||
40
src/mcp/shared/progress.py
Normal file
40
src/mcp/shared/progress.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.shared.session import BaseSession
|
||||
from mcp.types import ProgressToken
|
||||
|
||||
|
||||
class Progress(BaseModel):
|
||||
progress: float
|
||||
total: float | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProgressContext:
|
||||
session: BaseSession
|
||||
progress_token: ProgressToken
|
||||
total: float | None
|
||||
current: float = field(default=0.0, init=False)
|
||||
|
||||
async def progress(self, amount: float) -> None:
|
||||
self.current += amount
|
||||
|
||||
await self.session.send_progress_notification(
|
||||
self.progress_token, self.current, total=self.total
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def progress(ctx: RequestContext, total: float | None = None):
|
||||
if ctx.meta is None or ctx.meta.progressToken is None:
|
||||
raise ValueError("No progress token provided")
|
||||
|
||||
progress_ctx = ProgressContext(ctx.session, ctx.meta.progressToken, total)
|
||||
try:
|
||||
yield progress_ctx
|
||||
finally:
|
||||
pass
|
||||
288
src/mcp/shared/session.py
Normal file
288
src/mcp/shared/session.py
Normal file
@@ -0,0 +1,288 @@
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
from datetime import timedelta
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import anyio
|
||||
import anyio.lowlevel
|
||||
import httpx
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import BaseModel
|
||||
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.types import (
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
ClientResult,
|
||||
ErrorData,
|
||||
JSONRPCError,
|
||||
JSONRPCMessage,
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
RequestParams,
|
||||
ServerNotification,
|
||||
ServerRequest,
|
||||
ServerResult,
|
||||
)
|
||||
|
||||
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
|
||||
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
|
||||
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
|
||||
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
|
||||
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
|
||||
ReceiveNotificationT = TypeVar(
|
||||
"ReceiveNotificationT", ClientNotification, ServerNotification
|
||||
)
|
||||
|
||||
RequestId = str | int
|
||||
|
||||
|
||||
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||
def __init__(
|
||||
self,
|
||||
request_id: RequestId,
|
||||
request_meta: RequestParams.Meta | None,
|
||||
request: ReceiveRequestT,
|
||||
session: "BaseSession",
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.request_meta = request_meta
|
||||
self.request = request
|
||||
self._session = session
|
||||
self._responded = False
|
||||
|
||||
async def respond(self, response: SendResultT | ErrorData) -> None:
|
||||
assert not self._responded, "Request already responded to"
|
||||
self._responded = True
|
||||
|
||||
await self._session._send_response(
|
||||
request_id=self.request_id, response=response
|
||||
)
|
||||
|
||||
|
||||
class BaseSession(
|
||||
AbstractAsyncContextManager,
|
||||
Generic[
|
||||
SendRequestT,
|
||||
SendNotificationT,
|
||||
SendResultT,
|
||||
ReceiveRequestT,
|
||||
ReceiveNotificationT,
|
||||
],
|
||||
):
|
||||
"""
|
||||
Implements an MCP "session" on top of read/write streams, including features
|
||||
like request/response linking, notifications, and progress.
|
||||
|
||||
This class is an async context manager that automatically starts processing
|
||||
messages when entered.
|
||||
"""
|
||||
|
||||
_response_streams: dict[
|
||||
RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]
|
||||
]
|
||||
_request_id: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||
receive_request_type: type[ReceiveRequestT],
|
||||
receive_notification_type: type[ReceiveNotificationT],
|
||||
# If none, reading will never time out
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
) -> None:
|
||||
self._read_stream = read_stream
|
||||
self._write_stream = write_stream
|
||||
self._response_streams = {}
|
||||
self._request_id = 0
|
||||
self._receive_request_type = receive_request_type
|
||||
self._receive_notification_type = receive_notification_type
|
||||
self._read_timeout_seconds = read_timeout_seconds
|
||||
|
||||
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
|
||||
anyio.create_memory_object_stream[
|
||||
RequestResponder[ReceiveRequestT, SendResultT]
|
||||
| ReceiveNotificationT
|
||||
| Exception
|
||||
]()
|
||||
)
|
||||
|
||||
async def __aenter__(self):
|
||||
self._task_group = anyio.create_task_group()
|
||||
await self._task_group.__aenter__()
|
||||
self._task_group.start_soon(self._receive_loop)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
# Using BaseSession as a context manager should not block on exit (this
|
||||
# would be very surprising behavior), so make sure to cancel the tasks
|
||||
# in the task group.
|
||||
self._task_group.cancel_scope.cancel()
|
||||
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
async def send_request(
|
||||
self,
|
||||
request: SendRequestT,
|
||||
result_type: type[ReceiveResultT],
|
||||
) -> ReceiveResultT:
|
||||
"""
|
||||
Sends a request and wait for a response. Raises an McpError if the
|
||||
response contains an error.
|
||||
|
||||
Do not use this method to emit notifications! Use send_notification()
|
||||
instead.
|
||||
"""
|
||||
|
||||
request_id = self._request_id
|
||||
self._request_id = request_id + 1
|
||||
|
||||
response_stream, response_stream_reader = anyio.create_memory_object_stream[
|
||||
JSONRPCResponse | JSONRPCError
|
||||
](1)
|
||||
self._response_streams[request_id] = response_stream
|
||||
|
||||
jsonrpc_request = JSONRPCRequest(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
|
||||
# TODO: Support progress callbacks
|
||||
|
||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_request))
|
||||
|
||||
try:
|
||||
with anyio.fail_after(
|
||||
None
|
||||
if self._read_timeout_seconds is None
|
||||
else self._read_timeout_seconds.total_seconds()
|
||||
):
|
||||
response_or_error = await response_stream_reader.receive()
|
||||
except TimeoutError:
|
||||
raise McpError(
|
||||
ErrorData(
|
||||
code=httpx.codes.REQUEST_TIMEOUT,
|
||||
message=(
|
||||
f"Timed out while waiting for response to "
|
||||
f"{request.__class__.__name__}. Waited "
|
||||
f"{self._read_timeout_seconds} seconds."
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(response_or_error, JSONRPCError):
|
||||
raise McpError(response_or_error.error)
|
||||
else:
|
||||
return result_type.model_validate(response_or_error.result)
|
||||
|
||||
async def send_notification(self, notification: SendNotificationT) -> None:
|
||||
"""
|
||||
Emits a notification, which is a one-way message that does not expect
|
||||
a response.
|
||||
"""
|
||||
jsonrpc_notification = JSONRPCNotification(
|
||||
jsonrpc="2.0",
|
||||
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
|
||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_notification))
|
||||
|
||||
async def _send_response(
|
||||
self, request_id: RequestId, response: SendResultT | ErrorData
|
||||
) -> None:
|
||||
if isinstance(response, ErrorData):
|
||||
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
|
||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_error))
|
||||
else:
|
||||
jsonrpc_response = JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
result=response.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
),
|
||||
)
|
||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_response))
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
async with (
|
||||
self._read_stream,
|
||||
self._write_stream,
|
||||
self._incoming_message_stream_writer,
|
||||
):
|
||||
async for message in self._read_stream:
|
||||
if isinstance(message, Exception):
|
||||
await self._incoming_message_stream_writer.send(message)
|
||||
elif isinstance(message.root, JSONRPCRequest):
|
||||
validated_request = self._receive_request_type.model_validate(
|
||||
message.root.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
)
|
||||
)
|
||||
responder = RequestResponder(
|
||||
request_id=message.root.id,
|
||||
request_meta=validated_request.root.params._meta
|
||||
if validated_request.root.params
|
||||
else None,
|
||||
request=validated_request,
|
||||
session=self,
|
||||
)
|
||||
|
||||
await self._received_request(responder)
|
||||
if not responder._responded:
|
||||
await self._incoming_message_stream_writer.send(responder)
|
||||
elif isinstance(message.root, JSONRPCNotification):
|
||||
notification = self._receive_notification_type.model_validate(
|
||||
message.root.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
)
|
||||
)
|
||||
|
||||
await self._received_notification(notification)
|
||||
await self._incoming_message_stream_writer.send(notification)
|
||||
else: # Response or error
|
||||
stream = self._response_streams.pop(message.root.id, None)
|
||||
if stream:
|
||||
await stream.send(message.root)
|
||||
else:
|
||||
await self._incoming_message_stream_writer.send(
|
||||
RuntimeError(
|
||||
"Received response with an unknown "
|
||||
f"request ID: {message}"
|
||||
)
|
||||
)
|
||||
|
||||
async def _received_request(
|
||||
self, responder: RequestResponder[ReceiveRequestT, SendResultT]
|
||||
) -> None:
|
||||
"""
|
||||
Can be overridden by subclasses to handle a request without needing to
|
||||
listen on the message stream.
|
||||
|
||||
If the request is responded to within this method, it will not be
|
||||
forwarded on to the message stream.
|
||||
"""
|
||||
|
||||
async def _received_notification(self, notification: ReceiveNotificationT) -> None:
|
||||
"""
|
||||
Can be overridden by subclasses to handle a notification without needing
|
||||
to listen on the message stream.
|
||||
"""
|
||||
|
||||
async def send_progress_notification(
|
||||
self, progress_token: str | int, progress: float, total: float | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Sends a progress notification for a request that is currently being
|
||||
processed.
|
||||
"""
|
||||
|
||||
@property
|
||||
def incoming_messages(
|
||||
self,
|
||||
) -> MemoryObjectReceiveStream[
|
||||
RequestResponder[ReceiveRequestT, SendResultT]
|
||||
| ReceiveNotificationT
|
||||
| Exception
|
||||
]:
|
||||
return self._incoming_message_stream_reader
|
||||
3
src/mcp/shared/version.py
Normal file
3
src/mcp/shared/version.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from mcp.types import LATEST_PROTOCOL_VERSION
|
||||
|
||||
SUPPORTED_PROTOCOL_VERSIONS = [1, LATEST_PROTOCOL_VERSION]
|
||||
1041
src/mcp/types.py
Normal file
1041
src/mcp/types.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user