rename mcp_python to mcp

This commit is contained in:
David Soria Parra
2024-11-11 12:31:36 +00:00
parent aa164ab556
commit ed87ae9f06
31 changed files with 128 additions and 127 deletions

114
src/mcp/__init__.py Normal file
View File

@@ -0,0 +1,114 @@
from .client.session import ClientSession
from .client.stdio import StdioServerParameters, stdio_client
from .server.session import ServerSession
from .server.stdio import stdio_server
from .shared.exceptions import McpError
from .types import (
CallToolRequest,
ClientCapabilities,
ClientNotification,
ClientRequest,
ClientResult,
CompleteRequest,
CreateMessageRequest,
CreateMessageResult,
ErrorData,
GetPromptRequest,
GetPromptResult,
Implementation,
IncludeContext,
InitializedNotification,
InitializeRequest,
InitializeResult,
JSONRPCError,
JSONRPCRequest,
JSONRPCResponse,
ListPromptsRequest,
ListPromptsResult,
ListResourcesRequest,
ListResourcesResult,
ListToolsResult,
LoggingLevel,
LoggingMessageNotification,
Notification,
PingRequest,
ProgressNotification,
PromptsCapability,
ReadResourceRequest,
ReadResourceResult,
Resource,
ResourcesCapability,
ResourceUpdatedNotification,
RootsCapability,
SamplingMessage,
ServerCapabilities,
ServerNotification,
ServerRequest,
ServerResult,
SetLevelRequest,
StopReason,
SubscribeRequest,
Tool,
ToolsCapability,
UnsubscribeRequest,
)
from .types import (
Role as SamplingRole,
)
__all__ = [
"CallToolRequest",
"ClientCapabilities",
"ClientNotification",
"ClientRequest",
"ClientResult",
"ClientSession",
"CreateMessageRequest",
"CreateMessageResult",
"ErrorData",
"GetPromptRequest",
"GetPromptResult",
"Implementation",
"IncludeContext",
"InitializeRequest",
"InitializeResult",
"InitializedNotification",
"JSONRPCError",
"JSONRPCRequest",
"ListPromptsRequest",
"ListPromptsResult",
"ListResourcesRequest",
"ListResourcesResult",
"ListToolsResult",
"LoggingLevel",
"LoggingMessageNotification",
"McpError",
"Notification",
"PingRequest",
"ProgressNotification",
"PromptsCapability",
"ReadResourceRequest",
"ReadResourceResult",
"ResourcesCapability",
"ResourceUpdatedNotification",
"Resource",
"RootsCapability",
"SamplingMessage",
"SamplingRole",
"ServerCapabilities",
"ServerNotification",
"ServerRequest",
"ServerResult",
"ServerSession",
"SetLevelRequest",
"StdioServerParameters",
"StopReason",
"SubscribeRequest",
"Tool",
"ToolsCapability",
"UnsubscribeRequest",
"stdio_client",
"stdio_server",
"CompleteRequest",
"JSONRPCResponse",
]

View File

View File

@@ -0,0 +1,76 @@
import logging
import sys
from functools import partial
from urllib.parse import urlparse
import anyio
import click
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client
if not sys.warnoptions:
import warnings
warnings.simplefilter("ignore")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("client")
async def receive_loop(session: ClientSession):
logger.info("Starting receive loop")
async for message in session.incoming_messages:
if isinstance(message, Exception):
logger.error("Error: %s", message)
continue
logger.info("Received message from server: %s", message)
async def run_session(read_stream, write_stream):
async with (
ClientSession(read_stream, write_stream) as session,
anyio.create_task_group() as tg,
):
tg.start_soon(receive_loop, session)
logger.info("Initializing session")
await session.initialize()
logger.info("Initialized")
async def main(command_or_url: str, args: list[str], env: list[tuple[str, str]]):
env_dict = dict(env)
if urlparse(command_or_url).scheme in ("http", "https"):
# Use SSE client for HTTP(S) URLs
async with sse_client(command_or_url) as streams:
await run_session(*streams)
else:
# Use stdio client for commands
server_parameters = StdioServerParameters(
command=command_or_url, args=args, env=env_dict
)
async with stdio_client(server_parameters) as streams:
await run_session(*streams)
@click.command()
@click.argument("command_or_url")
@click.argument("args", nargs=-1)
@click.option(
"--env",
"-e",
multiple=True,
nargs=2,
metavar="KEY VALUE",
help="Environment variables to set. Can be used multiple times.",
)
def cli(*args, **kwargs):
anyio.run(partial(main, *args, **kwargs), backend="trio")
if __name__ == "__main__":
cli()

313
src/mcp/client/session.py Normal file
View File

