From 4cbf8154306aa5b96b2bb3fc83ac5984d217a0f5 Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Tue, 24 Sep 2024 22:04:19 +0100 Subject: [PATCH] Initial import --- .devcontainer/devcontainer.json | 22 + .gitignore | 162 ++++++++ README.md | 2 + mcp_python/__init__.py | 104 +++++ mcp_python/client/__init__.py | 0 mcp_python/client/__main__.py | 76 ++++ mcp_python/client/session.py | 211 ++++++++++ mcp_python/client/sse.py | 129 ++++++ mcp_python/client/stdio.py | 84 ++++ mcp_python/py.typed | 0 mcp_python/server/__init__.py | 347 ++++++++++++++++ mcp_python/server/__main__.py | 35 ++ mcp_python/server/session.py | 203 +++++++++ mcp_python/server/sse.py | 133 ++++++ mcp_python/server/stdio.py | 60 +++ mcp_python/server/types.py | 27 ++ mcp_python/server/websocket.py | 58 +++ mcp_python/shared/__init__.py | 0 mcp_python/shared/context.py | 14 + mcp_python/shared/exceptions.py | 9 + mcp_python/shared/progress.py | 40 ++ mcp_python/shared/session.py | 244 +++++++++++ mcp_python/shared/version.py | 1 + mcp_python/types.py | 709 ++++++++++++++++++++++++++++++++ pyproject.toml | 34 ++ tests/__init__.py | 0 tests/client/__init__.py | 0 tests/client/test_session.py | 93 +++++ tests/client/test_stdio.py | 38 ++ tests/server/__init__.py | 0 tests/server/test_session.py | 59 +++ tests/server/test_stdio.py | 68 +++ tests/test_types.py | 24 ++ 33 files changed, 2986 insertions(+) create mode 100644 .devcontainer/devcontainer.json create mode 100644 .gitignore create mode 100644 README.md create mode 100644 mcp_python/__init__.py create mode 100644 mcp_python/client/__init__.py create mode 100644 mcp_python/client/__main__.py create mode 100644 mcp_python/client/session.py create mode 100644 mcp_python/client/sse.py create mode 100644 mcp_python/client/stdio.py create mode 100644 mcp_python/py.typed create mode 100644 mcp_python/server/__init__.py create mode 100644 mcp_python/server/__main__.py create mode 100644 mcp_python/server/session.py create mode 100644 mcp_python/server/sse.py create mode 100644 mcp_python/server/stdio.py create mode 100644 mcp_python/server/types.py create mode 100644 mcp_python/server/websocket.py create mode 100644 mcp_python/shared/__init__.py create mode 100644 mcp_python/shared/context.py create mode 100644 mcp_python/shared/exceptions.py create mode 100644 mcp_python/shared/progress.py create mode 100644 mcp_python/shared/session.py create mode 100644 mcp_python/shared/version.py create mode 100644 mcp_python/types.py create mode 100644 pyproject.toml create mode 100644 tests/__init__.py create mode 100644 tests/client/__init__.py create mode 100644 tests/client/test_session.py create mode 100644 tests/client/test_stdio.py create mode 100644 tests/server/__init__.py create mode 100644 tests/server/test_session.py create mode 100644 tests/server/test_stdio.py create mode 100644 tests/test_types.py diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..1fc5955 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,22 @@ +// For format details, see https://aka.ms/devcontainer.json. For config options, see the +// README at: https://github.com/devcontainers/templates/tree/main/src/python +{ + "name": "Python 3", + // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile + "image": "mcr.microsoft.com/devcontainers/python:1-3.12-bookworm" + + // Features to add to the dev container. More info: https://containers.dev/features. + // "features": {}, + + // Use 'forwardPorts' to make a list of ports inside the container available locally. + // "forwardPorts": [], + + // Use 'postCreateCommand' to run commands after the container is created. + // "postCreateCommand": "pip3 install --user -r requirements.txt", + + // Configure tool-specific properties. + // "customizations": {}, + + // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. + // "remoteUser": "root" +} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..82f9275 --- /dev/null +++ b/.gitignore @@ -0,0 +1,162 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..317d946 --- /dev/null +++ b/README.md @@ -0,0 +1,2 @@ +# mcp-python +Model Context Protocol implementation for Python diff --git a/mcp_python/__init__.py b/mcp_python/__init__.py new file mode 100644 index 0000000..2285847 --- /dev/null +++ b/mcp_python/__init__.py @@ -0,0 +1,104 @@ +from .client.session import ClientSession +from .client.stdio import StdioServerParameters, stdio_client +from .server.session import ServerSession +from .server.stdio import stdio_server +from .shared.exceptions import McpError +from .types import ( + CallToolRequest, + ClientCapabilities, + ClientNotification, + ClientRequest, + ClientResult, + CompleteRequest, + CreateMessageRequest, + CreateMessageResult, + ErrorData, + GetPromptRequest, + GetPromptResult, + Implementation, + IncludeContext, + InitializedNotification, + InitializeRequest, + InitializeResult, + JSONRPCError, + JSONRPCRequest, + JSONRPCResponse, + ListPromptsRequest, + ListPromptsResult, + ListResourcesRequest, + ListResourcesResult, + ListToolsResult, + LoggingLevel, + LoggingMessageNotification, + Notification, + PingRequest, + ProgressNotification, + ReadResourceRequest, + ReadResourceResult, + Resource, + ResourceUpdatedNotification, + Role as SamplingRole, + SamplingMessage, + ServerCapabilities, + ServerNotification, + ServerRequest, + ServerResult, + SetLevelRequest, + StopReason, + SubscribeRequest, + Tool, + UnsubscribeRequest, +) + +__all__ = [ + "CallToolRequest", + "ClientCapabilities", + "ClientNotification", + "ClientRequest", + "ClientResult", + "ClientSession", + "CreateMessageRequest", + "CreateMessageResult", + "ErrorData", + "GetPromptRequest", + "GetPromptResult", + "Implementation", + "IncludeContext", + "InitializeRequest", + "InitializeResult", + "InitializedNotification", + "JSONRPCError", + "JSONRPCRequest", + "ListPromptsRequest", + "ListPromptsResult", + "ListResourcesRequest", + "ListResourcesResult", + "ListToolsResult", + "LoggingLevel", + "LoggingMessageNotification", + "McpError", + "Notification", + "PingRequest", + "ProgressNotification", + "ReadResourceRequest", + "ReadResourceResult", + "ResourceUpdatedNotification", + "Resource", + "SamplingMessage", + "SamplingRole", + "ServerCapabilities", + "ServerNotification", + "ServerRequest", + "ServerResult", + "ServerSession", + "SetLevelRequest", + "StdioServerParameters", + "StopReason", + "SubscribeRequest", + "Tool", + "UnsubscribeRequest", + "stdio_client", + "stdio_server", + "CompleteRequest", + "JSONRPCResponse", +] diff --git a/mcp_python/client/__init__.py b/mcp_python/client/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mcp_python/client/__main__.py b/mcp_python/client/__main__.py new file mode 100644 index 0000000..e89f5fe --- /dev/null +++ b/mcp_python/client/__main__.py @@ -0,0 +1,76 @@ +import logging +import sys +from functools import partial +from urllib.parse import urlparse + +import anyio +import click + +from mcp_python.client.session import ClientSession +from mcp_python.client.sse import sse_client +from mcp_python.client.stdio import StdioServerParameters, stdio_client + +if not sys.warnoptions: + import warnings + + warnings.simplefilter("ignore") + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("client") + + +async def receive_loop(session: ClientSession): + logger.info("Starting receive loop") + async for message in session.incoming_messages: + if isinstance(message, Exception): + logger.error("Error: %s", message) + continue + + logger.info("Received message from server: %s", message) + + +async def run_session(read_stream, write_stream): + async with ( + ClientSession(read_stream, write_stream) as session, + anyio.create_task_group() as tg, + ): + tg.start_soon(receive_loop, session) + + logger.info("Initializing session") + await session.initialize() + logger.info("Initialized") + + +async def main(command_or_url: str, args: list[str], env: list[tuple[str, str]]): + env_dict = dict(env) + + if urlparse(command_or_url).scheme in ("http", "https"): + # Use SSE client for HTTP(S) URLs + async with sse_client(command_or_url) as streams: + await run_session(*streams) + else: + # Use stdio client for commands + server_parameters = StdioServerParameters( + command=command_or_url, args=args, env=env_dict + ) + async with stdio_client(server_parameters) as streams: + await run_session(*streams) + + +@click.command() +@click.argument("command_or_url") +@click.argument("args", nargs=-1) +@click.option( + "--env", + "-e", + multiple=True, + nargs=2, + metavar="KEY VALUE", + help="Environment variables to set. Can be used multiple times.", +) +def cli(*args, **kwargs): + anyio.run(partial(main, *args, **kwargs), backend="trio") + + +if __name__ == "__main__": + cli() diff --git a/mcp_python/client/session.py b/mcp_python/client/session.py new file mode 100644 index 0000000..5eab70e --- /dev/null +++ b/mcp_python/client/session.py @@ -0,0 +1,211 @@ +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import AnyUrl + +from mcp_python.shared.session import BaseSession +from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSION +from mcp_python.types import ( + CallToolResult, + ClientCapabilities, + ClientNotification, + ClientRequest, + ClientResult, + EmptyResult, + Implementation, + InitializedNotification, + InitializeResult, + JSONRPCMessage, + ListResourcesResult, + LoggingLevel, + ReadResourceResult, + ServerNotification, + ServerRequest, +) + + +class ClientSession( + BaseSession[ + ClientRequest, + ClientNotification, + ClientResult, + ServerRequest, + ServerNotification, + ] +): + def __init__( + self, + read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], + write_stream: MemoryObjectSendStream[JSONRPCMessage], + ) -> None: + super().__init__(read_stream, write_stream, ServerRequest, ServerNotification) + + async def initialize(self) -> InitializeResult: + from mcp_python.types import ( + InitializeRequest, + InitializeRequestParams, + ) + + result = await self.send_request( + ClientRequest( + InitializeRequest( + method="initialize", + params=InitializeRequestParams( + protocolVersion=SUPPORTED_PROTOCOL_VERSION, + capabilities=ClientCapabilities( + sampling=None, experimental=None + ), + clientInfo=Implementation(name="mcp_python", version="0.1.0"), + ), + ) + ), + InitializeResult, + ) + + if result.protocolVersion != SUPPORTED_PROTOCOL_VERSION: + raise RuntimeError( + f"Unsupported protocol version from the server: {result.protocolVersion}" + ) + + await self.send_notification( + ClientNotification( + InitializedNotification(method="notifications/initialized") + ) + ) + + return result + + async def send_ping(self) -> EmptyResult: + """Send a ping request.""" + from mcp_python.types import PingRequest + + return await self.send_request( + ClientRequest( + PingRequest( + method="ping", + ) + ), + EmptyResult, + ) + + async def send_progress_notification( + self, progress_token: str | int, progress: float, total: float | None = None + ) -> None: + """Send a progress notification.""" + from mcp_python.types import ( + ProgressNotification, + ProgressNotificationParams, + ) + + await self.send_notification( + ClientNotification( + ProgressNotification( + method="notifications/progress", + params=ProgressNotificationParams( + progressToken=progress_token, + progress=progress, + total=total, + ), + ), + ) + ) + + async def set_logging_level(self, level: LoggingLevel) -> EmptyResult: + """Send a logging/setLevel request.""" + from mcp_python.types import ( + SetLevelRequest, + SetLevelRequestParams, + ) + + return await self.send_request( + ClientRequest( + SetLevelRequest( + method="logging/setLevel", + params=SetLevelRequestParams(level=level), + ) + ), + EmptyResult, + ) + + async def list_resources(self) -> ListResourcesResult: + """Send a resources/list request.""" + from mcp_python.types import ( + ListResourcesRequest, + ) + + return await self.send_request( + ClientRequest( + ListResourcesRequest( + method="resources/list", + ) + ), + ListResourcesResult, + ) + + async def read_resource(self, uri: AnyUrl) -> ReadResourceResult: + """Send a resources/read request.""" + from mcp_python.types import ( + ReadResourceRequest, + ReadResourceRequestParams, + ) + + return await self.send_request( + ClientRequest( + ReadResourceRequest( + method="resources/read", + params=ReadResourceRequestParams(uri=uri), + ) + ), + ReadResourceResult, + ) + + async def subscribe_resource(self, uri: AnyUrl) -> EmptyResult: + """Send a resources/subscribe request.""" + from mcp_python.types import ( + SubscribeRequest, + SubscribeRequestParams, + ) + + return await self.send_request( + ClientRequest( + SubscribeRequest( + method="resources/subscribe", + params=SubscribeRequestParams(uri=uri), + ) + ), + EmptyResult, + ) + + async def unsubscribe_resource(self, uri: AnyUrl) -> EmptyResult: + """Send a resources/unsubscribe request.""" + from mcp_python.types import ( + UnsubscribeRequest, + UnsubscribeRequestParams, + ) + + return await self.send_request( + ClientRequest( + UnsubscribeRequest( + method="resources/unsubscribe", + params=UnsubscribeRequestParams(uri=uri), + ) + ), + EmptyResult, + ) + + async def call_tool( + self, name: str, arguments: dict | None = None + ) -> CallToolResult: + """Send a tools/call request.""" + from mcp_python.types import ( + CallToolRequest, + CallToolRequestParams, + ) + + return await self.send_request( + ClientRequest( + CallToolRequest( + method="tools/call", + params=CallToolRequestParams(name=name, arguments=arguments), + ) + ), + CallToolResult, + ) diff --git a/mcp_python/client/sse.py b/mcp_python/client/sse.py new file mode 100644 index 0000000..918fa96 --- /dev/null +++ b/mcp_python/client/sse.py @@ -0,0 +1,129 @@ +import logging +from contextlib import asynccontextmanager +from urllib.parse import urljoin, urlparse + +import anyio +import httpx +from anyio.abc import TaskStatus +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from httpx_sse import aconnect_sse + +from mcp_python.types import JSONRPCMessage + +logger = logging.getLogger(__name__) + + +def remove_request_params(url: str) -> str: + return urljoin(url, urlparse(url).path) + + +@asynccontextmanager +async def sse_client(url: str, timeout: float = 5, sse_read_timeout: float = 60 * 5): + """ + Client transport for SSE. + + `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. + """ + read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] + + write_stream: MemoryObjectSendStream[JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] + + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + + async with anyio.create_task_group() as tg: + try: + logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") + async with httpx.AsyncClient() as client: + async with aconnect_sse( + client, + "GET", + url, + timeout=httpx.Timeout(timeout, read=sse_read_timeout), + ) as event_source: + event_source.response.raise_for_status() + logger.debug("SSE connection established") + + async def sse_reader( + task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED, + ): + try: + async for sse in event_source.aiter_sse(): + logger.debug(f"Received SSE event: {sse.event}") + match sse.event: + case "endpoint": + endpoint_url = urljoin(url, sse.data) + logger.info( + f"Received endpoint URL: {endpoint_url}" + ) + + url_parsed = urlparse(url) + endpoint_parsed = urlparse(endpoint_url) + if ( + url_parsed.netloc != endpoint_parsed.netloc + or url_parsed.scheme + != endpoint_parsed.scheme + ): + error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}" + logger.error(error_msg) + raise ValueError(error_msg) + + task_status.started(endpoint_url) + + case "message": + try: + message = ( + JSONRPCMessage.model_validate_json( + sse.data + ) + ) + logger.debug( + f"Received server message: {message}" + ) + except Exception as exc: + logger.error( + f"Error parsing server message: {exc}" + ) + await read_stream_writer.send(exc) + continue + + await read_stream_writer.send(message) + except Exception as exc: + logger.error(f"Error in sse_reader: {exc}") + await read_stream_writer.send(exc) + finally: + await read_stream_writer.aclose() + + async def post_writer(endpoint_url: str): + try: + async with write_stream_reader: + async for message in write_stream_reader: + logger.debug(f"Sending client message: {message}") + response = await client.post( + endpoint_url, + json=message.model_dump(by_alias=True, mode="json"), + ) + response.raise_for_status() + logger.debug( + f"Client message sent successfully: {response.status_code}" + ) + except Exception as exc: + logger.error(f"Error in post_writer: {exc}") + finally: + await write_stream.aclose() + + endpoint_url = await tg.start(sse_reader) + logger.info( + f"Starting post writer with endpoint URL: {endpoint_url}" + ) + tg.start_soon(post_writer, endpoint_url) + + try: + yield read_stream, write_stream + finally: + tg.cancel_scope.cancel() + finally: + await read_stream_writer.aclose() + await write_stream.aclose() diff --git a/mcp_python/client/stdio.py b/mcp_python/client/stdio.py new file mode 100644 index 0000000..2c71064 --- /dev/null +++ b/mcp_python/client/stdio.py @@ -0,0 +1,84 @@ +import sys +from contextlib import asynccontextmanager + +import anyio +import anyio.lowlevel +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from anyio.streams.text import TextReceiveStream +from pydantic import BaseModel, Field + +from mcp_python.types import JSONRPCMessage + + +class StdioServerParameters(BaseModel): + command: str + """The executable to run to start the server.""" + + args: list[str] = Field(default_factory=list) + """Command line arguments to pass to the executable.""" + + env: dict[str, str] = Field(default_factory=dict) + """ + The environment to use when spawning the process. + + The environment is NOT inherited from the parent process by default. + """ + + +@asynccontextmanager +async def stdio_client(server: StdioServerParameters): + """ + Client transport for stdio: this will connect to a server by spawning a process and communicating with it over stdin/stdout. + """ + read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] + + write_stream: MemoryObjectSendStream[JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] + + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + + process = await anyio.open_process( + [server.command, *server.args], env=server.env, stderr=sys.stderr + ) + + async def stdout_reader(): + assert process.stdout, "Opened process is missing stdout" + + try: + async with read_stream_writer: + buffer = "" + async for chunk in TextReceiveStream(process.stdout): + lines = (buffer + chunk).split("\n") + buffer = lines.pop() + + for line in lines: + try: + message = JSONRPCMessage.model_validate_json(line) + except Exception as exc: + await read_stream_writer.send(exc) + continue + + await read_stream_writer.send(message) + except anyio.ClosedResourceError: + await anyio.lowlevel.checkpoint() + + async def stdin_writer(): + assert process.stdin, "Opened process is missing stdin" + + try: + async with write_stream_reader: + async for message in write_stream_reader: + json = message.model_dump_json(by_alias=True) + await process.stdin.send((json + "\n").encode()) + except anyio.ClosedResourceError: + await anyio.lowlevel.checkpoint() + + async with ( + anyio.create_task_group() as tg, + process, + ): + tg.start_soon(stdout_reader) + tg.start_soon(stdin_writer) + yield read_stream, write_stream diff --git a/mcp_python/py.typed b/mcp_python/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/mcp_python/server/__init__.py b/mcp_python/server/__init__.py new file mode 100644 index 0000000..375cbfe --- /dev/null +++ b/mcp_python/server/__init__.py @@ -0,0 +1,347 @@ +import contextvars +import logging +import warnings +from collections.abc import Awaitable, Callable +from typing import Any + +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import AnyUrl + +from mcp_python.server import types +from mcp_python.server.session import ServerSession +from mcp_python.server.stdio import stdio_server as stdio_server +from mcp_python.shared.context import RequestContext +from mcp_python.shared.session import RequestResponder +from mcp_python.types import ( + METHOD_NOT_FOUND, + CallToolRequest, + ClientNotification, + ClientRequest, + CompleteRequest, + ErrorData, + JSONRPCMessage, + ListResourcesRequest, + ListResourcesResult, + LoggingLevel, + ProgressNotification, + Prompt, + PromptReference, + ReadResourceRequest, + ReadResourceResult, + Resource, + ResourceReference, + ServerResult, + SetLevelRequest, + SubscribeRequest, + UnsubscribeRequest, +) + +logger = logging.getLogger(__name__) + + +request_ctx: contextvars.ContextVar[RequestContext] = contextvars.ContextVar( + "request_ctx" +) + + +class Server: + def __init__(self, name: str): + self.name = name + self.request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]] = {} + self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} + logger.info(f"Initializing server '{name}'") + + @property + def request_context(self) -> RequestContext: + """If called outside of a request context, this will raise a LookupError.""" + return request_ctx.get() + + def list_prompts(self): + from mcp_python.types import ListPromptsRequest, ListPromptsResult + + def decorator(func: Callable[[], Awaitable[list[Prompt]]]): + logger.debug(f"Registering handler for PromptListRequest") + + async def handler(_: Any): + prompts = await func() + return ServerResult(ListPromptsResult(prompts=prompts)) + + self.request_handlers[ListPromptsRequest] = handler + return func + + return decorator + + def get_prompt(self): + from mcp_python.types import ( + GetPromptRequest, + GetPromptResult, + ImageContent, + Role as Role, + SamplingMessage, + TextContent, + ) + + def decorator( + func: Callable[ + [str, dict[str, str] | None], Awaitable[types.PromptResponse] + ], + ): + logger.debug(f"Registering handler for GetPromptRequest") + + async def handler(req: GetPromptRequest): + prompt_get = await func(req.params.name, req.params.arguments) + messages = [] + for message in prompt_get.messages: + match message.content: + case str() as text_content: + content = TextContent(type="text", text=text_content) + case types.ImageContent() as img_content: + content = ImageContent( + type="image", + data=img_content.data, + mimeType=img_content.mime_type, + ) + case _: + raise ValueError( + f"Unexpected content type: {type(message.content)}" + ) + + sampling_message = SamplingMessage( + role=message.role, content=content + ) + messages.append(sampling_message) + + return ServerResult( + GetPromptResult(description=prompt_get.desc, messages=messages) + ) + + self.request_handlers[GetPromptRequest] = handler + return func + + return decorator + + def list_resources(self): + def decorator(func: Callable[[], Awaitable[list[Resource]]]): + logger.debug(f"Registering handler for ListResourcesRequest") + + async def handler(_: Any): + resources = await func() + return ServerResult( + ListResourcesResult(resources=resources, resourceTemplates=None) + ) + + self.request_handlers[ListResourcesRequest] = handler + return func + + return decorator + + def read_resource(self): + from mcp_python.types import ( + BlobResourceContents, + TextResourceContents, + ) + + def decorator(func: Callable[[AnyUrl], Awaitable[str | bytes]]): + logger.debug(f"Registering handler for ReadResourceRequest") + + async def handler(req: ReadResourceRequest): + result = await func(req.params.uri) + match result: + case str(s): + content = TextResourceContents( + uri=req.params.uri, + text=s, + mimeType="text/plain", + ) + case bytes(b): + import base64 + + content = BlobResourceContents( + uri=req.params.uri, + blob=base64.urlsafe_b64encode(b).decode(), + mimeType="application/octet-stream", + ) + + return ServerResult( + ReadResourceResult( + contents=[content], + ) + ) + + self.request_handlers[ReadResourceRequest] = handler + return func + + return decorator + + def set_logging_level(self): + from mcp_python.types import EmptyResult + + def decorator(func: Callable[[LoggingLevel], Awaitable[None]]): + logger.debug(f"Registering handler for SetLevelRequest") + + async def handler(req: SetLevelRequest): + await func(req.params.level) + return ServerResult(EmptyResult()) + + self.request_handlers[SetLevelRequest] = handler + return func + + return decorator + + def subscribe_resource(self): + from mcp_python.types import EmptyResult + + def decorator(func: Callable[[AnyUrl], Awaitable[None]]): + logger.debug(f"Registering handler for SubscribeRequest") + + async def handler(req: SubscribeRequest): + await func(req.params.uri) + return ServerResult(EmptyResult()) + + self.request_handlers[SubscribeRequest] = handler + return func + + return decorator + + def unsubscribe_resource(self): + from mcp_python.types import EmptyResult + + def decorator(func: Callable[[AnyUrl], Awaitable[None]]): + logger.debug(f"Registering handler for UnsubscribeRequest") + + async def handler(req: UnsubscribeRequest): + await func(req.params.uri) + return ServerResult(EmptyResult()) + + self.request_handlers[UnsubscribeRequest] = handler + return func + + return decorator + + def call_tool(self): + from mcp_python.types import CallToolResult + + def decorator(func: Callable[..., Awaitable[Any]]): + logger.debug(f"Registering handler for CallToolRequest") + + async def handler(req: CallToolRequest): + result = await func(req.params.name, **(req.params.arguments or {})) + return ServerResult(CallToolResult(toolResult=result)) + + self.request_handlers[CallToolRequest] = handler + return func + + return decorator + + def progress_notification(self): + def decorator( + func: Callable[[str | int, float, float | None], Awaitable[None]], + ): + logger.debug(f"Registering handler for ProgressNotification") + + async def handler(req: ProgressNotification): + await func( + req.params.progressToken, req.params.progress, req.params.total + ) + + self.notification_handlers[ProgressNotification] = handler + return func + + return decorator + + def completion(self): + """Provides completions for prompts and resource templates""" + from mcp_python.types import CompleteResult, Completion, CompletionArgument + + def decorator( + func: Callable[ + [PromptReference | ResourceReference, CompletionArgument], + Awaitable[Completion | None], + ], + ): + logger.debug(f"Registering handler for CompleteRequest") + + async def handler(req: CompleteRequest): + completion = await func(req.params.ref, req.params.argument) + return ServerResult( + CompleteResult( + completion=completion + if completion is not None + else Completion(values=[], total=None, hasMore=None), + ) + ) + + self.request_handlers[CompleteRequest] = handler + return func + + return decorator + + async def run( + self, + read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], + write_stream: MemoryObjectSendStream[JSONRPCMessage], + ): + with warnings.catch_warnings(record=True) as w: + async with ServerSession(read_stream, write_stream) as session: + async for message in session.incoming_messages: + logger.debug(f"Received message: {message}") + + match message: + case RequestResponder(request=ClientRequest(root=req)): + logger.info( + f"Processing request of type {type(req).__name__}" + ) + if type(req) in self.request_handlers: + handler = self.request_handlers[type(req)] + logger.debug( + f"Dispatching request of type {type(req).__name__}" + ) + + try: + # Set our global state that can be retrieved via + # app.get_request_context() + token = request_ctx.set( + RequestContext( + message.request_id, + message.request_meta, + session, + ) + ) + response = await handler(req) + # Reset the global state after we are done + request_ctx.reset(token) + except Exception as err: + response = ErrorData( + code=0, message=str(err), data=None + ) + + await message.respond(response) + else: + await message.respond( + ErrorData( + code=METHOD_NOT_FOUND, + message="Method not found", + ) + ) + + logger.debug("Response sent") + case ClientNotification(root=notify): + if type(notify) in self.notification_handlers: + assert type(notify) in self.notification_handlers + + handler = self.notification_handlers[type(notify)] + logger.debug( + f"Dispatching notification of type {type(notify).__name__}" + ) + + try: + await handler(notify) + except Exception as err: + logger.error( + f"Uncaught exception in notification handler: {err}" + ) + + for warning in w: + logger.info( + f"Warning: {warning.category.__name__}: {warning.message}" + ) diff --git a/mcp_python/server/__main__.py b/mcp_python/server/__main__.py new file mode 100644 index 0000000..907b453 --- /dev/null +++ b/mcp_python/server/__main__.py @@ -0,0 +1,35 @@ +import logging +import sys + +import anyio + +from mcp_python.server.session import ServerSession +from mcp_python.server.stdio import stdio_server + +if not sys.warnoptions: + import warnings + + warnings.simplefilter("ignore") + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("server") + + +async def receive_loop(session: ServerSession): + logger.info("Starting receive loop") + async for message in session.incoming_messages: + if isinstance(message, Exception): + logger.error("Error: %s", message) + continue + + logger.info("Received message from client: %s", message) + + +async def main(): + async with stdio_server() as (read_stream, write_stream): + async with ServerSession(read_stream, write_stream) as session, write_stream: + await receive_loop(session) + + +if __name__ == "__main__": + anyio.run(main, backend="trio") diff --git a/mcp_python/server/session.py b/mcp_python/server/session.py new file mode 100644 index 0000000..687fda6 --- /dev/null +++ b/mcp_python/server/session.py @@ -0,0 +1,203 @@ +from enum import Enum +from typing import Any + +import anyio +import anyio.lowlevel +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import AnyUrl + +from mcp_python.shared.session import ( + BaseSession, + RequestResponder, +) +from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSION +from mcp_python.types import ( + ClientNotification, + ClientRequest, + CreateMessageResult, + EmptyResult, + Implementation, + IncludeContext, + InitializedNotification, + InitializeRequest, + InitializeResult, + JSONRPCMessage, + LoggingLevel, + SamplingMessage, + ServerCapabilities, + ServerNotification, + ServerRequest, + ServerResult, +) + + +class InitializationState(Enum): + NotInitialized = 1 + Initializing = 2 + Initialized = 3 + + +class ServerSession( + BaseSession[ + ServerRequest, + ServerNotification, + ServerResult, + ClientRequest, + ClientNotification, + ] +): + _initialized: InitializationState = InitializationState.NotInitialized + + def __init__( + self, + read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], + write_stream: MemoryObjectSendStream[JSONRPCMessage], + ) -> None: + super().__init__(read_stream, write_stream, ClientRequest, ClientNotification) + self._initialization_state = InitializationState.NotInitialized + + async def _received_request( + self, responder: RequestResponder[ClientRequest, ServerResult] + ): + match responder.request.root: + case InitializeRequest(): + self._initialization_state = InitializationState.Initializing + await responder.respond( + ServerResult( + InitializeResult( + protocolVersion=SUPPORTED_PROTOCOL_VERSION, + capabilities=ServerCapabilities( + logging=None, + resources=None, + tools=None, + experimental=None, + prompts={}, + ), + serverInfo=Implementation( + name="mcp_python", version="0.1.0" + ), + ) + ) + ) + case _: + if self._initialization_state != InitializationState.Initialized: + raise RuntimeError( + "Received request before initialization was complete" + ) + + async def _received_notification(self, notification: ClientNotification) -> None: + # Need this to avoid ASYNC910 + await anyio.lowlevel.checkpoint() + match notification.root: + case InitializedNotification(): + self._initialization_state = InitializationState.Initialized + case _: + if self._initialization_state != InitializationState.Initialized: + raise RuntimeError( + "Received notification before initialization was complete" + ) + + async def send_log_message( + self, level: LoggingLevel, data: Any, logger: str | None = None + ) -> None: + """Send a log message notification.""" + from mcp_python.types import ( + LoggingMessageNotification, + LoggingMessageNotificationParams, + ) + + await self.send_notification( + ServerNotification( + LoggingMessageNotification( + method="notifications/message", + params=LoggingMessageNotificationParams( + level=level, + data=data, + logger=logger, + ), + ) + ) + ) + + async def send_resource_updated(self, uri: AnyUrl) -> None: + """Send a resource updated notification.""" + from mcp_python.types import ( + ResourceUpdatedNotification, + ResourceUpdatedNotificationParams, + ) + + await self.send_notification( + ServerNotification( + ResourceUpdatedNotification( + method="notifications/resources/updated", + params=ResourceUpdatedNotificationParams(uri=uri), + ) + ) + ) + + async def request_create_message( + self, + messages: list[SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + ) -> CreateMessageResult: + """Send a sampling/create_message request.""" + from mcp_python.types import ( + CreateMessageRequest, + CreateMessageRequestParams, + ) + + return await self.send_request( + ServerRequest( + CreateMessageRequest( + method="sampling/createMessage", + params=CreateMessageRequestParams( + messages=messages, + systemPrompt=system_prompt, + includeContext=include_context, + temperature=temperature, + maxTokens=max_tokens, + stopSequences=stop_sequences, + metadata=metadata, + ), + ) + ), + CreateMessageResult, + ) + + async def send_ping(self) -> EmptyResult: + """Send a ping request.""" + from mcp_python.types import PingRequest + + return await self.send_request( + ServerRequest( + PingRequest( + method="ping", + ) + ), + EmptyResult, + ) + + async def send_progress_notification( + self, progress_token: str | int, progress: float, total: float | None = None + ) -> None: + """Send a progress notification.""" + from mcp_python.types import ProgressNotification, ProgressNotificationParams + + await self.send_notification( + ServerNotification( + ProgressNotification( + method="notifications/progress", + params=ProgressNotificationParams( + progressToken=progress_token, + progress=progress, + total=total, + ), + ) + ) + ) diff --git a/mcp_python/server/sse.py b/mcp_python/server/sse.py new file mode 100644 index 0000000..8ae6436 --- /dev/null +++ b/mcp_python/server/sse.py @@ -0,0 +1,133 @@ +import logging +from contextlib import asynccontextmanager +from typing import Any +from urllib.parse import quote +from uuid import UUID, uuid4 + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import ValidationError +from sse_starlette import EventSourceResponse +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import Receive, Scope, Send + +from mcp_python.types import JSONRPCMessage + +logger = logging.getLogger(__name__) + + +class SseServerTransport: + """ + SSE server transport for MCP. This class provides _two_ ASGI applications, suitable to be used with a framework like Starlette and a server like Hypercorn: + + 1. connect_sse() is an ASGI application which receives incoming GET requests, and sets up a new SSE stream to send server messages to the client. + 2. handle_post_message() is an ASGI application which receives incoming POST requests, which should contain client messages that link to a previously-established SSE session. + """ + + _endpoint: str + _read_stream_writers: dict[UUID, MemoryObjectSendStream[JSONRPCMessage | Exception]] + + def __init__(self, endpoint: str) -> None: + """ + Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL given. + """ + + super().__init__() + self._endpoint = endpoint + self._read_stream_writers = {} + logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") + + @asynccontextmanager + async def connect_sse(self, scope: Scope, receive: Receive, send: Send): + if scope["type"] != "http": + logger.error("connect_sse received non-HTTP request") + raise ValueError("connect_sse can only handle HTTP requests") + + logger.debug("Setting up SSE connection") + read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] + + write_stream: MemoryObjectSendStream[JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] + + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + + session_id = uuid4() + session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}" + self._read_stream_writers[session_id] = read_stream_writer + logger.debug(f"Created new session with ID: {session_id}") + + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream( + 0, dict[str, Any] + ) + + async def sse_writer(): + logger.debug("Starting SSE writer") + async with sse_stream_writer, write_stream_reader: + await sse_stream_writer.send({"event": "endpoint", "data": session_uri}) + logger.debug(f"Sent endpoint event: {session_uri}") + + async for message in write_stream_reader: + logger.debug(f"Sending message via SSE: {message}") + await sse_stream_writer.send( + { + "event": "message", + "data": message.model_dump_json(by_alias=True), + } + ) + + async with anyio.create_task_group() as tg: + response = EventSourceResponse( + content=sse_stream_reader, data_sender_callable=sse_writer + ) + logger.debug("Starting SSE response task") + tg.start_soon(response, scope, receive, send) + + logger.debug("Yielding read and write streams") + yield (read_stream, write_stream) + + async def handle_post_message( + self, scope: Scope, receive: Receive, send: Send + ) -> None: + logger.debug("Handling POST message") + request = Request(scope, receive) + + session_id_param = request.query_params.get("session_id") + if session_id_param is None: + logger.warning("Received request without session_id") + response = Response("session_id is required", status_code=400) + return await response(scope, receive, send) + + try: + session_id = UUID(hex=session_id_param) + logger.debug(f"Parsed session ID: {session_id}") + except ValueError: + logger.warning(f"Received invalid session ID: {session_id_param}") + response = Response("Invalid session ID", status_code=400) + return await response(scope, receive, send) + + writer = self._read_stream_writers.get(session_id) + if not writer: + logger.warning(f"Could not find session for ID: {session_id}") + response = Response("Could not find session", status_code=404) + return await response(scope, receive, send) + + json = await request.json() + logger.debug(f"Received JSON: {json}") + + try: + message = JSONRPCMessage.model_validate(json) + logger.debug(f"Validated client message: {message}") + except ValidationError as err: + logger.error(f"Failed to parse message: {err}") + response = Response("Could not parse message", status_code=400) + await response(scope, receive, send) + await writer.send(err) + return + + logger.debug(f"Sending message to writer: {message}") + response = Response("Accepted", status_code=202) + await response(scope, receive, send) + await writer.send(message) diff --git a/mcp_python/server/stdio.py b/mcp_python/server/stdio.py new file mode 100644 index 0000000..8757c3d --- /dev/null +++ b/mcp_python/server/stdio.py @@ -0,0 +1,60 @@ +import sys +from contextlib import asynccontextmanager + +import anyio +import anyio.lowlevel +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream + +from mcp_python.types import JSONRPCMessage + + +@asynccontextmanager +async def stdio_server( + stdin: anyio.AsyncFile | None = None, stdout: anyio.AsyncFile | None = None +): + """ + Server transport for stdio: this communicates with an MCP client by reading from the current process' stdin and writing to stdout. + """ + # Purposely not using context managers for these, as we don't want to close standard process handles. + if not stdin: + stdin = anyio.wrap_file(sys.stdin) + if not stdout: + stdout = anyio.wrap_file(sys.stdout) + + read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] + + write_stream: MemoryObjectSendStream[JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] + + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + + async def stdin_reader(): + try: + async with read_stream_writer: + async for line in stdin: + try: + message = JSONRPCMessage.model_validate_json(line) + except Exception as exc: + await read_stream_writer.send(exc) + continue + + await read_stream_writer.send(message) + except anyio.ClosedResourceError: + await anyio.lowlevel.checkpoint() + + async def stdout_writer(): + try: + async with write_stream_reader: + async for message in write_stream_reader: + json = message.model_dump_json(by_alias=True) + await stdout.write(json + "\n") + await stdout.flush() + except anyio.ClosedResourceError: + await anyio.lowlevel.checkpoint() + + async with anyio.create_task_group() as tg: + tg.start_soon(stdin_reader) + tg.start_soon(stdout_writer) + yield read_stream, write_stream diff --git a/mcp_python/server/types.py b/mcp_python/server/types.py new file mode 100644 index 0000000..2993d84 --- /dev/null +++ b/mcp_python/server/types.py @@ -0,0 +1,27 @@ +""" +This module provides simpler types to use with the server for managing prompts. +""" + +from dataclasses import dataclass +from typing import Literal + +from mcp_python.types import Role + + +@dataclass +class ImageContent: + type: Literal["image"] + data: str + mime_type: str + + +@dataclass +class Message: + role: Role + content: str | ImageContent + + +@dataclass +class PromptResponse: + messages: list[Message] + desc: str | None = None diff --git a/mcp_python/server/websocket.py b/mcp_python/server/websocket.py new file mode 100644 index 0000000..6547d77 --- /dev/null +++ b/mcp_python/server/websocket.py @@ -0,0 +1,58 @@ +import logging +from contextlib import asynccontextmanager + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from starlette.types import Receive, Scope, Send +from starlette.websockets import WebSocket + +from mcp_python.types import JSONRPCMessage + +logger = logging.getLogger(__name__) + + +@asynccontextmanager +async def websocket_server(scope: Scope, receive: Receive, send: Send): + """ + WebSocket server transport for MCP. This is an ASGI application, suitable to be used with a framework like Starlette and a server like Hypercorn. + """ + + websocket = WebSocket(scope, receive, send) + await websocket.accept(subprotocol="mcp") + + read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] + + write_stream: MemoryObjectSendStream[JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] + + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + + async def ws_reader(): + try: + async with read_stream_writer: + async for message in websocket.iter_json(): + try: + client_message = JSONRPCMessage.model_validate(message) + except Exception as exc: + await read_stream_writer.send(exc) + continue + + await read_stream_writer.send(client_message) + except anyio.ClosedResourceError: + await websocket.close() + + async def ws_writer(): + try: + async with write_stream_reader: + async for message in write_stream_reader: + obj = message.model_dump(by_alias=True, mode="json") + await websocket.send_json(obj) + except anyio.ClosedResourceError: + await websocket.close() + + async with anyio.create_task_group() as tg: + tg.start_soon(ws_reader) + tg.start_soon(ws_writer) + yield (read_stream, write_stream) diff --git a/mcp_python/shared/__init__.py b/mcp_python/shared/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mcp_python/shared/context.py b/mcp_python/shared/context.py new file mode 100644 index 0000000..20481d6 --- /dev/null +++ b/mcp_python/shared/context.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass +from typing import Generic, TypeVar + +from mcp_python.shared.session import BaseSession +from mcp_python.types import RequestId, RequestParams + +SessionT = TypeVar("SessionT", bound=BaseSession) + + +@dataclass +class RequestContext(Generic[SessionT]): + request_id: RequestId + meta: RequestParams.Meta | None + session: SessionT diff --git a/mcp_python/shared/exceptions.py b/mcp_python/shared/exceptions.py new file mode 100644 index 0000000..cf98cdf --- /dev/null +++ b/mcp_python/shared/exceptions.py @@ -0,0 +1,9 @@ +from mcp_python.types import ErrorData + + +class McpError(Exception): + """ + Exception type raised when an error arrives over an MCP connection. + """ + + error: ErrorData diff --git a/mcp_python/shared/progress.py b/mcp_python/shared/progress.py new file mode 100644 index 0000000..d94740d --- /dev/null +++ b/mcp_python/shared/progress.py @@ -0,0 +1,40 @@ +from contextlib import contextmanager +from dataclasses import dataclass, field + +from pydantic import BaseModel + +from mcp_python.shared.context import RequestContext +from mcp_python.shared.session import BaseSession +from mcp_python.types import ProgressToken + + +class Progress(BaseModel): + progress: float + total: float | None + + +@dataclass +class ProgressContext: + session: BaseSession + progress_token: ProgressToken + total: float | None + current: float = field(default=0.0, init=False) + + async def progress(self, amount: float) -> None: + self.current += amount + + await self.session.send_progress_notification( + self.progress_token, self.current, total=self.total + ) + + +@contextmanager +def progress(ctx: RequestContext, total: float | None = None): + if ctx.meta is None or ctx.meta.progressToken is None: + raise ValueError("No progress token provided") + + progress_ctx = ProgressContext(ctx.session, ctx.meta.progressToken, total) + try: + yield progress_ctx + finally: + pass diff --git a/mcp_python/shared/session.py b/mcp_python/shared/session.py new file mode 100644 index 0000000..8be9386 --- /dev/null +++ b/mcp_python/shared/session.py @@ -0,0 +1,244 @@ +from contextlib import AbstractAsyncContextManager +from typing import Generic, TypeVar + +import anyio +import anyio.lowlevel +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import BaseModel + +from mcp_python.shared.exceptions import McpError +from mcp_python.types import ( + ClientNotification, + ClientRequest, + ClientResult, + ErrorData, + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + RequestParams, + ServerNotification, + ServerRequest, + ServerResult, +) + +SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest) +SendResultT = TypeVar("SendResultT", ClientResult, ServerResult) +SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification) +ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest) +ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) +ReceiveNotificationT = TypeVar( + "ReceiveNotificationT", ClientNotification, ServerNotification +) + +RequestId = str | int + + +class RequestResponder(Generic[ReceiveRequestT, SendResultT]): + def __init__( + self, + request_id: RequestId, + request_meta: RequestParams.Meta | None, + request: ReceiveRequestT, + session: "BaseSession", + ) -> None: + self.request_id = request_id + self.request_meta = request_meta + self.request = request + self._session = session + self._responded = False + + async def respond(self, response: SendResultT | ErrorData) -> None: + assert not self._responded, "Request already responded to" + self._responded = True + + await self._session._send_response( + request_id=self.request_id, response=response + ) + + +class BaseSession( + AbstractAsyncContextManager, + Generic[ + SendRequestT, + SendNotificationT, + SendResultT, + ReceiveRequestT, + ReceiveNotificationT, + ], +): + """ + Implements an MCP "session" on top of read/write streams, including features like request/response linking, notifications, and progress. + + This class is an async context manager that automatically starts processing messages when entered. + """ + + _response_streams: dict[ + RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError] + ] + _request_id: int + + def __init__( + self, + read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], + write_stream: MemoryObjectSendStream[JSONRPCMessage], + receive_request_type: type[ReceiveRequestT], + receive_notification_type: type[ReceiveNotificationT], + ) -> None: + self._read_stream = read_stream + self._write_stream = write_stream + self._response_streams = {} + self._request_id = 0 + self._receive_request_type = receive_request_type + self._receive_notification_type = receive_notification_type + + self._incoming_message_stream_writer, self._incoming_message_stream_reader = ( + anyio.create_memory_object_stream[ + RequestResponder[ReceiveRequestT, SendResultT] + | ReceiveNotificationT + | Exception + ]() + ) + + async def __aenter__(self): + self._task_group = anyio.create_task_group() + await self._task_group.__aenter__() + self._task_group.start_soon(self._receive_loop) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + # Using BaseSession as a context manager should not block on exit (this would be very surprising behavior), so make sure to cancel the tasks in the task group. + self._task_group.cancel_scope.cancel() + return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + + async def send_request( + self, + request: SendRequestT, + result_type: type[ReceiveResultT], + ) -> ReceiveResultT: + """ + Sends a request and wait for a response. Raises an McpError if the response contains an error. + + Do not use this method to emit notifications! Use send_notification() instead. + """ + + request_id = self._request_id + self._request_id = request_id + 1 + + response_stream, response_stream_reader = anyio.create_memory_object_stream[ + JSONRPCResponse | JSONRPCError + ](1) + self._response_streams[request_id] = response_stream + + jsonrpc_request = JSONRPCRequest( + jsonrpc="2.0", id=request_id, **request.model_dump(by_alias=True, mode="json") + ) + + # TODO: Support progress callbacks + + await self._write_stream.send(JSONRPCMessage(jsonrpc_request)) + + response_or_error = await response_stream_reader.receive() + if isinstance(response_or_error, JSONRPCError): + raise McpError(response_or_error.error) + else: + return result_type.model_validate(response_or_error.result) + + async def send_notification(self, notification: SendNotificationT) -> None: + """ + Emits a notification, which is a one-way message that does not expect a response. + """ + jsonrpc_notification = JSONRPCNotification( + jsonrpc="2.0", **notification.model_dump(by_alias=True, mode="json") + ) + + await self._write_stream.send(JSONRPCMessage(jsonrpc_notification)) + + async def _send_response( + self, request_id: RequestId, response: SendResultT | ErrorData + ) -> None: + if isinstance(response, ErrorData): + jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) + await self._write_stream.send(JSONRPCMessage(jsonrpc_error)) + else: + jsonrpc_response = JSONRPCResponse( + jsonrpc="2.0", + id=request_id, + result=response.model_dump(by_alias=True, mode="json"), + ) + await self._write_stream.send(JSONRPCMessage(jsonrpc_response)) + + async def _receive_loop(self) -> None: + async with ( + self._read_stream, + self._write_stream, + self._incoming_message_stream_writer, + ): + async for message in self._read_stream: + if isinstance(message, Exception): + await self._incoming_message_stream_writer.send(message) + elif isinstance(message.root, JSONRPCRequest): + validated_request = self._receive_request_type.model_validate( + message.root.model_dump(by_alias=True, mode="json") + ) + responder = RequestResponder( + request_id=message.root.id, + request_meta=validated_request.root.params._meta + if validated_request.root.params + else None, + request=validated_request, + session=self, + ) + + await self._received_request(responder) + if not responder._responded: + await self._incoming_message_stream_writer.send(responder) + elif isinstance(message.root, JSONRPCNotification): + notification = self._receive_notification_type.model_validate( + message.root.model_dump(by_alias=True, mode="json") + ) + + await self._received_notification(notification) + await self._incoming_message_stream_writer.send(notification) + else: # Response or error + stream = self._response_streams.pop(message.root.id, None) + if stream: + await stream.send(message.root) + else: + await self._incoming_message_stream_writer.send( + RuntimeError( + f"Received response with an unknown request ID: {message}" + ) + ) + + async def _received_request( + self, responder: RequestResponder[ReceiveRequestT, SendResultT] + ) -> None: + """ + Can be overridden by subclasses to handle a request without needing to listen on the message stream. + + If the request is responded to within this method, it will not be forwarded on to the message stream. + """ + + async def _received_notification(self, notification: ReceiveNotificationT) -> None: + """ + Can be overridden by subclasses to handle a notification without needing to listen on the message stream. + """ + + async def send_progress_notification( + self, progress_token: str | int, progress: float, total: float | None = None + ) -> None: + """ + Sends a progress notification for a request that is currently being processed. + """ + + @property + def incoming_messages( + self, + ) -> MemoryObjectReceiveStream[ + RequestResponder[ReceiveRequestT, SendResultT] + | ReceiveNotificationT + | Exception + ]: + return self._incoming_message_stream_reader diff --git a/mcp_python/shared/version.py b/mcp_python/shared/version.py new file mode 100644 index 0000000..bc8db20 --- /dev/null +++ b/mcp_python/shared/version.py @@ -0,0 +1 @@ +SUPPORTED_PROTOCOL_VERSION = 1 diff --git a/mcp_python/types.py b/mcp_python/types.py new file mode 100644 index 0000000..becc2b1 --- /dev/null +++ b/mcp_python/types.py @@ -0,0 +1,709 @@ +from typing import Any, Generic, Literal, TypeVar + +from pydantic import BaseModel, ConfigDict, RootModel +from pydantic.networks import AnyUrl + +""" +Model Context Protocol bindings for Python + +These bindings were generated from https://github.com/anthropic-experimental/mcp-spec, using Claude, with a prompt something like the following: + +Generate idiomatic Python bindings for this schema for MCP, or the "Model Context Protocol." The schema is defined in TypeScript, but there's also a JSON Schema version for reference. + +* For the bindings, let's use Pydantic V2 models. +* Each model should allow extra fields everywhere, by specifying `model_config = ConfigDict(extra='allow')`. Do this in every case, instead of a custom base class. +* Union types should be represented with a Pydantic `RootModel`. +* Define additional model classes instead of using dictionaries. Do this even if they're not separate types in the schema. +""" + + +ProgressToken = str | int + + +class RequestParams(BaseModel): + class Meta(BaseModel): + progressToken: ProgressToken | None = None + """ + If specified, the caller is requesting out-of-band progress notifications for this request (as represented by notifications/progress). The value of this parameter is an opaque token that will be attached to any subsequent notifications. The receiver is not obligated to provide these notifications. + """ + + model_config = ConfigDict(extra="allow") + + _meta: Meta | None = None + + +class NotificationParams(BaseModel): + class Meta(BaseModel): + model_config = ConfigDict(extra="allow") + + _meta: Meta | None = None + """ + This parameter name is reserved by MCP to allow clients and servers to attach additional metadata to their notifications. + """ + + +RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams) +NotificationParamsT = TypeVar("NotificationParamsT", bound=NotificationParams) + + +class Request(BaseModel, Generic[RequestParamsT]): + """Base class for JSON-RPC requests.""" + + method: str + params: RequestParamsT + model_config = ConfigDict(extra="allow") + + +class Notification(BaseModel, Generic[NotificationParamsT]): + """Base class for JSON-RPC notifications.""" + + method: str + model_config = ConfigDict(extra="allow") + + +class Result(BaseModel): + """Base class for JSON-RPC results.""" + + model_config = ConfigDict(extra="allow") + + _meta: dict[str, Any] | None = None + """ + This result property is reserved by the protocol to allow clients and servers to attach additional metadata to their responses. + """ + + +RequestId = str | int + + +class JSONRPCRequest(Request): + """A request that expects a response.""" + + jsonrpc: Literal["2.0"] + id: RequestId + params: dict[str, Any] | None = None + + +class JSONRPCNotification(Notification): + """A notification which does not expect a response.""" + + jsonrpc: Literal["2.0"] + params: dict[str, Any] | None = None + + +class JSONRPCResponse(BaseModel): + """A successful (non-error) response to a request.""" + + jsonrpc: Literal["2.0"] + id: RequestId + result: dict[str, Any] + model_config = ConfigDict(extra="allow") + + +# Standard JSON-RPC error codes +PARSE_ERROR = -32700 +INVALID_REQUEST = -32600 +METHOD_NOT_FOUND = -32601 +INVALID_PARAMS = -32602 + + +class ErrorData(BaseModel): + """Error information for JSON-RPC error responses.""" + + code: int + """The error type that occurred.""" + message: str + """A short description of the error. The message SHOULD be limited to a concise single sentence.""" + data: Any | None = None + """Additional information about the error. The value of this member is defined by the sender (e.g. detailed error information, nested errors etc.).""" + model_config = ConfigDict(extra="allow") + + +class JSONRPCError(BaseModel): + """A response to a request that indicates an error occurred.""" + + jsonrpc: Literal["2.0"] + id: str | int + error: ErrorData + model_config = ConfigDict(extra="allow") + + +class JSONRPCMessage( + RootModel[JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCError] +): + pass + + +class EmptyResult(Result): + """A response that indicates success but carries no data.""" + + +class Implementation(BaseModel): + """Describes the name and version of an MCP implementation.""" + + name: str + version: str + model_config = ConfigDict(extra="allow") + + +class ClientCapabilities(BaseModel): + """Capabilities a client may support.""" + + experimental: dict[str, dict[str, Any]] | None = None + """Experimental, non-standard capabilities that the client supports.""" + sampling: dict[str, Any] | None = None + """Present if the client supports sampling from an LLM.""" + model_config = ConfigDict(extra="allow") + + +class ServerCapabilities(BaseModel): + """Capabilities that a server may support.""" + + experimental: dict[str, dict[str, Any]] | None = None + """Experimental, non-standard capabilities that the server supports.""" + logging: dict[str, Any] | None = None + """Present if the server supports sending log messages to the client.""" + prompts: dict[str, Any] | None = None + """Present if the server offers any prompt templates.""" + resources: dict[str, Any] | None = None + """Present if the server offers any resources to read.""" + tools: dict[str, Any] | None = None + """Present if the server offers any tools to call.""" + model_config = ConfigDict(extra="allow") + + +class InitializeRequestParams(RequestParams): + """Parameters for the initialize request.""" + + protocolVersion: Literal[1] + """The latest version of the Model Context Protocol that the client supports.""" + capabilities: ClientCapabilities + clientInfo: Implementation + model_config = ConfigDict(extra="allow") + + +class InitializeRequest(Request): + """This request is sent from the client to the server when it first connects, asking it to begin initialization.""" + + method: Literal["initialize"] + params: InitializeRequestParams + + +class InitializeResult(Result): + """After receiving an initialize request from the client, the server sends this response.""" + + protocolVersion: Literal[1] + """The version of the Model Context Protocol that the server wants to use.""" + capabilities: ServerCapabilities + serverInfo: Implementation + + +class InitializedNotification(Notification): + """This notification is sent from the client to the server after initialization has finished.""" + + method: Literal["notifications/initialized"] + params: NotificationParams | None = None + + +class PingRequest(Request): + """A ping, issued by either the server or the client, to check that the other party is still alive.""" + + method: Literal["ping"] + params: RequestParams | None = None + + +class ProgressNotificationParams(NotificationParams): + """Parameters for progress notifications.""" + + progressToken: ProgressToken + """The progress token which was given in the initial request, used to associate this notification with the request that is proceeding.""" + progress: float + """The progress thus far. This should increase every time progress is made, even if the total is unknown.""" + total: float | None = None + """Total number of items to process (or total progress required), if known.""" + model_config = ConfigDict(extra="allow") + + +class ProgressNotification(Notification): + """An out-of-band notification used to inform the receiver of a progress update for a long-running request.""" + + method: Literal["notifications/progress"] + params: ProgressNotificationParams + + +class ListResourcesRequest(Request): + """Sent from the client to request a list of resources the server has.""" + + method: Literal["resources/list"] + params: RequestParams | None = None + + +class Resource(BaseModel): + """A known resource that the server is capable of reading.""" + + uri: AnyUrl + """The URI of this resource.""" + mimeType: str | None = None + """The MIME type of this resource, if known.""" + model_config = ConfigDict(extra="allow") + + +class ResourceTemplate(BaseModel): + """A template description for resources available on the server.""" + + uriTemplate: str + """A URI template (according to RFC 6570) that can be used to construct resource URIs.""" + name: str | None = None + """A human-readable name for the type of resource this template refers to.""" + description: str | None = None + """A human-readable description of what this template is for.""" + mimeType: str | None = None + """The MIME type for all resources that match this template. This should only be included if all resources matching this template have the same type.""" + model_config = ConfigDict(extra="allow") + + +class ListResourcesResult(Result): + """The server's response to a resources/list request from the client.""" + + resourceTemplates: list[ResourceTemplate] | None = None + resources: list[Resource] | None = None + + +class ReadResourceRequestParams(RequestParams): + """Parameters for reading a resource.""" + + uri: AnyUrl + """The URI of the resource to read. The URI can use any protocol; it is up to the server how to interpret it.""" + model_config = ConfigDict(extra="allow") + + +class ReadResourceRequest(Request): + """Sent from the client to the server, to read a specific resource URI.""" + + method: Literal["resources/read"] + params: ReadResourceRequestParams + + +class ResourceContents(BaseModel): + """The contents of a specific resource or sub-resource.""" + + uri: AnyUrl + """The URI of this resource.""" + mimeType: str | None = None + """The MIME type of this resource, if known.""" + model_config = ConfigDict(extra="allow") + + +class TextResourceContents(ResourceContents): + """Text contents of a resource.""" + + text: str + """The text of the item. This must only be set if the item can actually be represented as text (not binary data).""" + + +class BlobResourceContents(ResourceContents): + """Binary contents of a resource.""" + + blob: str + """A base64-encoded string representing the binary data of the item.""" + + +class ReadResourceResult(Result): + """The server's response to a resources/read request from the client.""" + + contents: list[TextResourceContents | BlobResourceContents] + + +class ResourceListChangedNotification(Notification): + """An optional notification from the server to the client, informing it that the list of resources it can read from has changed.""" + + method: Literal["notifications/resources/list_changed"] + params: NotificationParams | None = None + + +class SubscribeRequestParams(RequestParams): + """Parameters for subscribing to a resource.""" + + uri: AnyUrl + """The URI of the resource to subscribe to. The URI can use any protocol; it is up to the server how to interpret it.""" + model_config = ConfigDict(extra="allow") + + +class SubscribeRequest(Request): + """Sent from the client to request resources/updated notifications from the server whenever a particular resource changes.""" + + method: Literal["resources/subscribe"] + params: SubscribeRequestParams + + +class UnsubscribeRequestParams(RequestParams): + """Parameters for unsubscribing from a resource.""" + + uri: AnyUrl + """The URI of the resource to unsubscribe from.""" + model_config = ConfigDict(extra="allow") + + +class UnsubscribeRequest(Request): + """Sent from the client to request cancellation of resources/updated notifications from the server.""" + + method: Literal["resources/unsubscribe"] + params: UnsubscribeRequestParams + + +class ResourceUpdatedNotificationParams(NotificationParams): + """Parameters for resource update notifications.""" + + uri: AnyUrl + """The URI of the resource that has been updated. This might be a sub-resource of the one that the client actually subscribed to.""" + model_config = ConfigDict(extra="allow") + + +class ResourceUpdatedNotification(Notification): + """A notification from the server to the client, informing it that a resource has changed and may need to be read again.""" + + method: Literal["notifications/resources/updated"] + params: ResourceUpdatedNotificationParams + + +class ListPromptsRequest(Request): + """Sent from the client to request a list of prompts and prompt templates the server has.""" + + method: Literal["prompts/list"] + params: RequestParams | None = None + + +class PromptArgument(BaseModel): + """An argument for a prompt template.""" + + name: str + """The name of the argument.""" + description: str | None = None + """A human-readable description of the argument.""" + required: bool | None = None + """Whether this argument must be provided.""" + model_config = ConfigDict(extra="allow") + + +class Prompt(BaseModel): + """A prompt or prompt template that the server offers.""" + + name: str + """The name of the prompt or prompt template.""" + description: str | None = None + """An optional description of what this prompt provides.""" + arguments: list[PromptArgument] | None = None + """A list of arguments to use for templating the prompt.""" + model_config = ConfigDict(extra="allow") + + +class ListPromptsResult(Result): + """The server's response to a prompts/list request from the client.""" + + prompts: list[Prompt] + + +class GetPromptRequestParams(RequestParams): + """Parameters for getting a prompt.""" + + name: str + """The name of the prompt or prompt template.""" + arguments: dict[str, str] | None = None + """Arguments to use for templating the prompt.""" + model_config = ConfigDict(extra="allow") + + +class GetPromptRequest(Request): + """Used by the client to get a prompt provided by the server.""" + + method: Literal["prompts/get"] + params: GetPromptRequestParams + + +class TextContent(BaseModel): + """Text content for a message.""" + + type: Literal["text"] + text: str + """The text content of the message.""" + model_config = ConfigDict(extra="allow") + + +class ImageContent(BaseModel): + """Image content for a message.""" + + type: Literal["image"] + data: str + """The base64-encoded image data.""" + mimeType: str + """The MIME type of the image. Different providers may support different image types.""" + model_config = ConfigDict(extra="allow") + + +Role = Literal["user", "assistant"] + + +class SamplingMessage(BaseModel): + """Describes a message issued to or received from an LLM API.""" + + role: Role + content: TextContent | ImageContent + model_config = ConfigDict(extra="allow") + + +class GetPromptResult(Result): + """The server's response to a prompts/get request from the client.""" + + description: str | None = None + """An optional description for the prompt.""" + messages: list[SamplingMessage] + + +class ListToolsRequest(Request): + """Sent from the client to request a list of tools the server has.""" + + method: Literal["tools/list"] + params: RequestParams | None = None + + +class Tool(BaseModel): + """Definition for a tool the client can call.""" + + name: str + """The name of the tool.""" + description: str | None = None + """A human-readable description of the tool.""" + inputSchema: dict[str, Any] + """A JSON Schema object defining the expected parameters for the tool.""" + model_config = ConfigDict(extra="allow") + + +class ListToolsResult(Result): + """The server's response to a tools/list request from the client.""" + + tools: list[Tool] + + +class CallToolRequestParams(RequestParams): + """Parameters for calling a tool.""" + + name: str + arguments: dict[str, Any] | None = None + model_config = ConfigDict(extra="allow") + + +class CallToolRequest(Request): + """Used by the client to invoke a tool provided by the server.""" + + method: Literal["tools/call"] + params: CallToolRequestParams + + +class CallToolResult(Result): + """The server's response to a tool call.""" + + toolResult: Any + + +class ToolListChangedNotification(Notification): + """An optional notification from the server to the client, informing it that the list of tools it offers has changed.""" + + method: Literal["notifications/tools/list_changed"] + params: NotificationParams | None = None + + +LoggingLevel = Literal["debug", "info", "warning", "error"] + + +class SetLevelRequestParams(RequestParams): + """Parameters for setting the logging level.""" + + level: LoggingLevel + """The level of logging that the client wants to receive from the server.""" + model_config = ConfigDict(extra="allow") + + +class SetLevelRequest(Request): + """A request from the client to the server, to enable or adjust logging.""" + + method: Literal["logging/setLevel"] + params: SetLevelRequestParams + + +class LoggingMessageNotificationParams(NotificationParams): + """Parameters for logging message notifications.""" + + level: LoggingLevel + """The severity of this log message.""" + logger: str | None = None + """An optional name of the logger issuing this message.""" + data: Any + """The data to be logged, such as a string message or an object. Any JSON serializable type is allowed here.""" + model_config = ConfigDict(extra="allow") + + +class LoggingMessageNotification(Notification): + """Notification of a log message passed from server to client.""" + + method: Literal["notifications/message"] + params: LoggingMessageNotificationParams + + +IncludeContext = Literal["none", "thisServer", "allServers"] + + +class CreateMessageRequestParams(RequestParams): + """Parameters for creating a message.""" + + messages: list[SamplingMessage] + systemPrompt: str | None = None + """An optional system prompt the server wants to use for sampling.""" + includeContext: IncludeContext | None = None + """A request to include context from one or more MCP servers (including the caller), to be attached to the prompt.""" + temperature: float | None = None + maxTokens: int + """The maximum number of tokens to sample, as requested by the server.""" + stopSequences: list[str] | None = None + metadata: dict[str, Any] | None = None + """Optional metadata to pass through to the LLM provider.""" + model_config = ConfigDict(extra="allow") + + +class CreateMessageRequest(Request): + """A request from the server to sample an LLM via the client.""" + + method: Literal["sampling/createMessage"] + params: CreateMessageRequestParams + + +StopReason = Literal["endTurn", "stopSequence", "maxTokens"] + + +class CreateMessageResult(Result): + """The client's response to a sampling/create_message request from the server.""" + + role: Role + content: TextContent | ImageContent + model: str + """The name of the model that generated the message.""" + stopReason: StopReason + """The reason why sampling stopped.""" + + +class ResourceReference(BaseModel): + """A reference to a resource or resource template definition.""" + + type: Literal["ref/resource"] + uri: str + """The URI or URI template of the resource.""" + model_config = ConfigDict(extra="allow") + + +class PromptReference(BaseModel): + """Identifies a prompt.""" + + type: Literal["ref/prompt"] + name: str + """The name of the prompt or prompt template""" + model_config = ConfigDict(extra="allow") + + +class CompletionArgument(BaseModel): + """The argument's information for completion requests.""" + + name: str + """The name of the argument""" + value: str + """The value of the argument to use for completion matching.""" + model_config = ConfigDict(extra="allow") + + +class CompleteRequestParams(RequestParams): + """Parameters for completion requests.""" + + ref: ResourceReference | PromptReference + argument: CompletionArgument + model_config = ConfigDict(extra="allow") + + +class CompleteRequest(Request): + """A request from the client to the server, to ask for completion options.""" + + method: Literal["completion/complete"] + params: CompleteRequestParams + + +class Completion(BaseModel): + """Completion information.""" + + values: list[str] + """An array of completion values. Must not exceed 100 items.""" + total: int | None = None + """The total number of completion options available. This can exceed the number of values actually sent in the response.""" + hasMore: bool | None = None + """Indicates whether there are additional completion options beyond those provided in the current response, even if the exact total is unknown.""" + model_config = ConfigDict(extra="allow") + + +class CompleteResult(Result): + """The server's response to a completion/complete request""" + + completion: Completion + + +class ClientRequest( + RootModel[ + PingRequest + | InitializeRequest + | CompleteRequest + | SetLevelRequest + | GetPromptRequest + | ListPromptsRequest + | ListResourcesRequest + | ReadResourceRequest + | SubscribeRequest + | UnsubscribeRequest + | CallToolRequest + | ListToolsRequest + ] +): + pass + + +class ClientNotification(RootModel[ProgressNotification | InitializedNotification]): + pass + + +class ClientResult(RootModel[EmptyResult | CreateMessageResult]): + pass + + +class ServerRequest(RootModel[PingRequest | CreateMessageRequest]): + pass + + +class ServerNotification( + RootModel[ + ProgressNotification + | LoggingMessageNotification + | ResourceUpdatedNotification + | ResourceListChangedNotification + | ToolListChangedNotification + ] +): + pass + + +class ServerResult( + RootModel[ + EmptyResult + | InitializeResult + | CompleteResult + | GetPromptResult + | ListPromptsResult + | ListResourcesResult + | ReadResourceResult + | CallToolResult + | ListToolsResult + ] +): + pass diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..9710bec --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,34 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "mcp-python" +version = "0.1.2" +description = "Model Context Protocol implementation for Python" +readme = "README.md" +requires-python = ">=3.8" +dependencies = [ + "anyio", + "httpx", + "httpx-sse", + "pydantic>=2.0.0", + "starlette", + "sse-starlette", +] + +[tool.hatch.build.targets.wheel] +packages = ["mcp_python"] + +[tool.pyright] +include = ["mcp_python", "tests"] +typeCheckingMode = "strict" + +[tool.ruff] +select = ["E", "F", "I"] +ignore = [] +line-length = 88 +target-version = "py38" + +[tool.ruff.per-file-ignores] +"__init__.py" = ["F401"] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/client/__init__.py b/tests/client/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/client/test_session.py b/tests/client/test_session.py new file mode 100644 index 0000000..adfc6df --- /dev/null +++ b/tests/client/test_session.py @@ -0,0 +1,93 @@ +import anyio +import pytest + +from mcp_python.client.session import ClientSession +from mcp_python.types import ( + ClientNotification, + ClientRequest, + Implementation, + InitializedNotification, + InitializeRequest, + InitializeResult, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ServerCapabilities, + ServerResult, +) + + +@pytest.mark.anyio +async def test_client_session_initialize(): + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + JSONRPCMessage + ](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + JSONRPCMessage + ](1) + + initialized_notification = None + + async def mock_server(): + nonlocal initialized_notification + + jsonrpc_request = await client_to_server_receive.receive() + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json") + ) + assert isinstance(request.root, InitializeRequest) + + result = ServerResult( + InitializeResult( + protocolVersion=1, + capabilities=ServerCapabilities( + logging=None, + resources=None, + tools=None, + experimental=None, + prompts=None, + ), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + async with server_to_client_send: + await server_to_client_send.send( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json"), + ) + ) + ) + jsonrpc_notification = await client_to_server_receive.receive() + assert isinstance(jsonrpc_notification.root, JSONRPCNotification) + initialized_notification = ClientNotification.model_validate( + jsonrpc_notification.model_dump(by_alias=True, mode="json") + ) + + async def listen_session(): + async for message in session.incoming_messages: + if isinstance(message, Exception): + raise message + + async with ( + ClientSession(server_to_client_receive, client_to_server_send) as session, + anyio.create_task_group() as tg, + ): + tg.start_soon(mock_server) + tg.start_soon(listen_session) + result = await session.initialize() + + # Assert the result + assert isinstance(result, InitializeResult) + assert result.protocolVersion == 1 + assert isinstance(result.capabilities, ServerCapabilities) + assert result.serverInfo == Implementation(name="mock-server", version="0.1.0") + + # Check that the client sent the initialized notification + assert initialized_notification + assert isinstance(initialized_notification.root, InitializedNotification) diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py new file mode 100644 index 0000000..b5e168b --- /dev/null +++ b/tests/client/test_stdio.py @@ -0,0 +1,38 @@ +import pytest + +from mcp_python.client.stdio import StdioServerParameters, stdio_client +from mcp_python.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse + + +@pytest.mark.anyio +async def test_stdio_client(): + server_parameters = StdioServerParameters(command="/usr/bin/tee") + + async with stdio_client(server_parameters) as (read_stream, write_stream): + # Test sending and receiving messages + messages = [ + JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")), + JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})), + ] + + async with write_stream: + for message in messages: + await write_stream.send(message) + + read_messages = [] + async with read_stream: + async for message in read_stream: + if isinstance(message, Exception): + raise message + + read_messages.append(message) + if len(read_messages) == 2: + break + + assert len(read_messages) == 2 + assert read_messages[0] == JSONRPCMessage( + root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + ) + assert read_messages[1] == JSONRPCMessage( + root=JSONRPCResponse(jsonrpc="2.0", id=2, result={}) + ) diff --git a/tests/server/__init__.py b/tests/server/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/server/test_session.py b/tests/server/test_session.py new file mode 100644 index 0000000..50994cc --- /dev/null +++ b/tests/server/test_session.py @@ -0,0 +1,59 @@ +import anyio +import pytest + +from mcp_python.client.session import ClientSession +from mcp_python.server.session import ServerSession +from mcp_python.types import ( + ClientNotification, + InitializedNotification, + JSONRPCMessage, +) + + +@pytest.mark.anyio +async def test_server_session_initialize(): + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream( + 1, JSONRPCMessage + ) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream( + 1, JSONRPCMessage + ) + + async def run_client(client: ClientSession): + async for message in client_session.incoming_messages: + if isinstance(message, Exception): + raise message + + received_initialized = False + + async def run_server(): + nonlocal received_initialized + + async with ServerSession( + client_to_server_receive, server_to_client_send + ) as server_session: + async for message in server_session.incoming_messages: + if isinstance(message, Exception): + raise message + + if isinstance(message, ClientNotification) and isinstance( + message.root, InitializedNotification + ): + received_initialized = True + return + + try: + async with ( + ClientSession( + server_to_client_receive, client_to_server_send + ) as client_session, + anyio.create_task_group() as tg, + ): + tg.start_soon(run_client, client_session) + tg.start_soon(run_server) + + await client_session.initialize() + except* anyio.ClosedResourceError: + pass + + assert received_initialized diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py new file mode 100644 index 0000000..782b750 --- /dev/null +++ b/tests/server/test_stdio.py @@ -0,0 +1,68 @@ +import io + +import anyio +import pytest + +from mcp_python.server.stdio import stdio_server +from mcp_python.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse + + +@pytest.mark.anyio +async def test_stdio_server(): + stdin = io.StringIO() + stdout = io.StringIO() + + messages = [ + JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")), + JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})), + ] + + for message in messages: + stdin.write(message.model_dump_json() + "\n") + stdin.seek(0) + + async with stdio_server( + stdin=anyio.AsyncFile(stdin), stdout=anyio.AsyncFile(stdout) + ) as (read_stream, write_stream): + received_messages = [] + async with read_stream: + async for message in read_stream: + if isinstance(message, Exception): + raise message + received_messages.append(message) + if len(received_messages) == 2: + break + + # Verify received messages + assert len(received_messages) == 2 + assert received_messages[0] == JSONRPCMessage( + root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + ) + assert received_messages[1] == JSONRPCMessage( + root=JSONRPCResponse(jsonrpc="2.0", id=2, result={}) + ) + + # Test sending responses from the server + responses = [ + JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")), + JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})), + ] + + async with write_stream: + for response in responses: + await write_stream.send(response) + + stdout.seek(0) + output_lines = stdout.readlines() + assert len(output_lines) == 2 + + received_responses = [ + JSONRPCMessage.model_validate_json(line.strip()) for line in output_lines + ] + assert len(received_responses) == 2 + assert received_responses[0] == JSONRPCMessage( + root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping") + ) + assert received_responses[1] == JSONRPCMessage( + root=JSONRPCResponse(jsonrpc="2.0", id=4, result={}) + ) diff --git a/tests/test_types.py b/tests/test_types.py new file mode 100644 index 0000000..75cac11 --- /dev/null +++ b/tests/test_types.py @@ -0,0 +1,24 @@ +from mcp_python.types import ClientRequest, JSONRPCMessage, JSONRPCRequest + + +def test_jsonrpc_request(): + json_data = { + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": 1, + "capabilities": {"batch": None, "sampling": None}, + "clientInfo": {"name": "mcp_python", "version": "0.1.0"}, + }, + } + + request = JSONRPCMessage.model_validate(json_data) + assert isinstance(request.root, JSONRPCRequest) + ClientRequest.model_validate(request.model_dump(by_alias=True)) + + assert request.root.jsonrpc == "2.0" + assert request.root.id == 1 + assert request.root.method == "initialize" + assert request.root.params is not None + assert request.root.params["protocolVersion"] == 1