@@ -0,0 +1,313 @@
from datetime import timedelta
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
from mcp.shared.session import BaseSession
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
from mcp.types import (
LATEST_PROTOCOL_VERSION,
CallToolResult,
ClientCapabilities,
ClientNotification,
ClientRequest,
ClientResult,
CompleteResult,
EmptyResult,
GetPromptResult,
Implementation,
InitializedNotification,
InitializeResult,
JSONRPCMessage,
ListPromptsResult,
ListResourcesResult,
ListToolsResult,
LoggingLevel,
PromptReference,
ReadResourceResult,
ResourceReference,
RootsCapability,
ServerNotification,
ServerRequest,
)
class ClientSession(
BaseSession[
ClientRequest,
ClientNotification,
ClientResult,
ServerRequest,
ServerNotification,
]
):
def __init__(
self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
read_timeout_seconds: timedelta | None = None,
) -> None:
super().__init__(
read_stream,
write_stream,
ServerRequest,
ServerNotification,
read_timeout_seconds=read_timeout_seconds,
)
async def initialize(self) -> InitializeResult:
from mcp.types import (
InitializeRequest,
InitializeRequestParams,
)
result = await self.send_request(
ClientRequest(
InitializeRequest(
method="initialize",
params=InitializeRequestParams(
protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=ClientCapabilities(
sampling=None,
experimental=None,
roots=RootsCapability(
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
listChanged=True
),
),
clientInfo=Implementation(name="mcp", version="0.1.0"),
),
)
),
InitializeResult,
)
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
raise RuntimeError(
"Unsupported protocol version from the server: "
f"{result.protocolVersion}"
)
await self.send_notification(
ClientNotification(
InitializedNotification(method="notifications/initialized")
)
)
return result
async def send_ping(self) -> EmptyResult:
"""Send a ping request."""
from mcp.types import PingRequest
return await self.send_request(
ClientRequest(
PingRequest(
method="ping",
)
),
EmptyResult,
)
async def send_progress_notification(
self, progress_token: str | int, progress: float, total: float | None = None
) -> None:
"""Send a progress notification."""
from mcp.types import (
ProgressNotification,
ProgressNotificationParams,
)
await self.send_notification(
ClientNotification(
ProgressNotification(
method="notifications/progress",
params=ProgressNotificationParams(
progressToken=progress_token,
progress=progress,
total=total,
),
),
)
)
async def set_logging_level(self, level: LoggingLevel) -> EmptyResult:
"""Send a logging/setLevel request."""
from mcp.types import (
SetLevelRequest,
SetLevelRequestParams,
)
return await self.send_request(
ClientRequest(
SetLevelRequest(
method="logging/setLevel",
params=SetLevelRequestParams(level=level),
)
),
EmptyResult,
)
async def list_resources(self) -> ListResourcesResult:
"""Send a resources/list request."""
from mcp.types import (
ListResourcesRequest,
)
return await self.send_request(
ClientRequest(
ListResourcesRequest(
method="resources/list",
)
),
ListResourcesResult,
)
async def read_resource(self, uri: AnyUrl) -> ReadResourceResult:
"""Send a resources/read request."""
from mcp.types import (
ReadResourceRequest,
ReadResourceRequestParams,
)
return await self.send_request(
ClientRequest(
ReadResourceRequest(
method="resources/read",
params=ReadResourceRequestParams(uri=uri),
)
),
ReadResourceResult,
)
async def subscribe_resource(self, uri: AnyUrl) -> EmptyResult:
"""Send a resources/subscribe request."""
from mcp.types import (
SubscribeRequest,
SubscribeRequestParams,
)
return await self.send_request(
ClientRequest(
SubscribeRequest(
method="resources/subscribe",
params=SubscribeRequestParams(uri=uri),
)
),
EmptyResult,
)
async def unsubscribe_resource(self, uri: AnyUrl) -> EmptyResult:
"""Send a resources/unsubscribe request."""
from mcp.types import (
UnsubscribeRequest,
UnsubscribeRequestParams,
)
return await self.send_request(
ClientRequest(
UnsubscribeRequest(
method="resources/unsubscribe",
params=UnsubscribeRequestParams(uri=uri),
)
),
EmptyResult,
)
async def call_tool(
self, name: str, arguments: dict | None = None
) -> CallToolResult:
"""Send a tools/call request."""
from mcp.types import (
CallToolRequest,
CallToolRequestParams,
)
return await self.send_request(
ClientRequest(
CallToolRequest(
method="tools/call",
params=CallToolRequestParams(name=name, arguments=arguments),
)
),
CallToolResult,
)
async def list_prompts(self) -> ListPromptsResult:
"""Send a prompts/list request."""
from mcp.types import ListPromptsRequest
return await self.send_request(
ClientRequest(
ListPromptsRequest(
method="prompts/list",
)
),
ListPromptsResult,
)
async def get_prompt(
self, name: str, arguments: dict[str, str] | None = None
) -> GetPromptResult:
"""Send a prompts/get request."""
from mcp.types import GetPromptRequest, GetPromptRequestParams
return await self.send_request(
ClientRequest(
GetPromptRequest(
method="prompts/get",
params=GetPromptRequestParams(name=name, arguments=arguments),
)
),
GetPromptResult,
)
async def complete(
self, ref: ResourceReference | PromptReference, argument: dict
) -> CompleteResult:
"""Send a completion/complete request."""
from mcp.types import (
CompleteRequest,
CompleteRequestParams,
CompletionArgument,
)
return await self.send_request(
ClientRequest(
CompleteRequest(
method="completion/complete",
params=CompleteRequestParams(
ref=ref,
argument=CompletionArgument(**argument),
),
)
),
CompleteResult,
)
async def list_tools(self) -> ListToolsResult:
"""Send a tools/list request."""
from mcp.types import ListToolsRequest
return await self.send_request(
ClientRequest(
ListToolsRequest(
method="tools/list",
)
),
ListToolsResult,
)
async def send_roots_list_changed(self) -> None:
"""Send a roots/list_changed notification."""
from mcp.types import RootsListChangedNotification
await self.send_notification(
ClientNotification(
RootsListChangedNotification(
method="notifications/roots/list_changed",
)
)
)

144
src/mcp/client/sse.py Normal file
View File

@@ -0,0 +1,144 @@
import logging
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import urljoin, urlparse
import anyio
import httpx
from anyio.abc import TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import aconnect_sse
from mcp.types import JSONRPCMessage
logger = logging.getLogger(__name__)
def remove_request_params(url: str) -> str:
return urljoin(url, urlparse(url).path)
@asynccontextmanager
async def sse_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5,
sse_read_timeout: float = 60 * 5,
):
"""
Client transport for SSE.
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.
"""
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
write_stream: MemoryObjectSendStream[JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
async with anyio.create_task_group() as tg:
try:
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
async with httpx.AsyncClient(headers=headers) as client:
async with aconnect_sse(
client,
"GET",
url,
timeout=httpx.Timeout(timeout, read=sse_read_timeout),
) as event_source:
event_source.response.raise_for_status()
logger.debug("SSE connection established")
async def sse_reader(
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
):
try:
async for sse in event_source.aiter_sse():
logger.debug(f"Received SSE event: {sse.event}")
match sse.event:
case "endpoint":
endpoint_url = urljoin(url, sse.data)
logger.info(
f"Received endpoint URL: {endpoint_url}"
)
url_parsed = urlparse(url)
endpoint_parsed = urlparse(endpoint_url)
if (
url_parsed.netloc != endpoint_parsed.netloc
or url_parsed.scheme
!= endpoint_parsed.scheme
):
error_msg = (
"Endpoint origin does not match "
f"connection origin: {endpoint_url}"
)
logger.error(error_msg)
raise ValueError(error_msg)
task_status.started(endpoint_url)
case "message":
try:
message = (
JSONRPCMessage.model_validate_json(
sse.data
)
)
logger.debug(
f"Received server message: {message}"
)
except Exception as exc:
logger.error(
f"Error parsing server message: {exc}"
)
await read_stream_writer.send(exc)
continue
await read_stream_writer.send(message)
except Exception as exc:
logger.error(f"Error in sse_reader: {exc}")
await read_stream_writer.send(exc)
finally:
await read_stream_writer.aclose()
async def post_writer(endpoint_url: str):
try:
async with write_stream_reader:
async for message in write_stream_reader:
logger.debug(f"Sending client message: {message}")
response = await client.post(
endpoint_url,
json=message.model_dump(
by_alias=True,
mode="json",
exclude_none=True,
),
)
response.raise_for_status()
logger.debug(
"Client message sent successfully: "
f"{response.status_code}"
)
except Exception as exc:
logger.error(f"Error in post_writer: {exc}")
finally:
await write_stream.aclose()
endpoint_url = await tg.start(sse_reader)
logger.info(
f"Starting post writer with endpoint URL: {endpoint_url}"
)
tg.start_soon(post_writer, endpoint_url)
try:
yield read_stream, write_stream
finally:
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()

128
src/mcp/client/stdio.py Normal file
View File

@@ -0,0 +1,128 @@
import os
import sys
from contextlib import asynccontextmanager
import anyio
import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from anyio.streams.text import TextReceiveStream
from pydantic import BaseModel, Field
from mcp.types import JSONRPCMessage
# Environment variables to inherit by default
DEFAULT_INHERITED_ENV_VARS = (
[
"APPDATA",
"HOMEDRIVE",
"HOMEPATH",
"LOCALAPPDATA",
"PATH",
"PROCESSOR_ARCHITECTURE",
"SYSTEMDRIVE",
"SYSTEMROOT",
"TEMP",
"USERNAME",
"USERPROFILE",
]
if sys.platform == "win32"
else ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"]
)
def get_default_environment() -> dict[str, str]:
"""
Returns a default environment object including only environment variables deemed
safe to inherit.
"""
env: dict[str, str] = {}
for key in DEFAULT_INHERITED_ENV_VARS:
value = os.environ.get(key)
if value is None:
continue
if value.startswith("()"):
# Skip functions, which are a security risk
continue
env[key] = value
return env
class StdioServerParameters(BaseModel):
command: str
"""The executable to run to start the server."""
args: list[str] = Field(default_factory=list)
"""Command line arguments to pass to the executable."""
env: dict[str, str] | None = None
"""
The environment to use when spawning the process.
If not specified, the result of get_default_environment() will be used.
"""
@asynccontextmanager
async def stdio_client(server: StdioServerParameters):
"""
Client transport for stdio: this will connect to a server by spawning a
process and communicating with it over stdin/stdout.
"""
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
write_stream: MemoryObjectSendStream[JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
process = await anyio.open_process(
[server.command, *server.args],
env=server.env if server.env is not None else get_default_environment(),
stderr=sys.stderr,
)
async def stdout_reader():
assert process.stdout, "Opened process is missing stdout"
try:
async with read_stream_writer:
buffer = ""
async for chunk in TextReceiveStream(process.stdout):
lines = (buffer + chunk).split("\n")
buffer = lines.pop()
for line in lines:
try:
message = JSONRPCMessage.model_validate_json(line)
except Exception as exc:
await read_stream_writer.send(exc)
continue
await read_stream_writer.send(message)
except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint()
async def stdin_writer():
assert process.stdin, "Opened process is missing stdin"
try:
async with write_stream_reader:
async for message in write_stream_reader:
json = message.model_dump_json(by_alias=True, exclude_none=True)
await process.stdin.send((json + "\n").encode())
except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint()
async with (
anyio.create_task_group() as tg,
process,
):
tg.start_soon(stdout_reader)
tg.start_soon(stdin_writer)
yield read_stream, write_stream

0
src/mcp/py.typed Normal file
View File

513
src/mcp/server/__init__.py Normal file
View File

@@ -0,0 +1,513 @@
import contextvars
import logging
import warnings
from collections.abc import Awaitable, Callable
from typing import Any, Sequence
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
from mcp.server import types
from mcp.server.session import ServerSession
from mcp.server.stdio import stdio_server as stdio_server
from mcp.shared.context import RequestContext
from mcp.shared.session import RequestResponder
from mcp.types import (
METHOD_NOT_FOUND,
CallToolRequest,
ClientNotification,
ClientRequest,
CompleteRequest,
EmbeddedResource,
EmptyResult,
ErrorData,
JSONRPCMessage,
ListPromptsRequest,
ListPromptsResult,
ListResourcesRequest,
ListResourcesResult,
ListToolsRequest,
ListToolsResult,
LoggingCapability,
LoggingLevel,
PingRequest,
ProgressNotification,
Prompt,
PromptMessage,
PromptReference,
PromptsCapability,
ReadResourceRequest,
ReadResourceResult,
Resource,
ResourceReference,
ResourcesCapability,
ServerCapabilities,
ServerResult,
SetLevelRequest,
SubscribeRequest,
TextContent,
Tool,
ToolsCapability,
UnsubscribeRequest,
)
logger = logging.getLogger(__name__)
request_ctx: contextvars.ContextVar[RequestContext] = contextvars.ContextVar(
"request_ctx"
)
class NotificationOptions:
def __init__(
self,
prompts_changed: bool = False,
resources_changed: bool = False,
tools_changed: bool = False,
):
self.prompts_changed = prompts_changed
self.resources_changed = resources_changed
self.tools_changed = tools_changed
class Server:
def __init__(self, name: str):
self.name = name
self.request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]] = {
PingRequest: _ping_handler,
}
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
self.notification_options = NotificationOptions()
logger.debug(f"Initializing server '{name}'")
def create_initialization_options(
self,
notification_options: NotificationOptions | None = None,
experimental_capabilities: dict[str, dict[str, Any]] | None = None,
) -> types.InitializationOptions:
"""Create initialization options from this server instance."""
def pkg_version(package: str) -> str:
try:
from importlib.metadata import version
v = version(package)
if v is not None:
return v
except Exception:
pass
return "unknown"
return types.InitializationOptions(
server_name=self.name,
server_version=pkg_version("mcp"),
capabilities=self.get_capabilities(
notification_options or NotificationOptions(),
experimental_capabilities or {},
),
)
def get_capabilities(
self,
notification_options: NotificationOptions,
experimental_capabilities: dict[str, dict[str, Any]],
) -> ServerCapabilities:
"""Convert existing handlers to a ServerCapabilities object."""
prompts_capability = None
resources_capability = None
tools_capability = None
logging_capability = None
# Set prompt capabilities if handler exists
if ListPromptsRequest in self.request_handlers:
prompts_capability = PromptsCapability(
listChanged=notification_options.prompts_changed
)
# Set resource capabilities if handler exists
if ListResourcesRequest in self.request_handlers:
resources_capability = ResourcesCapability(
subscribe=False, listChanged=notification_options.resources_changed
)
# Set tool capabilities if handler exists
if ListToolsRequest in self.request_handlers:
tools_capability = ToolsCapability(
listChanged=notification_options.tools_changed
)
# Set logging capabilities if handler exists
if SetLevelRequest in self.request_handlers:
logging_capability = LoggingCapability()
return ServerCapabilities(
prompts=prompts_capability,
resources=resources_capability,
tools=tools_capability,
logging=logging_capability,
experimental=experimental_capabilities,
)
@property
def request_context(self) -> RequestContext:
"""If called outside of a request context, this will raise a LookupError."""
return request_ctx.get()
def list_prompts(self):
def decorator(func: Callable[[], Awaitable[list[Prompt]]]):
logger.debug("Registering handler for PromptListRequest")
async def handler(_: Any):
prompts = await func()
return ServerResult(ListPromptsResult(prompts=prompts))
self.request_handlers[ListPromptsRequest] = handler
return func
return decorator
def get_prompt(self):
from mcp.types import (
GetPromptRequest,
GetPromptResult,
ImageContent,
)
from mcp.types import (
Role as Role,
)
def decorator(
func: Callable[
[str, dict[str, str] | None], Awaitable[types.PromptResponse]
],
):
logger.debug("Registering handler for GetPromptRequest")
async def handler(req: GetPromptRequest):
prompt_get = await func(req.params.name, req.params.arguments)
messages: list[PromptMessage] = []
for message in prompt_get.messages:
match message.content:
case str() as text_content:
content = TextContent(type="text", text=text_content)
case types.ImageContent() as img_content:
content = ImageContent(
type="image",
data=img_content.data,
mimeType=img_content.mime_type,
)
case types.EmbeddedResource() as resource:
content = EmbeddedResource(
type="resource", resource=resource.resource
)
case _:
raise ValueError(
f"Unexpected content type: {type(message.content)}"
)
prompt_message = PromptMessage(role=message.role, content=content)
messages.append(prompt_message)
return ServerResult(
GetPromptResult(description=prompt_get.desc, messages=messages)
)
self.request_handlers[GetPromptRequest] = handler
return func
return decorator
def list_resources(self):
def decorator(func: Callable[[], Awaitable[list[Resource]]]):
logger.debug("Registering handler for ListResourcesRequest")
async def handler(_: Any):
resources = await func()
return ServerResult(ListResourcesResult(resources=resources))
self.request_handlers[ListResourcesRequest] = handler
return func
return decorator
def read_resource(self):
from mcp.types import (
BlobResourceContents,
TextResourceContents,
)
def decorator(func: Callable[[AnyUrl], Awaitable[str | bytes]]):
logger.debug("Registering handler for ReadResourceRequest")
async def handler(req: ReadResourceRequest):
result = await func(req.params.uri)
match result:
case str(s):
content = TextResourceContents(
uri=req.params.uri,
text=s,
mimeType="text/plain",
)
case bytes(b):
import base64
content = BlobResourceContents(
uri=req.params.uri,
blob=base64.urlsafe_b64encode(b).decode(),
mimeType="application/octet-stream",
)
return ServerResult(
ReadResourceResult(
contents=[content],
)
)
self.request_handlers[ReadResourceRequest] = handler
return func
return decorator
def set_logging_level(self):
from mcp.types import EmptyResult
def decorator(func: Callable[[LoggingLevel], Awaitable[None]]):
logger.debug("Registering handler for SetLevelRequest")
async def handler(req: SetLevelRequest):
await func(req.params.level)
return ServerResult(EmptyResult())
self.request_handlers[SetLevelRequest] = handler
return func
return decorator
def subscribe_resource(self):
from mcp.types import EmptyResult
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
logger.debug("Registering handler for SubscribeRequest")
async def handler(req: SubscribeRequest):
await func(req.params.uri)
return ServerResult(EmptyResult())
self.request_handlers[SubscribeRequest] = handler
return func
return decorator
def unsubscribe_resource(self):
from mcp.types import EmptyResult
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
logger.debug("Registering handler for UnsubscribeRequest")
async def handler(req: UnsubscribeRequest):
await func(req.params.uri)
return ServerResult(EmptyResult())
self.request_handlers[UnsubscribeRequest] = handler
return func
return decorator
def list_tools(self):
def decorator(func: Callable[[], Awaitable[list[Tool]]]):
logger.debug("Registering handler for ListToolsRequest")
async def handler(_: Any):
tools = await func()
return ServerResult(ListToolsResult(tools=tools))
self.request_handlers[ListToolsRequest] = handler
return func
return decorator
def call_tool(self):
from mcp.types import (
CallToolResult,
EmbeddedResource,
ImageContent,
TextContent,
)
def decorator(
func: Callable[
...,
Awaitable[Sequence[str | types.ImageContent | types.EmbeddedResource]],
],
):
logger.debug("Registering handler for CallToolRequest")
async def handler(req: CallToolRequest):
try:
results = await func(req.params.name, (req.params.arguments or {}))
content = []
for result in results:
match result:
case str() as text:
content.append(TextContent(type="text", text=text))
case types.ImageContent() as img:
content.append(
ImageContent(
type="image",
data=img.data,
mimeType=img.mime_type,
)
)
case types.EmbeddedResource() as resource:
content.append(
EmbeddedResource(
type="resource", resource=resource.resource
)
)
return ServerResult(CallToolResult(content=content, isError=False))
except Exception as e:
return ServerResult(
CallToolResult(
content=[TextContent(type="text", text=str(e))],
isError=True,
)
)
self.request_handlers[CallToolRequest] = handler
return func
return decorator
def progress_notification(self):
def decorator(
func: Callable[[str | int, float, float | None], Awaitable[None]],
):
logger.debug("Registering handler for ProgressNotification")
async def handler(req: ProgressNotification):
await func(
req.params.progressToken, req.params.progress, req.params.total
)
self.notification_handlers[ProgressNotification] = handler
return func
return decorator
def completion(self):
"""Provides completions for prompts and resource templates"""
from mcp.types import CompleteResult, Completion, CompletionArgument
def decorator(
func: Callable[
[PromptReference | ResourceReference, CompletionArgument],
Awaitable[Completion | None],
],
):
logger.debug("Registering handler for CompleteRequest")
async def handler(req: CompleteRequest):
completion = await func(req.params.ref, req.params.argument)
return ServerResult(
CompleteResult(
completion=completion
if completion is not None
else Completion(values=[], total=None, hasMore=None),
)
)
self.request_handlers[CompleteRequest] = handler
return func
return decorator
async def run(
self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
initialization_options: types.InitializationOptions,
# When True, exceptions are returned as messages to the client.
# When False, exceptions are raised, which will cause the server to shut down
# but also make tracing exceptions much easier during testing and when using
# in-process servers.
raise_exceptions: bool = False,
):
with warnings.catch_warnings(record=True) as w:
async with ServerSession(
read_stream, write_stream, initialization_options
) as session:
async for message in session.incoming_messages:
logger.debug(f"Received message: {message}")
match message:
case RequestResponder(request=ClientRequest(root=req)):
logger.info(
f"Processing request of type {type(req).__name__}"
)
if type(req) in self.request_handlers:
handler = self.request_handlers[type(req)]
logger.debug(
f"Dispatching request of type {type(req).__name__}"
)
token = None
try:
# Set our global state that can be retrieved via
# app.get_request_context()
token = request_ctx.set(
RequestContext(
message.request_id,
message.request_meta,
session,
)
)
response = await handler(req)
except Exception as err:
if raise_exceptions:
raise err
response = ErrorData(
code=0, message=str(err), data=None
)
finally:
# Reset the global state after we are done
if token is not None:
request_ctx.reset(token)
await message.respond(response)
else:
await message.respond(
ErrorData(
code=METHOD_NOT_FOUND,
message="Method not found",
)
)
logger.debug("Response sent")
case ClientNotification(root=notify):
if type(notify) in self.notification_handlers:
assert type(notify) in self.notification_handlers
handler = self.notification_handlers[type(notify)]
logger.debug(
f"Dispatching notification of type "
f"{type(notify).__name__}"
)
try:
await handler(notify)
except Exception as err:
logger.error(
f"Uncaught exception in notification handler: "
f"{err}"
)
for warning in w:
logger.info(
f"Warning: {warning.category.__name__}: {warning.message}"
)
async def _ping_handler(request: PingRequest) -> ServerResult:
return ServerResult(EmptyResult())

View File

@@ -0,0 +1,50 @@
import importlib.metadata
import logging
import sys
import anyio
from mcp.server.session import ServerSession
from mcp.server.stdio import stdio_server
from mcp.server.types import InitializationOptions
from mcp.types import ServerCapabilities
if not sys.warnoptions:
import warnings
warnings.simplefilter("ignore")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("server")
async def receive_loop(session: ServerSession):
logger.info("Starting receive loop")
async for message in session.incoming_messages:
if isinstance(message, Exception):
logger.error("Error: %s", message)
continue
logger.info("Received message from client: %s", message)
async def main():
version = importlib.metadata.version("mcp")
async with stdio_server() as (read_stream, write_stream):
async with (
ServerSession(
read_stream,
write_stream,
InitializationOptions(
server_name="mcp",
server_version=version,
capabilities=ServerCapabilities(),
),
) as session,
write_stream,
):
await receive_loop(session)
if __name__ == "__main__":
anyio.run(main, backend="trio")

250
src/mcp/server/session.py Normal file
View File

@@ -0,0 +1,250 @@
from enum import Enum
from typing import Any
import anyio
import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
from mcp.server.types import InitializationOptions
from mcp.shared.session import (
BaseSession,
RequestResponder,
)
from mcp.types import (
LATEST_PROTOCOL_VERSION,
ClientNotification,
ClientRequest,
CreateMessageResult,
EmptyResult,
Implementation,
IncludeContext,
InitializedNotification,
InitializeRequest,
InitializeResult,
JSONRPCMessage,
ListRootsResult,
LoggingLevel,
ModelPreferences,
PromptListChangedNotification,
ResourceListChangedNotification,
SamplingMessage,
ServerNotification,
ServerRequest,
ServerResult,
ToolListChangedNotification,
)
class InitializationState(Enum):
NotInitialized = 1
Initializing = 2
Initialized = 3
class ServerSession(
BaseSession[
ServerRequest,
ServerNotification,
ServerResult,
ClientRequest,
ClientNotification,
]
):
_initialized: InitializationState = InitializationState.NotInitialized
def __init__(
self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
init_options: InitializationOptions,
) -> None:
super().__init__(read_stream, write_stream, ClientRequest, ClientNotification)
self._initialization_state = InitializationState.NotInitialized
self._init_options = init_options
async def _received_request(
self, responder: RequestResponder[ClientRequest, ServerResult]
):
match responder.request.root:
case InitializeRequest():
self._initialization_state = InitializationState.Initializing
await responder.respond(
ServerResult(
InitializeResult(
protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=self._init_options.capabilities,
serverInfo=Implementation(
name=self._init_options.server_name,
version=self._init_options.server_version,
),
)
)
)
case _:
if self._initialization_state != InitializationState.Initialized:
raise RuntimeError(
"Received request before initialization was complete"
)
async def _received_notification(self, notification: ClientNotification) -> None:
# Need this to avoid ASYNC910
await anyio.lowlevel.checkpoint()
match notification.root:
case InitializedNotification():
self._initialization_state = InitializationState.Initialized
case _:
if self._initialization_state != InitializationState.Initialized:
raise RuntimeError(
"Received notification before initialization was complete"
)
async def send_log_message(
self, level: LoggingLevel, data: Any, logger: str | None = None
) -> None:
"""Send a log message notification."""
from mcp.types import (
LoggingMessageNotification,
LoggingMessageNotificationParams,
)
await self.send_notification(
ServerNotification(
LoggingMessageNotification(
method="notifications/message",
params=LoggingMessageNotificationParams(
level=level,
data=data,
logger=logger,
),
)
)
)
async def send_resource_updated(self, uri: AnyUrl) -> None:
"""Send a resource updated notification."""
from mcp.types import (
ResourceUpdatedNotification,
ResourceUpdatedNotificationParams,
)
await self.send_notification(
ServerNotification(
ResourceUpdatedNotification(
method="notifications/resources/updated",
params=ResourceUpdatedNotificationParams(uri=uri),
)
)
)
async def create_message(
self,
messages: list[SamplingMessage],
*,
max_tokens: int,
system_prompt: str | None = None,
include_context: IncludeContext | None = None,
temperature: float | None = None,
stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None,
model_preferences: ModelPreferences | None = None,
) -> CreateMessageResult:
"""Send a sampling/create_message request."""
from mcp.types import (
CreateMessageRequest,
CreateMessageRequestParams,
)
return await self.send_request(
ServerRequest(
CreateMessageRequest(
method="sampling/createMessage",
params=CreateMessageRequestParams(
messages=messages,
systemPrompt=system_prompt,
includeContext=include_context,
temperature=temperature,
maxTokens=max_tokens,
stopSequences=stop_sequences,
metadata=metadata,
modelPreferences=model_preferences,
),
)
),
CreateMessageResult,
)
async def list_roots(self) -> ListRootsResult:
"""Send a roots/list request."""
from mcp.types import ListRootsRequest
return await self.send_request(
ServerRequest(
ListRootsRequest(
method="roots/list",
)
),
ListRootsResult,
)
async def send_ping(self) -> EmptyResult:
"""Send a ping request."""
from mcp.types import PingRequest
return await self.send_request(
ServerRequest(
PingRequest(
method="ping",
)
),
EmptyResult,
)
async def send_progress_notification(
self, progress_token: str | int, progress: float, total: float | None = None
) -> None:
"""Send a progress notification."""
from mcp.types import ProgressNotification, ProgressNotificationParams
await self.send_notification(
ServerNotification(
ProgressNotification(
method="notifications/progress",
params=ProgressNotificationParams(
progressToken=progress_token,
progress=progress,
total=total,
),
)
)
)
async def send_resource_list_changed(self) -> None:
"""Send a resource list changed notification."""
await self.send_notification(
ServerNotification(
ResourceListChangedNotification(
method="notifications/resources/list_changed",
)
)
)
async def send_tool_list_changed(self) -> None:
"""Send a tool list changed notification."""
await self.send_notification(
ServerNotification(
ToolListChangedNotification(
method="notifications/tools/list_changed",
)
)
)
async def send_prompt_list_changed(self) -> None:
"""Send a prompt list changed notification."""
await self.send_notification(
ServerNotification(
PromptListChangedNotification(
method="notifications/prompts/list_changed",
)
)
)

140
src/mcp/server/sse.py Normal file
View File

@@ -0,0 +1,140 @@
import logging
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import quote
from uuid import UUID, uuid4
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import ValidationError
from sse_starlette import EventSourceResponse
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import Receive, Scope, Send
from mcp.types import JSONRPCMessage
logger = logging.getLogger(__name__)
class SseServerTransport:
"""
SSE server transport for MCP. This class provides _two_ ASGI applications,
suitable to be used with a framework like Starlette and a server like Hypercorn:
1. connect_sse() is an ASGI application which receives incoming GET requests,
and sets up a new SSE stream to send server messages to the client.
2. handle_post_message() is an ASGI application which receives incoming POST
requests, which should contain client messages that link to a
previously-established SSE session.
"""
_endpoint: str
_read_stream_writers: dict[UUID, MemoryObjectSendStream[JSONRPCMessage | Exception]]
def __init__(self, endpoint: str) -> None:
"""
Creates a new SSE server transport, which will direct the client to POST
messages to the relative or absolute URL given.
"""
super().__init__()
self._endpoint = endpoint
self._read_stream_writers = {}
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")
@asynccontextmanager
async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] != "http":
logger.error("connect_sse received non-HTTP request")
raise ValueError("connect_sse can only handle HTTP requests")
logger.debug("Setting up SSE connection")
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
write_stream: MemoryObjectSendStream[JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
session_id = uuid4()
session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}"
self._read_stream_writers[session_id] = read_stream_writer
logger.debug(f"Created new session with ID: {session_id}")
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream(
0, dict[str, Any]
)
async def sse_writer():
logger.debug("Starting SSE writer")
async with sse_stream_writer, write_stream_reader:
await sse_stream_writer.send({"event": "endpoint", "data": session_uri})
logger.debug(f"Sent endpoint event: {session_uri}")
async for message in write_stream_reader:
logger.debug(f"Sending message via SSE: {message}")
await sse_stream_writer.send(
{
"event": "message",
"data": message.model_dump_json(
by_alias=True, exclude_none=True
),
}
)
async with anyio.create_task_group() as tg:
response = EventSourceResponse(
content=sse_stream_reader, data_sender_callable=sse_writer
)
logger.debug("Starting SSE response task")
tg.start_soon(response, scope, receive, send)
logger.debug("Yielding read and write streams")
yield (read_stream, write_stream)
async def handle_post_message(
self, scope: Scope, receive: Receive, send: Send
) -> None:
logger.debug("Handling POST message")
request = Request(scope, receive)
session_id_param = request.query_params.get("session_id")
if session_id_param is None:
logger.warning("Received request without session_id")
response = Response("session_id is required", status_code=400)
return await response(scope, receive, send)
try:
session_id = UUID(hex=session_id_param)
logger.debug(f"Parsed session ID: {session_id}")
except ValueError:
logger.warning(f"Received invalid session ID: {session_id_param}")
response = Response("Invalid session ID", status_code=400)
return await response(scope, receive, send)
writer = self._read_stream_writers.get(session_id)
if not writer:
logger.warning(f"Could not find session for ID: {session_id}")
response = Response("Could not find session", status_code=404)
return await response(scope, receive, send)
json = await request.json()
logger.debug(f"Received JSON: {json}")
try:
message = JSONRPCMessage.model_validate(json)
logger.debug(f"Validated client message: {message}")
except ValidationError as err:
logger.error(f"Failed to parse message: {err}")
response = Response("Could not parse message", status_code=400)
await response(scope, receive, send)
await writer.send(err)
return
logger.debug(f"Sending message to writer: {message}")
response = Response("Accepted", status_code=202)
await response(scope, receive, send)
await writer.send(message)

63
src/mcp/server/stdio.py Normal file
View File

@@ -0,0 +1,63 @@
import sys
from contextlib import asynccontextmanager
import anyio
import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp.types import JSONRPCMessage
@asynccontextmanager
async def stdio_server(
stdin: anyio.AsyncFile[str] | None = None,
stdout: anyio.AsyncFile[str] | None = None,
):
"""
Server transport for stdio: this communicates with an MCP client by reading
from the current process' stdin and writing to stdout.
"""
# Purposely not using context managers for these, as we don't want to close
# standard process handles.
if not stdin:
stdin = anyio.wrap_file(sys.stdin)
if not stdout:
stdout = anyio.wrap_file(sys.stdout)
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
write_stream: MemoryObjectSendStream[JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
async def stdin_reader():
try:
async with read_stream_writer:
async for line in stdin:
try:
message = JSONRPCMessage.model_validate_json(line)
except Exception as exc:
await read_stream_writer.send(exc)
continue
await read_stream_writer.send(message)
except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint()
async def stdout_writer():
try:
async with write_stream_reader:
async for message in write_stream_reader:
json = message.model_dump_json(by_alias=True, exclude_none=True)
await stdout.write(json + "\n")
await stdout.flush()
except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint()
async with anyio.create_task_group() as tg:
tg.start_soon(stdin_reader)
tg.start_soon(stdout_writer)
yield read_stream, write_stream

46
src/mcp/server/types.py Normal file
View File

@@ -0,0 +1,46 @@
"""
This module provides simpler types to use with the server for managing prompts
and tools.
"""
from dataclasses import dataclass
from typing import Literal
from pydantic import BaseModel
from mcp.types import (
BlobResourceContents,
Role,
ServerCapabilities,
TextResourceContents,
)
@dataclass
class ImageContent:
type: Literal["image"]
data: str
mime_type: str
@dataclass
class EmbeddedResource:
resource: TextResourceContents | BlobResourceContents
@dataclass
class Message:
role: Role
content: str | ImageContent | EmbeddedResource
@dataclass
class PromptResponse:
messages: list[Message]
desc: str | None = None
class InitializationOptions(BaseModel):
server_name: str
server_version: str
capabilities: ServerCapabilities

View File

@@ -0,0 +1,61 @@
import logging
from contextlib import asynccontextmanager
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from starlette.types import Receive, Scope, Send
from starlette.websockets import WebSocket
from mcp.types import JSONRPCMessage
logger = logging.getLogger(__name__)
@asynccontextmanager
async def websocket_server(scope: Scope, receive: Receive, send: Send):
"""
WebSocket server transport for MCP. This is an ASGI application, suitable to be
used with a framework like Starlette and a server like Hypercorn.
"""
websocket = WebSocket(scope, receive, send)
await websocket.accept(subprotocol="mcp")
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
write_stream: MemoryObjectSendStream[JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
async def ws_reader():
try:
async with read_stream_writer:
async for message in websocket.iter_json():
try:
client_message = JSONRPCMessage.model_validate(message)
except Exception as exc:
await read_stream_writer.send(exc)
continue
await read_stream_writer.send(client_message)
except anyio.ClosedResourceError:
await websocket.close()
async def ws_writer():
try:
async with write_stream_reader:
async for message in write_stream_reader:
obj = message.model_dump(
by_alias=True, mode="json", exclude_none=True
)
await websocket.send_json(obj)
except anyio.ClosedResourceError:
await websocket.close()
async with anyio.create_task_group() as tg:
tg.start_soon(ws_reader)
tg.start_soon(ws_writer)
yield (read_stream, write_stream)

View File

14
src/mcp/shared/context.py Normal file
View File

@@ -0,0 +1,14 @@
from dataclasses import dataclass
from typing import Generic, TypeVar
from mcp.shared.session import BaseSession
from mcp.types import RequestId, RequestParams
SessionT = TypeVar("SessionT", bound=BaseSession)
@dataclass
class RequestContext(Generic[SessionT]):
request_id: RequestId
meta: RequestParams.Meta | None
session: SessionT

View File

@@ -0,0 +1,9 @@
from mcp.types import ErrorData
class McpError(Exception):
"""
Exception type raised when an error arrives over an MCP connection.
"""
error: ErrorData

87
src/mcp/shared/memory.py Normal file
View File

@@ -0,0 +1,87 @@
"""
In-memory transports
"""
from contextlib import asynccontextmanager
from datetime import timedelta
from typing import AsyncGenerator
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp.client.session import ClientSession
from mcp.server import Server
from mcp.types import JSONRPCMessage
MessageStream = tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage],
]
@asynccontextmanager
async def create_client_server_memory_streams() -> (
AsyncGenerator[tuple[MessageStream, MessageStream], None]
):
"""
Creates a pair of bidirectional memory streams for client-server communication.
Returns:
A tuple of (client_streams, server_streams) where each is a tuple of
(read_stream, write_stream)
"""
# Create streams for both directions
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
JSONRPCMessage | Exception
](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
JSONRPCMessage | Exception
](1)
client_streams = (server_to_client_receive, client_to_server_send)
server_streams = (client_to_server_receive, server_to_client_send)
async with (
server_to_client_receive,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
):
yield client_streams, server_streams
@asynccontextmanager
async def create_connected_server_and_client_session(
server: Server,
read_timeout_seconds: timedelta | None = None,
raise_exceptions: bool = False,
) -> AsyncGenerator[ClientSession, None]:
"""Creates a ClientSession that is connected to a running MCP server."""
async with create_client_server_memory_streams() as (
client_streams,
server_streams,
):
client_read, client_write = client_streams
server_read, server_write = server_streams
# Create a cancel scope for the server task
async with anyio.create_task_group() as tg:
tg.start_soon(
lambda: server.run(
server_read,
server_write,
server.create_initialization_options(),
raise_exceptions=raise_exceptions,
)
)
try:
async with ClientSession(
read_stream=client_read,
write_stream=client_write,
read_timeout_seconds=read_timeout_seconds,
) as client_session:
await client_session.initialize()
yield client_session
finally:
tg.cancel_scope.cancel()

View File

@@ -0,0 +1,40 @@
from contextlib import contextmanager
from dataclasses import dataclass, field
from pydantic import BaseModel
from mcp.shared.context import RequestContext
from mcp.shared.session import BaseSession
from mcp.types import ProgressToken
class Progress(BaseModel):
progress: float
total: float | None
@dataclass
class ProgressContext:
session: BaseSession
progress_token: ProgressToken
total: float | None
current: float = field(default=0.0, init=False)
async def progress(self, amount: float) -> None:
self.current += amount
await self.session.send_progress_notification(
self.progress_token, self.current, total=self.total
)
@contextmanager
def progress(ctx: RequestContext, total: float | None = None):
if ctx.meta is None or ctx.meta.progressToken is None:
raise ValueError("No progress token provided")
progress_ctx = ProgressContext(ctx.session, ctx.meta.progressToken, total)
try:
yield progress_ctx
finally:
pass

288
src/mcp/shared/session.py Normal file
View File

@@ -0,0 +1,288 @@
from contextlib import AbstractAsyncContextManager
from datetime import timedelta
from typing import Generic, TypeVar
import anyio
import anyio.lowlevel
import httpx
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import BaseModel
from mcp.shared.exceptions import McpError
from mcp.types import (
ClientNotification,
ClientRequest,
ClientResult,
ErrorData,
JSONRPCError,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
RequestParams,
ServerNotification,
ServerRequest,
ServerResult,
)
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
ReceiveNotificationT = TypeVar(
"ReceiveNotificationT", ClientNotification, ServerNotification
)
RequestId = str | int
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
def __init__(
self,
request_id: RequestId,
request_meta: RequestParams.Meta | None,
request: ReceiveRequestT,
session: "BaseSession",
) -> None:
self.request_id = request_id
self.request_meta = request_meta
self.request = request
self._session = session
self._responded = False
async def respond(self, response: SendResultT | ErrorData) -> None:
assert not self._responded, "Request already responded to"
self._responded = True
await self._session._send_response(
request_id=self.request_id, response=response
)
class BaseSession(
AbstractAsyncContextManager,
Generic[
SendRequestT,
SendNotificationT,
SendResultT,
ReceiveRequestT,
ReceiveNotificationT,
],
):
"""
Implements an MCP "session" on top of read/write streams, including features
like request/response linking, notifications, and progress.
This class is an async context manager that automatically starts processing
messages when entered.
"""
_response_streams: dict[
RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]
]
_request_id: int
def __init__(
self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
receive_request_type: type[ReceiveRequestT],
receive_notification_type: type[ReceiveNotificationT],
# If none, reading will never time out
read_timeout_seconds: timedelta | None = None,
) -> None:
self._read_stream = read_stream
self._write_stream = write_stream
self._response_streams = {}
self._request_id = 0
self._receive_request_type = receive_request_type
self._receive_notification_type = receive_notification_type
self._read_timeout_seconds = read_timeout_seconds
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
anyio.create_memory_object_stream[
RequestResponder[ReceiveRequestT, SendResultT]
| ReceiveNotificationT
| Exception
]()
)
async def __aenter__(self):
self._task_group = anyio.create_task_group()
await self._task_group.__aenter__()
self._task_group.start_soon(self._receive_loop)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
# Using BaseSession as a context manager should not block on exit (this
# would be very surprising behavior), so make sure to cancel the tasks
# in the task group.
self._task_group.cancel_scope.cancel()
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
async def send_request(
self,
request: SendRequestT,
result_type: type[ReceiveResultT],
) -> ReceiveResultT:
"""
Sends a request and wait for a response. Raises an McpError if the
response contains an error.
Do not use this method to emit notifications! Use send_notification()
instead.
"""
request_id = self._request_id
self._request_id = request_id + 1
response_stream, response_stream_reader = anyio.create_memory_object_stream[
JSONRPCResponse | JSONRPCError
](1)
self._response_streams[request_id] = response_stream
jsonrpc_request = JSONRPCRequest(
jsonrpc="2.0",
id=request_id,
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
)
# TODO: Support progress callbacks
await self._write_stream.send(JSONRPCMessage(jsonrpc_request))
try:
with anyio.fail_after(
None
if self._read_timeout_seconds is None
else self._read_timeout_seconds.total_seconds()
):
response_or_error = await response_stream_reader.receive()
except TimeoutError:
raise McpError(
ErrorData(
code=httpx.codes.REQUEST_TIMEOUT,
message=(
f"Timed out while waiting for response to "
f"{request.__class__.__name__}. Waited "
f"{self._read_timeout_seconds} seconds."
),
)
)
if isinstance(response_or_error, JSONRPCError):
raise McpError(response_or_error.error)
else:
return result_type.model_validate(response_or_error.result)
async def send_notification(self, notification: SendNotificationT) -> None:
"""
Emits a notification, which is a one-way message that does not expect
a response.
"""
jsonrpc_notification = JSONRPCNotification(
jsonrpc="2.0",
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
)
await self._write_stream.send(JSONRPCMessage(jsonrpc_notification))
async def _send_response(
self, request_id: RequestId, response: SendResultT | ErrorData
) -> None:
if isinstance(response, ErrorData):
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
await self._write_stream.send(JSONRPCMessage(jsonrpc_error))
else:
jsonrpc_response = JSONRPCResponse(
jsonrpc="2.0",
id=request_id,
result=response.model_dump(
by_alias=True, mode="json", exclude_none=True
),
)
await self._write_stream.send(JSONRPCMessage(jsonrpc_response))
async def _receive_loop(self) -> None:
async with (
self._read_stream,
self._write_stream,
self._incoming_message_stream_writer,
):
async for message in self._read_stream:
if isinstance(message, Exception):
await self._incoming_message_stream_writer.send(message)
elif isinstance(message.root, JSONRPCRequest):
validated_request = self._receive_request_type.model_validate(
message.root.model_dump(
by_alias=True, mode="json", exclude_none=True
)
)
responder = RequestResponder(
request_id=message.root.id,
request_meta=validated_request.root.params._meta
if validated_request.root.params
else None,
request=validated_request,
session=self,
)
await self._received_request(responder)
if not responder._responded:
await self._incoming_message_stream_writer.send(responder)
elif isinstance(message.root, JSONRPCNotification):
notification = self._receive_notification_type.model_validate(
message.root.model_dump(
by_alias=True, mode="json", exclude_none=True
)
)
await self._received_notification(notification)
await self._incoming_message_stream_writer.send(notification)
else: # Response or error
stream = self._response_streams.pop(message.root.id, None)
if stream:
await stream.send(message.root)
else:
await self._incoming_message_stream_writer.send(
RuntimeError(
"Received response with an unknown "
f"request ID: {message}"
)
)
async def _received_request(
self, responder: RequestResponder[ReceiveRequestT, SendResultT]
) -> None:
"""
Can be overridden by subclasses to handle a request without needing to
listen on the message stream.
If the request is responded to within this method, it will not be
forwarded on to the message stream.
"""
async def _received_notification(self, notification: ReceiveNotificationT) -> None:
"""
Can be overridden by subclasses to handle a notification without needing
to listen on the message stream.
"""
async def send_progress_notification(
self, progress_token: str | int, progress: float, total: float | None = None
) -> None:
"""
Sends a progress notification for a request that is currently being
processed.
"""
@property
def incoming_messages(
self,
) -> MemoryObjectReceiveStream[
RequestResponder[ReceiveRequestT, SendResultT]
| ReceiveNotificationT
| Exception
]:
return self._incoming_message_stream_reader

View File

@@ -0,0 +1,3 @@
from mcp.types import LATEST_PROTOCOL_VERSION
SUPPORTED_PROTOCOL_VERSIONS = [1, LATEST_PROTOCOL_VERSION]

1041
src/mcp/types.py Normal file

File diff suppressed because it is too large Load Diff