mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-18 14:34:27 +01:00
Initial import
This commit is contained in:
22
.devcontainer/devcontainer.json
Normal file
22
.devcontainer/devcontainer.json
Normal file
@@ -0,0 +1,22 @@
|
||||
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
|
||||
// README at: https://github.com/devcontainers/templates/tree/main/src/python
|
||||
{
|
||||
"name": "Python 3",
|
||||
// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
|
||||
"image": "mcr.microsoft.com/devcontainers/python:1-3.12-bookworm"
|
||||
|
||||
// Features to add to the dev container. More info: https://containers.dev/features.
|
||||
// "features": {},
|
||||
|
||||
// Use 'forwardPorts' to make a list of ports inside the container available locally.
|
||||
// "forwardPorts": [],
|
||||
|
||||
// Use 'postCreateCommand' to run commands after the container is created.
|
||||
// "postCreateCommand": "pip3 install --user -r requirements.txt",
|
||||
|
||||
// Configure tool-specific properties.
|
||||
// "customizations": {},
|
||||
|
||||
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
|
||||
// "remoteUser": "root"
|
||||
}
|
||||
162
.gitignore
vendored
Normal file
162
.gitignore
vendored
Normal file
@@ -0,0 +1,162 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
2
README.md
Normal file
2
README.md
Normal file
@@ -0,0 +1,2 @@
|
||||
# mcp-python
|
||||
Model Context Protocol implementation for Python
|
||||
104
mcp_python/__init__.py
Normal file
104
mcp_python/__init__.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from .client.session import ClientSession
|
||||
from .client.stdio import StdioServerParameters, stdio_client
|
||||
from .server.session import ServerSession
|
||||
from .server.stdio import stdio_server
|
||||
from .shared.exceptions import McpError
|
||||
from .types import (
|
||||
CallToolRequest,
|
||||
ClientCapabilities,
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
ClientResult,
|
||||
CompleteRequest,
|
||||
CreateMessageRequest,
|
||||
CreateMessageResult,
|
||||
ErrorData,
|
||||
GetPromptRequest,
|
||||
GetPromptResult,
|
||||
Implementation,
|
||||
IncludeContext,
|
||||
InitializedNotification,
|
||||
InitializeRequest,
|
||||
InitializeResult,
|
||||
JSONRPCError,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
ListPromptsRequest,
|
||||
ListPromptsResult,
|
||||
ListResourcesRequest,
|
||||
ListResourcesResult,
|
||||
ListToolsResult,
|
||||
LoggingLevel,
|
||||
LoggingMessageNotification,
|
||||
Notification,
|
||||
PingRequest,
|
||||
ProgressNotification,
|
||||
ReadResourceRequest,
|
||||
ReadResourceResult,
|
||||
Resource,
|
||||
ResourceUpdatedNotification,
|
||||
Role as SamplingRole,
|
||||
SamplingMessage,
|
||||
ServerCapabilities,
|
||||
ServerNotification,
|
||||
ServerRequest,
|
||||
ServerResult,
|
||||
SetLevelRequest,
|
||||
StopReason,
|
||||
SubscribeRequest,
|
||||
Tool,
|
||||
UnsubscribeRequest,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CallToolRequest",
|
||||
"ClientCapabilities",
|
||||
"ClientNotification",
|
||||
"ClientRequest",
|
||||
"ClientResult",
|
||||
"ClientSession",
|
||||
"CreateMessageRequest",
|
||||
"CreateMessageResult",
|
||||
"ErrorData",
|
||||
"GetPromptRequest",
|
||||
"GetPromptResult",
|
||||
"Implementation",
|
||||
"IncludeContext",
|
||||
"InitializeRequest",
|
||||
"InitializeResult",
|
||||
"InitializedNotification",
|
||||
"JSONRPCError",
|
||||
"JSONRPCRequest",
|
||||
"ListPromptsRequest",
|
||||
"ListPromptsResult",
|
||||
"ListResourcesRequest",
|
||||
"ListResourcesResult",
|
||||
"ListToolsResult",
|
||||
"LoggingLevel",
|
||||
"LoggingMessageNotification",
|
||||
"McpError",
|
||||
"Notification",
|
||||
"PingRequest",
|
||||
"ProgressNotification",
|
||||
"ReadResourceRequest",
|
||||
"ReadResourceResult",
|
||||
"ResourceUpdatedNotification",
|
||||
"Resource",
|
||||
"SamplingMessage",
|
||||
"SamplingRole",
|
||||
"ServerCapabilities",
|
||||
"ServerNotification",
|
||||
"ServerRequest",
|
||||
"ServerResult",
|
||||
"ServerSession",
|
||||
"SetLevelRequest",
|
||||
"StdioServerParameters",
|
||||
"StopReason",
|
||||
"SubscribeRequest",
|
||||
"Tool",
|
||||
"UnsubscribeRequest",
|
||||
"stdio_client",
|
||||
"stdio_server",
|
||||
"CompleteRequest",
|
||||
"JSONRPCResponse",
|
||||
]
|
||||
0
mcp_python/client/__init__.py
Normal file
0
mcp_python/client/__init__.py
Normal file
76
mcp_python/client/__main__.py
Normal file
76
mcp_python/client/__main__.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import logging
|
||||
import sys
|
||||
from functools import partial
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import anyio
|
||||
import click
|
||||
|
||||
from mcp_python.client.session import ClientSession
|
||||
from mcp_python.client.sse import sse_client
|
||||
from mcp_python.client.stdio import StdioServerParameters, stdio_client
|
||||
|
||||
if not sys.warnoptions:
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("client")
|
||||
|
||||
|
||||
async def receive_loop(session: ClientSession):
|
||||
logger.info("Starting receive loop")
|
||||
async for message in session.incoming_messages:
|
||||
if isinstance(message, Exception):
|
||||
logger.error("Error: %s", message)
|
||||
continue
|
||||
|
||||
logger.info("Received message from server: %s", message)
|
||||
|
||||
|
||||
async def run_session(read_stream, write_stream):
|
||||
async with (
|
||||
ClientSession(read_stream, write_stream) as session,
|
||||
anyio.create_task_group() as tg,
|
||||
):
|
||||
tg.start_soon(receive_loop, session)
|
||||
|
||||
logger.info("Initializing session")
|
||||
await session.initialize()
|
||||
logger.info("Initialized")
|
||||
|
||||
|
||||
async def main(command_or_url: str, args: list[str], env: list[tuple[str, str]]):
|
||||
env_dict = dict(env)
|
||||
|
||||
if urlparse(command_or_url).scheme in ("http", "https"):
|
||||
# Use SSE client for HTTP(S) URLs
|
||||
async with sse_client(command_or_url) as streams:
|
||||
await run_session(*streams)
|
||||
else:
|
||||
# Use stdio client for commands
|
||||
server_parameters = StdioServerParameters(
|
||||
command=command_or_url, args=args, env=env_dict
|
||||
)
|
||||
async with stdio_client(server_parameters) as streams:
|
||||
await run_session(*streams)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("command_or_url")
|
||||
@click.argument("args", nargs=-1)
|
||||
@click.option(
|
||||
"--env",
|
||||
"-e",
|
||||
multiple=True,
|
||||
nargs=2,
|
||||
metavar="KEY VALUE",
|
||||
help="Environment variables to set. Can be used multiple times.",
|
||||
)
|
||||
def cli(*args, **kwargs):
|
||||
anyio.run(partial(main, *args, **kwargs), backend="trio")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
211
mcp_python/client/session.py
Normal file
211
mcp_python/client/session.py
Normal file
@@ -0,0 +1,211 @@
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl
|
||||
|
||||
from mcp_python.shared.session import BaseSession
|
||||
from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSION
|
||||
from mcp_python.types import (
|
||||
CallToolResult,
|
||||
ClientCapabilities,
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
ClientResult,
|
||||
EmptyResult,
|
||||
Implementation,
|
||||
InitializedNotification,
|
||||
InitializeResult,
|
||||
JSONRPCMessage,
|
||||
ListResourcesResult,
|
||||
LoggingLevel,
|
||||
ReadResourceResult,
|
||||
ServerNotification,
|
||||
ServerRequest,
|
||||
)
|
||||
|
||||
|
||||
class ClientSession(
|
||||
BaseSession[
|
||||
ClientRequest,
|
||||
ClientNotification,
|
||||
ClientResult,
|
||||
ServerRequest,
|
||||
ServerNotification,
|
||||
]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||
) -> None:
|
||||
super().__init__(read_stream, write_stream, ServerRequest, ServerNotification)
|
||||
|
||||
async def initialize(self) -> InitializeResult:
|
||||
from mcp_python.types import (
|
||||
InitializeRequest,
|
||||
InitializeRequestParams,
|
||||
)
|
||||
|
||||
result = await self.send_request(
|
||||
ClientRequest(
|
||||
InitializeRequest(
|
||||
method="initialize",
|
||||
params=InitializeRequestParams(
|
||||
protocolVersion=SUPPORTED_PROTOCOL_VERSION,
|
||||
capabilities=ClientCapabilities(
|
||||
sampling=None, experimental=None
|
||||
),
|
||||
clientInfo=Implementation(name="mcp_python", version="0.1.0"),
|
||||
),
|
||||
)
|
||||
),
|
||||
InitializeResult,
|
||||
)
|
||||
|
||||
if result.protocolVersion != SUPPORTED_PROTOCOL_VERSION:
|
||||
raise RuntimeError(
|
||||
f"Unsupported protocol version from the server: {result.protocolVersion}"
|
||||
)
|
||||
|
||||
await self.send_notification(
|
||||
ClientNotification(
|
||||
InitializedNotification(method="notifications/initialized")
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def send_ping(self) -> EmptyResult:
|
||||
"""Send a ping request."""
|
||||
from mcp_python.types import PingRequest
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
PingRequest(
|
||||
method="ping",
|
||||
)
|
||||
),
|
||||
EmptyResult,
|
||||
)
|
||||
|
||||
async def send_progress_notification(
|
||||
self, progress_token: str | int, progress: float, total: float | None = None
|
||||
) -> None:
|
||||
"""Send a progress notification."""
|
||||
from mcp_python.types import (
|
||||
ProgressNotification,
|
||||
ProgressNotificationParams,
|
||||
)
|
||||
|
||||
await self.send_notification(
|
||||
ClientNotification(
|
||||
ProgressNotification(
|
||||
method="notifications/progress",
|
||||
params=ProgressNotificationParams(
|
||||
progressToken=progress_token,
|
||||
progress=progress,
|
||||
total=total,
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
async def set_logging_level(self, level: LoggingLevel) -> EmptyResult:
|
||||
"""Send a logging/setLevel request."""
|
||||
from mcp_python.types import (
|
||||
SetLevelRequest,
|
||||
SetLevelRequestParams,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
SetLevelRequest(
|
||||
method="logging/setLevel",
|
||||
params=SetLevelRequestParams(level=level),
|
||||
)
|
||||
),
|
||||
EmptyResult,
|
||||
)
|
||||
|
||||
async def list_resources(self) -> ListResourcesResult:
|
||||
"""Send a resources/list request."""
|
||||
from mcp_python.types import (
|
||||
ListResourcesRequest,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
ListResourcesRequest(
|
||||
method="resources/list",
|
||||
)
|
||||
),
|
||||
ListResourcesResult,
|
||||
)
|
||||
|
||||
async def read_resource(self, uri: AnyUrl) -> ReadResourceResult:
|
||||
"""Send a resources/read request."""
|
||||
from mcp_python.types import (
|
||||
ReadResourceRequest,
|
||||
ReadResourceRequestParams,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
ReadResourceRequest(
|
||||
method="resources/read",
|
||||
params=ReadResourceRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
ReadResourceResult,
|
||||
)
|
||||
|
||||
async def subscribe_resource(self, uri: AnyUrl) -> EmptyResult:
|
||||
"""Send a resources/subscribe request."""
|
||||
from mcp_python.types import (
|
||||
SubscribeRequest,
|
||||
SubscribeRequestParams,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
SubscribeRequest(
|
||||
method="resources/subscribe",
|
||||
params=SubscribeRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
EmptyResult,
|
||||
)
|
||||
|
||||
async def unsubscribe_resource(self, uri: AnyUrl) -> EmptyResult:
|
||||
"""Send a resources/unsubscribe request."""
|
||||
from mcp_python.types import (
|
||||
UnsubscribeRequest,
|
||||
UnsubscribeRequestParams,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
UnsubscribeRequest(
|
||||
method="resources/unsubscribe",
|
||||
params=UnsubscribeRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
EmptyResult,
|
||||
)
|
||||
|
||||
async def call_tool(
|
||||
self, name: str, arguments: dict | None = None
|
||||
) -> CallToolResult:
|
||||
"""Send a tools/call request."""
|
||||
from mcp_python.types import (
|
||||
CallToolRequest,
|
||||
CallToolRequestParams,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
CallToolRequest(
|
||||
method="tools/call",
|
||||
params=CallToolRequestParams(name=name, arguments=arguments),
|
||||
)
|
||||
),
|
||||
CallToolResult,
|
||||
)
|
||||
129
mcp_python/client/sse.py
Normal file
129
mcp_python/client/sse.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
from anyio.abc import TaskStatus
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from httpx_sse import aconnect_sse
|
||||
|
||||
from mcp_python.types import JSONRPCMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def remove_request_params(url: str) -> str:
|
||||
return urljoin(url, urlparse(url).path)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def sse_client(url: str, timeout: float = 5, sse_read_timeout: float = 60 * 5):
|
||||
"""
|
||||
Client transport for SSE.
|
||||
|
||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||
"""
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
try:
|
||||
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with aconnect_sse(
|
||||
client,
|
||||
"GET",
|
||||
url,
|
||||
timeout=httpx.Timeout(timeout, read=sse_read_timeout),
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("SSE connection established")
|
||||
|
||||
async def sse_reader(
|
||||
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
|
||||
):
|
||||
try:
|
||||
async for sse in event_source.aiter_sse():
|
||||
logger.debug(f"Received SSE event: {sse.event}")
|
||||
match sse.event:
|
||||
case "endpoint":
|
||||
endpoint_url = urljoin(url, sse.data)
|
||||
logger.info(
|
||||
f"Received endpoint URL: {endpoint_url}"
|
||||
)
|
||||
|
||||
url_parsed = urlparse(url)
|
||||
endpoint_parsed = urlparse(endpoint_url)
|
||||
if (
|
||||
url_parsed.netloc != endpoint_parsed.netloc
|
||||
or url_parsed.scheme
|
||||
!= endpoint_parsed.scheme
|
||||
):
|
||||
error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
task_status.started(endpoint_url)
|
||||
|
||||
case "message":
|
||||
try:
|
||||
message = (
|
||||
JSONRPCMessage.model_validate_json(
|
||||
sse.data
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
f"Received server message: {message}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f"Error parsing server message: {exc}"
|
||||
)
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
await read_stream_writer.send(message)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error in sse_reader: {exc}")
|
||||
await read_stream_writer.send(exc)
|
||||
finally:
|
||||
await read_stream_writer.aclose()
|
||||
|
||||
async def post_writer(endpoint_url: str):
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for message in write_stream_reader:
|
||||
logger.debug(f"Sending client message: {message}")
|
||||
response = await client.post(
|
||||
endpoint_url,
|
||||
json=message.model_dump(by_alias=True, mode="json"),
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.debug(
|
||||
f"Client message sent successfully: {response.status_code}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error in post_writer: {exc}")
|
||||
finally:
|
||||
await write_stream.aclose()
|
||||
|
||||
endpoint_url = await tg.start(sse_reader)
|
||||
logger.info(
|
||||
f"Starting post writer with endpoint URL: {endpoint_url}"
|
||||
)
|
||||
tg.start_soon(post_writer, endpoint_url)
|
||||
|
||||
try:
|
||||
yield read_stream, write_stream
|
||||
finally:
|
||||
tg.cancel_scope.cancel()
|
||||
finally:
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream.aclose()
|
||||
84
mcp_python/client/stdio.py
Normal file
84
mcp_python/client/stdio.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import anyio
|
||||
import anyio.lowlevel
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from anyio.streams.text import TextReceiveStream
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from mcp_python.types import JSONRPCMessage
|
||||
|
||||
|
||||
class StdioServerParameters(BaseModel):
|
||||
command: str
|
||||
"""The executable to run to start the server."""
|
||||
|
||||
args: list[str] = Field(default_factory=list)
|
||||
"""Command line arguments to pass to the executable."""
|
||||
|
||||
env: dict[str, str] = Field(default_factory=dict)
|
||||
"""
|
||||
The environment to use when spawning the process.
|
||||
|
||||
The environment is NOT inherited from the parent process by default.
|
||||
"""
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def stdio_client(server: StdioServerParameters):
|
||||
"""
|
||||
Client transport for stdio: this will connect to a server by spawning a process and communicating with it over stdin/stdout.
|
||||
"""
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
process = await anyio.open_process(
|
||||
[server.command, *server.args], env=server.env, stderr=sys.stderr
|
||||
)
|
||||
|
||||
async def stdout_reader():
|
||||
assert process.stdout, "Opened process is missing stdout"
|
||||
|
||||
try:
|
||||
async with read_stream_writer:
|
||||
buffer = ""
|
||||
async for chunk in TextReceiveStream(process.stdout):
|
||||
lines = (buffer + chunk).split("\n")
|
||||
buffer = lines.pop()
|
||||
|
||||
for line in lines:
|
||||
try:
|
||||
message = JSONRPCMessage.model_validate_json(line)
|
||||
except Exception as exc:
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
await read_stream_writer.send(message)
|
||||
except anyio.ClosedResourceError:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
async def stdin_writer():
|
||||
assert process.stdin, "Opened process is missing stdin"
|
||||
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for message in write_stream_reader:
|
||||
json = message.model_dump_json(by_alias=True)
|
||||
await process.stdin.send((json + "\n").encode())
|
||||
except anyio.ClosedResourceError:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
async with (
|
||||
anyio.create_task_group() as tg,
|
||||
process,
|
||||
):
|
||||
tg.start_soon(stdout_reader)
|
||||
tg.start_soon(stdin_writer)
|
||||
yield read_stream, write_stream
|
||||
0
mcp_python/py.typed
Normal file
0
mcp_python/py.typed
Normal file
347
mcp_python/server/__init__.py
Normal file
347
mcp_python/server/__init__.py
Normal file
@@ -0,0 +1,347 @@
|
||||
import contextvars
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl
|
||||
|
||||
from mcp_python.server import types
|
||||
from mcp_python.server.session import ServerSession
|
||||
from mcp_python.server.stdio import stdio_server as stdio_server
|
||||
from mcp_python.shared.context import RequestContext
|
||||
from mcp_python.shared.session import RequestResponder
|
||||
from mcp_python.types import (
|
||||
METHOD_NOT_FOUND,
|
||||
CallToolRequest,
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
CompleteRequest,
|
||||
ErrorData,
|
||||
JSONRPCMessage,
|
||||
ListResourcesRequest,
|
||||
ListResourcesResult,
|
||||
LoggingLevel,
|
||||
ProgressNotification,
|
||||
Prompt,
|
||||
PromptReference,
|
||||
ReadResourceRequest,
|
||||
ReadResourceResult,
|
||||
Resource,
|
||||
ResourceReference,
|
||||
ServerResult,
|
||||
SetLevelRequest,
|
||||
SubscribeRequest,
|
||||
UnsubscribeRequest,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
request_ctx: contextvars.ContextVar[RequestContext] = contextvars.ContextVar(
|
||||
"request_ctx"
|
||||
)
|
||||
|
||||
|
||||
class Server:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]] = {}
|
||||
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
|
||||
logger.info(f"Initializing server '{name}'")
|
||||
|
||||
@property
|
||||
def request_context(self) -> RequestContext:
|
||||
"""If called outside of a request context, this will raise a LookupError."""
|
||||
return request_ctx.get()
|
||||
|
||||
def list_prompts(self):
|
||||
from mcp_python.types import ListPromptsRequest, ListPromptsResult
|
||||
|
||||
def decorator(func: Callable[[], Awaitable[list[Prompt]]]):
|
||||
logger.debug(f"Registering handler for PromptListRequest")
|
||||
|
||||
async def handler(_: Any):
|
||||
prompts = await func()
|
||||
return ServerResult(ListPromptsResult(prompts=prompts))
|
||||
|
||||
self.request_handlers[ListPromptsRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def get_prompt(self):
|
||||
from mcp_python.types import (
|
||||
GetPromptRequest,
|
||||
GetPromptResult,
|
||||
ImageContent,
|
||||
Role as Role,
|
||||
SamplingMessage,
|
||||
TextContent,
|
||||
)
|
||||
|
||||
def decorator(
|
||||
func: Callable[
|
||||
[str, dict[str, str] | None], Awaitable[types.PromptResponse]
|
||||
],
|
||||
):
|
||||
logger.debug(f"Registering handler for GetPromptRequest")
|
||||
|
||||
async def handler(req: GetPromptRequest):
|
||||
prompt_get = await func(req.params.name, req.params.arguments)
|
||||
messages = []
|
||||
for message in prompt_get.messages:
|
||||
match message.content:
|
||||
case str() as text_content:
|
||||
content = TextContent(type="text", text=text_content)
|
||||
case types.ImageContent() as img_content:
|
||||
content = ImageContent(
|
||||
type="image",
|
||||
data=img_content.data,
|
||||
mimeType=img_content.mime_type,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Unexpected content type: {type(message.content)}"
|
||||
)
|
||||
|
||||
sampling_message = SamplingMessage(
|
||||
role=message.role, content=content
|
||||
)
|
||||
messages.append(sampling_message)
|
||||
|
||||
return ServerResult(
|
||||
GetPromptResult(description=prompt_get.desc, messages=messages)
|
||||
)
|
||||
|
||||
self.request_handlers[GetPromptRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def list_resources(self):
|
||||
def decorator(func: Callable[[], Awaitable[list[Resource]]]):
|
||||
logger.debug(f"Registering handler for ListResourcesRequest")
|
||||
|
||||
async def handler(_: Any):
|
||||
resources = await func()
|
||||
return ServerResult(
|
||||
ListResourcesResult(resources=resources, resourceTemplates=None)
|
||||
)
|
||||
|
||||
self.request_handlers[ListResourcesRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def read_resource(self):
|
||||
from mcp_python.types import (
|
||||
BlobResourceContents,
|
||||
TextResourceContents,
|
||||
)
|
||||
|
||||
def decorator(func: Callable[[AnyUrl], Awaitable[str | bytes]]):
|
||||
logger.debug(f"Registering handler for ReadResourceRequest")
|
||||
|
||||
async def handler(req: ReadResourceRequest):
|
||||
result = await func(req.params.uri)
|
||||
match result:
|
||||
case str(s):
|
||||
content = TextResourceContents(
|
||||
uri=req.params.uri,
|
||||
text=s,
|
||||
mimeType="text/plain",
|
||||
)
|
||||
case bytes(b):
|
||||
import base64
|
||||
|
||||
content = BlobResourceContents(
|
||||
uri=req.params.uri,
|
||||
blob=base64.urlsafe_b64encode(b).decode(),
|
||||
mimeType="application/octet-stream",
|
||||
)
|
||||
|
||||
return ServerResult(
|
||||
ReadResourceResult(
|
||||
contents=[content],
|
||||
)
|
||||
)
|
||||
|
||||
self.request_handlers[ReadResourceRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def set_logging_level(self):
|
||||
from mcp_python.types import EmptyResult
|
||||
|
||||
def decorator(func: Callable[[LoggingLevel], Awaitable[None]]):
|
||||
logger.debug(f"Registering handler for SetLevelRequest")
|
||||
|
||||
async def handler(req: SetLevelRequest):
|
||||
await func(req.params.level)
|
||||
return ServerResult(EmptyResult())
|
||||
|
||||
self.request_handlers[SetLevelRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def subscribe_resource(self):
|
||||
from mcp_python.types import EmptyResult
|
||||
|
||||
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
|
||||
logger.debug(f"Registering handler for SubscribeRequest")
|
||||
|
||||
async def handler(req: SubscribeRequest):
|
||||
await func(req.params.uri)
|
||||
return ServerResult(EmptyResult())
|
||||
|
||||
self.request_handlers[SubscribeRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def unsubscribe_resource(self):
|
||||
from mcp_python.types import EmptyResult
|
||||
|
||||
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
|
||||
logger.debug(f"Registering handler for UnsubscribeRequest")
|
||||
|
||||
async def handler(req: UnsubscribeRequest):
|
||||
await func(req.params.uri)
|
||||
return ServerResult(EmptyResult())
|
||||
|
||||
self.request_handlers[UnsubscribeRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def call_tool(self):
|
||||
from mcp_python.types import CallToolResult
|
||||
|
||||
def decorator(func: Callable[..., Awaitable[Any]]):
|
||||
logger.debug(f"Registering handler for CallToolRequest")
|
||||
|
||||
async def handler(req: CallToolRequest):
|
||||
result = await func(req.params.name, **(req.params.arguments or {}))
|
||||
return ServerResult(CallToolResult(toolResult=result))
|
||||
|
||||
self.request_handlers[CallToolRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def progress_notification(self):
|
||||
def decorator(
|
||||
func: Callable[[str | int, float, float | None], Awaitable[None]],
|
||||
):
|
||||
logger.debug(f"Registering handler for ProgressNotification")
|
||||
|
||||
async def handler(req: ProgressNotification):
|
||||
await func(
|
||||
req.params.progressToken, req.params.progress, req.params.total
|
||||
)
|
||||
|
||||
self.notification_handlers[ProgressNotification] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def completion(self):
|
||||
"""Provides completions for prompts and resource templates"""
|
||||
from mcp_python.types import CompleteResult, Completion, CompletionArgument
|
||||
|
||||
def decorator(
|
||||
func: Callable[
|
||||
[PromptReference | ResourceReference, CompletionArgument],
|
||||
Awaitable[Completion | None],
|
||||
],
|
||||
):
|
||||
logger.debug(f"Registering handler for CompleteRequest")
|
||||
|
||||
async def handler(req: CompleteRequest):
|
||||
completion = await func(req.params.ref, req.params.argument)
|
||||
return ServerResult(
|
||||
CompleteResult(
|
||||
completion=completion
|
||||
if completion is not None
|
||||
else Completion(values=[], total=None, hasMore=None),
|
||||
)
|
||||
)
|
||||
|
||||
self.request_handlers[CompleteRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
async def run(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||
):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
async with ServerSession(read_stream, write_stream) as session:
|
||||
async for message in session.incoming_messages:
|
||||
logger.debug(f"Received message: {message}")
|
||||
|
||||
match message:
|
||||
case RequestResponder(request=ClientRequest(root=req)):
|
||||
logger.info(
|
||||
f"Processing request of type {type(req).__name__}"
|
||||
)
|
||||
if type(req) in self.request_handlers:
|
||||
handler = self.request_handlers[type(req)]
|
||||
logger.debug(
|
||||
f"Dispatching request of type {type(req).__name__}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Set our global state that can be retrieved via
|
||||
# app.get_request_context()
|
||||
token = request_ctx.set(
|
||||
RequestContext(
|
||||
message.request_id,
|
||||
message.request_meta,
|
||||
session,
|
||||
)
|
||||
)
|
||||
response = await handler(req)
|
||||
# Reset the global state after we are done
|
||||
request_ctx.reset(token)
|
||||
except Exception as err:
|
||||
response = ErrorData(
|
||||
code=0, message=str(err), data=None
|
||||
)
|
||||
|
||||
await message.respond(response)
|
||||
else:
|
||||
await message.respond(
|
||||
ErrorData(
|
||||
code=METHOD_NOT_FOUND,
|
||||
message="Method not found",
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug("Response sent")
|
||||
case ClientNotification(root=notify):
|
||||
if type(notify) in self.notification_handlers:
|
||||
assert type(notify) in self.notification_handlers
|
||||
|
||||
handler = self.notification_handlers[type(notify)]
|
||||
logger.debug(
|
||||
f"Dispatching notification of type {type(notify).__name__}"
|
||||
)
|
||||
|
||||
try:
|
||||
await handler(notify)
|
||||
except Exception as err:
|
||||
logger.error(
|
||||
f"Uncaught exception in notification handler: {err}"
|
||||
)
|
||||
|
||||
for warning in w:
|
||||
logger.info(
|
||||
f"Warning: {warning.category.__name__}: {warning.message}"
|
||||
)
|
||||
35
mcp_python/server/__main__.py
Normal file
35
mcp_python/server/__main__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import anyio
|
||||
|
||||
from mcp_python.server.session import ServerSession
|
||||
from mcp_python.server.stdio import stdio_server
|
||||
|
||||
if not sys.warnoptions:
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("server")
|
||||
|
||||
|
||||
async def receive_loop(session: ServerSession):
|
||||
logger.info("Starting receive loop")
|
||||
async for message in session.incoming_messages:
|
||||
if isinstance(message, Exception):
|
||||
logger.error("Error: %s", message)
|
||||
continue
|
||||
|
||||
logger.info("Received message from client: %s", message)
|
||||
|
||||
|
||||
async def main():
|
||||
async with stdio_server() as (read_stream, write_stream):
|
||||
async with ServerSession(read_stream, write_stream) as session, write_stream:
|
||||
await receive_loop(session)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
anyio.run(main, backend="trio")
|
||||
203
mcp_python/server/session.py
Normal file
203
mcp_python/server/session.py
Normal file
@@ -0,0 +1,203 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
import anyio.lowlevel
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl
|
||||
|
||||
from mcp_python.shared.session import (
|
||||
BaseSession,
|
||||
RequestResponder,
|
||||
)
|
||||
from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSION
|
||||
from mcp_python.types import (
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
CreateMessageResult,
|
||||
EmptyResult,
|
||||
Implementation,
|
||||
IncludeContext,
|
||||
InitializedNotification,
|
||||
InitializeRequest,
|
||||
InitializeResult,
|
||||
JSONRPCMessage,
|
||||
LoggingLevel,
|
||||
SamplingMessage,
|
||||
ServerCapabilities,
|
||||
ServerNotification,
|
||||
ServerRequest,
|
||||
ServerResult,
|
||||
)
|
||||
|
||||
|
||||
class InitializationState(Enum):
|
||||
NotInitialized = 1
|
||||
Initializing = 2
|
||||
Initialized = 3
|
||||
|
||||
|
||||
class ServerSession(
|
||||
BaseSession[
|
||||
ServerRequest,
|
||||
ServerNotification,
|
||||
ServerResult,
|
||||
ClientRequest,
|
||||
ClientNotification,
|
||||
]
|
||||
):
|
||||
_initialized: InitializationState = InitializationState.NotInitialized
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||
) -> None:
|
||||
super().__init__(read_stream, write_stream, ClientRequest, ClientNotification)
|
||||
self._initialization_state = InitializationState.NotInitialized
|
||||
|
||||
async def _received_request(
|
||||
self, responder: RequestResponder[ClientRequest, ServerResult]
|
||||
):
|
||||
match responder.request.root:
|
||||
case InitializeRequest():
|
||||
self._initialization_state = InitializationState.Initializing
|
||||
await responder.respond(
|
||||
ServerResult(
|
||||
InitializeResult(
|
||||
protocolVersion=SUPPORTED_PROTOCOL_VERSION,
|
||||
capabilities=ServerCapabilities(
|
||||
logging=None,
|
||||
resources=None,
|
||||
tools=None,
|
||||
experimental=None,
|
||||
prompts={},
|
||||
),
|
||||
serverInfo=Implementation(
|
||||
name="mcp_python", version="0.1.0"
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
case _:
|
||||
if self._initialization_state != InitializationState.Initialized:
|
||||
raise RuntimeError(
|
||||
"Received request before initialization was complete"
|
||||
)
|
||||
|
||||
async def _received_notification(self, notification: ClientNotification) -> None:
|
||||
# Need this to avoid ASYNC910
|
||||
await anyio.lowlevel.checkpoint()
|
||||
match notification.root:
|
||||
case InitializedNotification():
|
||||
self._initialization_state = InitializationState.Initialized
|
||||
case _:
|
||||
if self._initialization_state != InitializationState.Initialized:
|
||||
raise RuntimeError(
|
||||
"Received notification before initialization was complete"
|
||||
)
|
||||
|
||||
async def send_log_message(
|
||||
self, level: LoggingLevel, data: Any, logger: str | None = None
|
||||
) -> None:
|
||||
"""Send a log message notification."""
|
||||
from mcp_python.types import (
|
||||
LoggingMessageNotification,
|
||||
LoggingMessageNotificationParams,
|
||||
)
|
||||
|
||||
await self.send_notification(
|
||||
ServerNotification(
|
||||
LoggingMessageNotification(
|
||||
method="notifications/message",
|
||||
params=LoggingMessageNotificationParams(
|
||||
level=level,
|
||||
data=data,
|
||||
logger=logger,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def send_resource_updated(self, uri: AnyUrl) -> None:
|
||||
"""Send a resource updated notification."""
|
||||
from mcp_python.types import (
|
||||
ResourceUpdatedNotification,
|
||||
ResourceUpdatedNotificationParams,
|
||||
)
|
||||
|
||||
await self.send_notification(
|
||||
ServerNotification(
|
||||
ResourceUpdatedNotification(
|
||||
method="notifications/resources/updated",
|
||||
params=ResourceUpdatedNotificationParams(uri=uri),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def request_create_message(
|
||||
self,
|
||||
messages: list[SamplingMessage],
|
||||
*,
|
||||
max_tokens: int,
|
||||
system_prompt: str | None = None,
|
||||
include_context: IncludeContext | None = None,
|
||||
temperature: float | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> CreateMessageResult:
|
||||
"""Send a sampling/create_message request."""
|
||||
from mcp_python.types import (
|
||||
CreateMessageRequest,
|
||||
CreateMessageRequestParams,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ServerRequest(
|
||||
CreateMessageRequest(
|
||||
method="sampling/createMessage",
|
||||
params=CreateMessageRequestParams(
|
||||
messages=messages,
|
||||
systemPrompt=system_prompt,
|
||||
includeContext=include_context,
|
||||
temperature=temperature,
|
||||
maxTokens=max_tokens,
|
||||
stopSequences=stop_sequences,
|
||||
metadata=metadata,
|
||||
),
|
||||
)
|
||||
),
|
||||
CreateMessageResult,
|
||||
)
|
||||
|
||||
async def send_ping(self) -> EmptyResult:
|
||||
"""Send a ping request."""
|
||||
from mcp_python.types import PingRequest
|
||||
|
||||
return await self.send_request(
|
||||
ServerRequest(
|
||||
PingRequest(
|
||||
method="ping",
|
||||
)
|
||||
),
|
||||
EmptyResult,
|
||||
)
|
||||
|
||||
async def send_progress_notification(
|
||||
self, progress_token: str | int, progress: float, total: float | None = None
|
||||
) -> None:
|
||||
"""Send a progress notification."""
|
||||
from mcp_python.types import ProgressNotification, ProgressNotificationParams
|
||||
|
||||
await self.send_notification(
|
||||
ServerNotification(
|
||||
ProgressNotification(
|
||||
method="notifications/progress",
|
||||
params=ProgressNotificationParams(
|
||||
progressToken=progress_token,
|
||||
progress=progress,
|
||||
total=total,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
133
mcp_python/server/sse.py
Normal file
133
mcp_python/server/sse.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import ValidationError
|
||||
from sse_starlette import EventSourceResponse
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
from mcp_python.types import JSONRPCMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SseServerTransport:
|
||||
"""
|
||||
SSE server transport for MCP. This class provides _two_ ASGI applications, suitable to be used with a framework like Starlette and a server like Hypercorn:
|
||||
|
||||
1. connect_sse() is an ASGI application which receives incoming GET requests, and sets up a new SSE stream to send server messages to the client.
|
||||
2. handle_post_message() is an ASGI application which receives incoming POST requests, which should contain client messages that link to a previously-established SSE session.
|
||||
"""
|
||||
|
||||
_endpoint: str
|
||||
_read_stream_writers: dict[UUID, MemoryObjectSendStream[JSONRPCMessage | Exception]]
|
||||
|
||||
def __init__(self, endpoint: str) -> None:
|
||||
"""
|
||||
Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL given.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self._endpoint = endpoint
|
||||
self._read_stream_writers = {}
|
||||
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
|
||||
if scope["type"] != "http":
|
||||
logger.error("connect_sse received non-HTTP request")
|
||||
raise ValueError("connect_sse can only handle HTTP requests")
|
||||
|
||||
logger.debug("Setting up SSE connection")
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
session_id = uuid4()
|
||||
session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}"
|
||||
self._read_stream_writers[session_id] = read_stream_writer
|
||||
logger.debug(f"Created new session with ID: {session_id}")
|
||||
|
||||
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream(
|
||||
0, dict[str, Any]
|
||||
)
|
||||
|
||||
async def sse_writer():
|
||||
logger.debug("Starting SSE writer")
|
||||
async with sse_stream_writer, write_stream_reader:
|
||||
await sse_stream_writer.send({"event": "endpoint", "data": session_uri})
|
||||
logger.debug(f"Sent endpoint event: {session_uri}")
|
||||
|
||||
async for message in write_stream_reader:
|
||||
logger.debug(f"Sending message via SSE: {message}")
|
||||
await sse_stream_writer.send(
|
||||
{
|
||||
"event": "message",
|
||||
"data": message.model_dump_json(by_alias=True),
|
||||
}
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
response = EventSourceResponse(
|
||||
content=sse_stream_reader, data_sender_callable=sse_writer
|
||||
)
|
||||
logger.debug("Starting SSE response task")
|
||||
tg.start_soon(response, scope, receive, send)
|
||||
|
||||
logger.debug("Yielding read and write streams")
|
||||
yield (read_stream, write_stream)
|
||||
|
||||
async def handle_post_message(
|
||||
self, scope: Scope, receive: Receive, send: Send
|
||||
) -> None:
|
||||
logger.debug("Handling POST message")
|
||||
request = Request(scope, receive)
|
||||
|
||||
session_id_param = request.query_params.get("session_id")
|
||||
if session_id_param is None:
|
||||
logger.warning("Received request without session_id")
|
||||
response = Response("session_id is required", status_code=400)
|
||||
return await response(scope, receive, send)
|
||||
|
||||
try:
|
||||
session_id = UUID(hex=session_id_param)
|
||||
logger.debug(f"Parsed session ID: {session_id}")
|
||||
except ValueError:
|
||||
logger.warning(f"Received invalid session ID: {session_id_param}")
|
||||
response = Response("Invalid session ID", status_code=400)
|
||||
return await response(scope, receive, send)
|
||||
|
||||
writer = self._read_stream_writers.get(session_id)
|
||||
if not writer:
|
||||
logger.warning(f"Could not find session for ID: {session_id}")
|
||||
response = Response("Could not find session", status_code=404)
|
||||
return await response(scope, receive, send)
|
||||
|
||||
json = await request.json()
|
||||
logger.debug(f"Received JSON: {json}")
|
||||
|
||||
try:
|
||||
message = JSONRPCMessage.model_validate(json)
|
||||
logger.debug(f"Validated client message: {message}")
|
||||
except ValidationError as err:
|
||||
logger.error(f"Failed to parse message: {err}")
|
||||
response = Response("Could not parse message", status_code=400)
|
||||
await response(scope, receive, send)
|
||||
await writer.send(err)
|
||||
return
|
||||
|
||||
logger.debug(f"Sending message to writer: {message}")
|
||||
response = Response("Accepted", status_code=202)
|
||||
await response(scope, receive, send)
|
||||
await writer.send(message)
|
||||
60
mcp_python/server/stdio.py
Normal file
60
mcp_python/server/stdio.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import anyio
|
||||
import anyio.lowlevel
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
|
||||
from mcp_python.types import JSONRPCMessage
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def stdio_server(
|
||||
stdin: anyio.AsyncFile | None = None, stdout: anyio.AsyncFile | None = None
|
||||
):
|
||||
"""
|
||||
Server transport for stdio: this communicates with an MCP client by reading from the current process' stdin and writing to stdout.
|
||||
"""
|
||||
# Purposely not using context managers for these, as we don't want to close standard process handles.
|
||||
if not stdin:
|
||||
stdin = anyio.wrap_file(sys.stdin)
|
||||
if not stdout:
|
||||
stdout = anyio.wrap_file(sys.stdout)
|
||||
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
async def stdin_reader():
|
||||
try:
|
||||
async with read_stream_writer:
|
||||
async for line in stdin:
|
||||
try:
|
||||
message = JSONRPCMessage.model_validate_json(line)
|
||||
except Exception as exc:
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
await read_stream_writer.send(message)
|
||||
except anyio.ClosedResourceError:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
async def stdout_writer():
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for message in write_stream_reader:
|
||||
json = message.model_dump_json(by_alias=True)
|
||||
await stdout.write(json + "\n")
|
||||
await stdout.flush()
|
||||
except anyio.ClosedResourceError:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(stdin_reader)
|
||||
tg.start_soon(stdout_writer)
|
||||
yield read_stream, write_stream
|
||||
27
mcp_python/server/types.py
Normal file
27
mcp_python/server/types.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""
|
||||
This module provides simpler types to use with the server for managing prompts.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
from mcp_python.types import Role
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageContent:
|
||||
type: Literal["image"]
|
||||
data: str
|
||||
mime_type: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
role: Role
|
||||
content: str | ImageContent
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptResponse:
|
||||
messages: list[Message]
|
||||
desc: str | None = None
|
||||
58
mcp_python/server/websocket.py
Normal file
58
mcp_python/server/websocket.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from starlette.types import Receive, Scope, Send
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
from mcp_python.types import JSONRPCMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def websocket_server(scope: Scope, receive: Receive, send: Send):
|
||||
"""
|
||||
WebSocket server transport for MCP. This is an ASGI application, suitable to be used with a framework like Starlette and a server like Hypercorn.
|
||||
"""
|
||||
|
||||
websocket = WebSocket(scope, receive, send)
|
||||
await websocket.accept(subprotocol="mcp")
|
||||
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
async def ws_reader():
|
||||
try:
|
||||
async with read_stream_writer:
|
||||
async for message in websocket.iter_json():
|
||||
try:
|
||||
client_message = JSONRPCMessage.model_validate(message)
|
||||
except Exception as exc:
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
await read_stream_writer.send(client_message)
|
||||
except anyio.ClosedResourceError:
|
||||
await websocket.close()
|
||||
|
||||
async def ws_writer():
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for message in write_stream_reader:
|
||||
obj = message.model_dump(by_alias=True, mode="json")
|
||||
await websocket.send_json(obj)
|
||||
except anyio.ClosedResourceError:
|
||||
await websocket.close()
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(ws_reader)
|
||||
tg.start_soon(ws_writer)
|
||||
yield (read_stream, write_stream)
|
||||
0
mcp_python/shared/__init__.py
Normal file
0
mcp_python/shared/__init__.py
Normal file
14
mcp_python/shared/context.py
Normal file
14
mcp_python/shared/context.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from mcp_python.shared.session import BaseSession
|
||||
from mcp_python.types import RequestId, RequestParams
|
||||
|
||||
SessionT = TypeVar("SessionT", bound=BaseSession)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext(Generic[SessionT]):
|
||||
request_id: RequestId
|
||||
meta: RequestParams.Meta | None
|
||||
session: SessionT
|
||||
9
mcp_python/shared/exceptions.py
Normal file
9
mcp_python/shared/exceptions.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from mcp_python.types import ErrorData
|
||||
|
||||
|
||||
class McpError(Exception):
|
||||
"""
|
||||
Exception type raised when an error arrives over an MCP connection.
|
||||
"""
|
||||
|
||||
error: ErrorData
|
||||
40
mcp_python/shared/progress.py
Normal file
40
mcp_python/shared/progress.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from mcp_python.shared.context import RequestContext
|
||||
from mcp_python.shared.session import BaseSession
|
||||
from mcp_python.types import ProgressToken
|
||||
|
||||
|
||||
class Progress(BaseModel):
|
||||
progress: float
|
||||
total: float | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProgressContext:
|
||||
session: BaseSession
|
||||
progress_token: ProgressToken
|
||||
total: float | None
|
||||
current: float = field(default=0.0, init=False)
|
||||
|
||||
async def progress(self, amount: float) -> None:
|
||||
self.current += amount
|
||||
|
||||
await self.session.send_progress_notification(
|
||||
self.progress_token, self.current, total=self.total
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def progress(ctx: RequestContext, total: float | None = None):
|
||||
if ctx.meta is None or ctx.meta.progressToken is None:
|
||||
raise ValueError("No progress token provided")
|
||||
|
||||
progress_ctx = ProgressContext(ctx.session, ctx.meta.progressToken, total)
|
||||
try:
|
||||
yield progress_ctx
|
||||
finally:
|
||||
pass
|
||||
244
mcp_python/shared/session.py
Normal file
244
mcp_python/shared/session.py
Normal file
@@ -0,0 +1,244 @@
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import anyio
|
||||
import anyio.lowlevel
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import BaseModel
|
||||
|
||||
from mcp_python.shared.exceptions import McpError
|
||||
from mcp_python.types import (
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
ClientResult,
|
||||
ErrorData,
|
||||
JSONRPCError,
|
||||
JSONRPCMessage,
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
RequestParams,
|
||||
ServerNotification,
|
||||
ServerRequest,
|
||||
ServerResult,
|
||||
)
|
||||
|
||||
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
|
||||
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
|
||||
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
|
||||
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
|
||||
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
|
||||
ReceiveNotificationT = TypeVar(
|
||||
"ReceiveNotificationT", ClientNotification, ServerNotification
|
||||
)
|
||||
|
||||
RequestId = str | int
|
||||
|
||||
|
||||
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||
def __init__(
|
||||
self,
|
||||
request_id: RequestId,
|
||||
request_meta: RequestParams.Meta | None,
|
||||
request: ReceiveRequestT,
|
||||
session: "BaseSession",
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.request_meta = request_meta
|
||||
self.request = request
|
||||
self._session = session
|
||||
self._responded = False
|
||||
|
||||
async def respond(self, response: SendResultT | ErrorData) -> None:
|
||||
assert not self._responded, "Request already responded to"
|
||||
self._responded = True
|
||||
|
||||
await self._session._send_response(
|
||||
request_id=self.request_id, response=response
|
||||
)
|
||||
|
||||
|
||||
class BaseSession(
|
||||
AbstractAsyncContextManager,
|
||||
Generic[
|
||||
SendRequestT,
|
||||
SendNotificationT,
|
||||
SendResultT,
|
||||
ReceiveRequestT,
|
||||
ReceiveNotificationT,
|
||||
],
|
||||
):
|
||||
"""
|
||||
Implements an MCP "session" on top of read/write streams, including features like request/response linking, notifications, and progress.
|
||||
|
||||
This class is an async context manager that automatically starts processing messages when entered.
|
||||
"""
|
||||
|
||||
_response_streams: dict[
|
||||
RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]
|
||||
]
|
||||
_request_id: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||
receive_request_type: type[ReceiveRequestT],
|
||||
receive_notification_type: type[ReceiveNotificationT],
|
||||
) -> None:
|
||||
self._read_stream = read_stream
|
||||
self._write_stream = write_stream
|
||||
self._response_streams = {}
|
||||
self._request_id = 0
|
||||
self._receive_request_type = receive_request_type
|
||||
self._receive_notification_type = receive_notification_type
|
||||
|
||||
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
|
||||
anyio.create_memory_object_stream[
|
||||
RequestResponder[ReceiveRequestT, SendResultT]
|
||||
| ReceiveNotificationT
|
||||
| Exception
|
||||
]()
|
||||
)
|
||||
|
||||
async def __aenter__(self):
|
||||
self._task_group = anyio.create_task_group()
|
||||
await self._task_group.__aenter__()
|
||||
self._task_group.start_soon(self._receive_loop)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
# Using BaseSession as a context manager should not block on exit (this would be very surprising behavior), so make sure to cancel the tasks in the task group.
|
||||
self._task_group.cancel_scope.cancel()
|
||||
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
async def send_request(
|
||||
self,
|
||||
request: SendRequestT,
|
||||
result_type: type[ReceiveResultT],
|
||||
) -> ReceiveResultT:
|
||||
"""
|
||||
Sends a request and wait for a response. Raises an McpError if the response contains an error.
|
||||
|
||||
Do not use this method to emit notifications! Use send_notification() instead.
|
||||
"""
|
||||
|
||||
request_id = self._request_id
|
||||
self._request_id = request_id + 1
|
||||
|
||||
response_stream, response_stream_reader = anyio.create_memory_object_stream[
|
||||
JSONRPCResponse | JSONRPCError
|
||||
](1)
|
||||
self._response_streams[request_id] = response_stream
|
||||
|
||||
jsonrpc_request = JSONRPCRequest(
|
||||
jsonrpc="2.0", id=request_id, **request.model_dump(by_alias=True, mode="json")
|
||||
)
|
||||
|
||||
# TODO: Support progress callbacks
|
||||
|
||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_request))
|
||||
|
||||
response_or_error = await response_stream_reader.receive()
|
||||
if isinstance(response_or_error, JSONRPCError):
|
||||
raise McpError(response_or_error.error)
|
||||
else:
|
||||
return result_type.model_validate(response_or_error.result)
|
||||
|
||||
async def send_notification(self, notification: SendNotificationT) -> None:
|
||||
"""
|
||||
Emits a notification, which is a one-way message that does not expect a response.
|
||||
"""
|
||||
jsonrpc_notification = JSONRPCNotification(
|
||||
jsonrpc="2.0", **notification.model_dump(by_alias=True, mode="json")
|
||||
)
|
||||
|
||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_notification))
|
||||
|
||||
async def _send_response(
|
||||
self, request_id: RequestId, response: SendResultT | ErrorData
|
||||
) -> None:
|
||||
if isinstance(response, ErrorData):
|
||||
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
|
||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_error))
|
||||
else:
|
||||
jsonrpc_response = JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
result=response.model_dump(by_alias=True, mode="json"),
|
||||
)
|
||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_response))
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
async with (
|
||||
self._read_stream,
|
||||
self._write_stream,
|
||||
self._incoming_message_stream_writer,
|
||||
):
|
||||
async for message in self._read_stream:
|
||||
if isinstance(message, Exception):
|
||||
await self._incoming_message_stream_writer.send(message)
|
||||
elif isinstance(message.root, JSONRPCRequest):
|
||||
validated_request = self._receive_request_type.model_validate(
|
||||
message.root.model_dump(by_alias=True, mode="json")
|
||||
)
|
||||
responder = RequestResponder(
|
||||
request_id=message.root.id,
|
||||
request_meta=validated_request.root.params._meta
|
||||
if validated_request.root.params
|
||||
else None,
|
||||
request=validated_request,
|
||||
session=self,
|
||||
)
|
||||
|
||||
await self._received_request(responder)
|
||||
if not responder._responded:
|
||||
await self._incoming_message_stream_writer.send(responder)
|
||||
elif isinstance(message.root, JSONRPCNotification):
|
||||
notification = self._receive_notification_type.model_validate(
|
||||
message.root.model_dump(by_alias=True, mode="json")
|
||||
)
|
||||
|
||||
await self._received_notification(notification)
|
||||
await self._incoming_message_stream_writer.send(notification)
|
||||
else: # Response or error
|
||||
stream = self._response_streams.pop(message.root.id, None)
|
||||
if stream:
|
||||
await stream.send(message.root)
|
||||
else:
|
||||
await self._incoming_message_stream_writer.send(
|
||||
RuntimeError(
|
||||
f"Received response with an unknown request ID: {message}"
|
||||
)
|
||||
)
|
||||
|
||||
async def _received_request(
|
||||
self, responder: RequestResponder[ReceiveRequestT, SendResultT]
|
||||
) -> None:
|
||||
"""
|
||||
Can be overridden by subclasses to handle a request without needing to listen on the message stream.
|
||||
|
||||
If the request is responded to within this method, it will not be forwarded on to the message stream.
|
||||
"""
|
||||
|
||||
async def _received_notification(self, notification: ReceiveNotificationT) -> None:
|
||||
"""
|
||||
Can be overridden by subclasses to handle a notification without needing to listen on the message stream.
|
||||
"""
|
||||
|
||||
async def send_progress_notification(
|
||||
self, progress_token: str | int, progress: float, total: float | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Sends a progress notification for a request that is currently being processed.
|
||||
"""
|
||||
|
||||
@property
|
||||
def incoming_messages(
|
||||
self,
|
||||
) -> MemoryObjectReceiveStream[
|
||||
RequestResponder[ReceiveRequestT, SendResultT]
|
||||
| ReceiveNotificationT
|
||||
| Exception
|
||||
]:
|
||||
return self._incoming_message_stream_reader
|
||||
1
mcp_python/shared/version.py
Normal file
1
mcp_python/shared/version.py
Normal file
@@ -0,0 +1 @@
|
||||
SUPPORTED_PROTOCOL_VERSION = 1
|
||||
709
mcp_python/types.py
Normal file
709
mcp_python/types.py
Normal file
@@ -0,0 +1,709 @@
|
||||
from typing import Any, Generic, Literal, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, RootModel
|
||||
from pydantic.networks import AnyUrl
|
||||
|
||||
"""
|
||||
Model Context Protocol bindings for Python
|
||||
|
||||
These bindings were generated from https://github.com/anthropic-experimental/mcp-spec, using Claude, with a prompt something like the following:
|
||||
|
||||
Generate idiomatic Python bindings for this schema for MCP, or the "Model Context Protocol." The schema is defined in TypeScript, but there's also a JSON Schema version for reference.
|
||||
|
||||
* For the bindings, let's use Pydantic V2 models.
|
||||
* Each model should allow extra fields everywhere, by specifying `model_config = ConfigDict(extra='allow')`. Do this in every case, instead of a custom base class.
|
||||
* Union types should be represented with a Pydantic `RootModel`.
|
||||
* Define additional model classes instead of using dictionaries. Do this even if they're not separate types in the schema.
|
||||
"""
|
||||
|
||||
|
||||
ProgressToken = str | int
|
||||
|
||||
|
||||
class RequestParams(BaseModel):
|
||||
class Meta(BaseModel):
|
||||
progressToken: ProgressToken | None = None
|
||||
"""
|
||||
If specified, the caller is requesting out-of-band progress notifications for this request (as represented by notifications/progress). The value of this parameter is an opaque token that will be attached to any subsequent notifications. The receiver is not obligated to provide these notifications.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
_meta: Meta | None = None
|
||||
|
||||
|
||||
class NotificationParams(BaseModel):
|
||||
class Meta(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
_meta: Meta | None = None
|
||||
"""
|
||||
This parameter name is reserved by MCP to allow clients and servers to attach additional metadata to their notifications.
|
||||
"""
|
||||
|
||||
|
||||
RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams)
|
||||
NotificationParamsT = TypeVar("NotificationParamsT", bound=NotificationParams)
|
||||
|
||||
|
||||
class Request(BaseModel, Generic[RequestParamsT]):
|
||||
"""Base class for JSON-RPC requests."""
|
||||
|
||||
method: str
|
||||
params: RequestParamsT
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class Notification(BaseModel, Generic[NotificationParamsT]):
|
||||
"""Base class for JSON-RPC notifications."""
|
||||
|
||||
method: str
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class Result(BaseModel):
|
||||
"""Base class for JSON-RPC results."""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
_meta: dict[str, Any] | None = None
|
||||
"""
|
||||
This result property is reserved by the protocol to allow clients and servers to attach additional metadata to their responses.
|
||||
"""
|
||||
|
||||
|
||||
RequestId = str | int
|
||||
|
||||
|
||||
class JSONRPCRequest(Request):
|
||||
"""A request that expects a response."""
|
||||
|
||||
jsonrpc: Literal["2.0"]
|
||||
id: RequestId
|
||||
params: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class JSONRPCNotification(Notification):
|
||||
"""A notification which does not expect a response."""
|
||||
|
||||
jsonrpc: Literal["2.0"]
|
||||
params: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class JSONRPCResponse(BaseModel):
|
||||
"""A successful (non-error) response to a request."""
|
||||
|
||||
jsonrpc: Literal["2.0"]
|
||||
id: RequestId
|
||||
result: dict[str, Any]
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
# Standard JSON-RPC error codes
|
||||
PARSE_ERROR = -32700
|
||||
INVALID_REQUEST = -32600
|
||||
METHOD_NOT_FOUND = -32601
|
||||
INVALID_PARAMS = -32602
|
||||
|
||||
|
||||
class ErrorData(BaseModel):
|
||||
"""Error information for JSON-RPC error responses."""
|
||||
|
||||
code: int
|
||||
"""The error type that occurred."""
|
||||
message: str
|
||||
"""A short description of the error. The message SHOULD be limited to a concise single sentence."""
|
||||
data: Any | None = None
|
||||
"""Additional information about the error. The value of this member is defined by the sender (e.g. detailed error information, nested errors etc.)."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class JSONRPCError(BaseModel):
|
||||
"""A response to a request that indicates an error occurred."""
|
||||
|
||||
jsonrpc: Literal["2.0"]
|
||||
id: str | int
|
||||
error: ErrorData
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class JSONRPCMessage(
|
||||
RootModel[JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCError]
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class EmptyResult(Result):
|
||||
"""A response that indicates success but carries no data."""
|
||||
|
||||
|
||||
class Implementation(BaseModel):
|
||||
"""Describes the name and version of an MCP implementation."""
|
||||
|
||||
name: str
|
||||
version: str
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ClientCapabilities(BaseModel):
|
||||
"""Capabilities a client may support."""
|
||||
|
||||
experimental: dict[str, dict[str, Any]] | None = None
|
||||
"""Experimental, non-standard capabilities that the client supports."""
|
||||
sampling: dict[str, Any] | None = None
|
||||
"""Present if the client supports sampling from an LLM."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ServerCapabilities(BaseModel):
|
||||
"""Capabilities that a server may support."""
|
||||
|
||||
experimental: dict[str, dict[str, Any]] | None = None
|
||||
"""Experimental, non-standard capabilities that the server supports."""
|
||||
logging: dict[str, Any] | None = None
|
||||
"""Present if the server supports sending log messages to the client."""
|
||||
prompts: dict[str, Any] | None = None
|
||||
"""Present if the server offers any prompt templates."""
|
||||
resources: dict[str, Any] | None = None
|
||||
"""Present if the server offers any resources to read."""
|
||||
tools: dict[str, Any] | None = None
|
||||
"""Present if the server offers any tools to call."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class InitializeRequestParams(RequestParams):
|
||||
"""Parameters for the initialize request."""
|
||||
|
||||
protocolVersion: Literal[1]
|
||||
"""The latest version of the Model Context Protocol that the client supports."""
|
||||
capabilities: ClientCapabilities
|
||||
clientInfo: Implementation
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class InitializeRequest(Request):
|
||||
"""This request is sent from the client to the server when it first connects, asking it to begin initialization."""
|
||||
|
||||
method: Literal["initialize"]
|
||||
params: InitializeRequestParams
|
||||
|
||||
|
||||
class InitializeResult(Result):
|
||||
"""After receiving an initialize request from the client, the server sends this response."""
|
||||
|
||||
protocolVersion: Literal[1]
|
||||
"""The version of the Model Context Protocol that the server wants to use."""
|
||||
capabilities: ServerCapabilities
|
||||
serverInfo: Implementation
|
||||
|
||||
|
||||
class InitializedNotification(Notification):
|
||||
"""This notification is sent from the client to the server after initialization has finished."""
|
||||
|
||||
method: Literal["notifications/initialized"]
|
||||
params: NotificationParams | None = None
|
||||
|
||||
|
||||
class PingRequest(Request):
|
||||
"""A ping, issued by either the server or the client, to check that the other party is still alive."""
|
||||
|
||||
method: Literal["ping"]
|
||||
params: RequestParams | None = None
|
||||
|
||||
|
||||
class ProgressNotificationParams(NotificationParams):
|
||||
"""Parameters for progress notifications."""
|
||||
|
||||
progressToken: ProgressToken
|
||||
"""The progress token which was given in the initial request, used to associate this notification with the request that is proceeding."""
|
||||
progress: float
|
||||
"""The progress thus far. This should increase every time progress is made, even if the total is unknown."""
|
||||
total: float | None = None
|
||||
"""Total number of items to process (or total progress required), if known."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ProgressNotification(Notification):
|
||||
"""An out-of-band notification used to inform the receiver of a progress update for a long-running request."""
|
||||
|
||||
method: Literal["notifications/progress"]
|
||||
params: ProgressNotificationParams
|
||||
|
||||
|
||||
class ListResourcesRequest(Request):
|
||||
"""Sent from the client to request a list of resources the server has."""
|
||||
|
||||
method: Literal["resources/list"]
|
||||
params: RequestParams | None = None
|
||||
|
||||
|
||||
class Resource(BaseModel):
|
||||
"""A known resource that the server is capable of reading."""
|
||||
|
||||
uri: AnyUrl
|
||||
"""The URI of this resource."""
|
||||
mimeType: str | None = None
|
||||
"""The MIME type of this resource, if known."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ResourceTemplate(BaseModel):
|
||||
"""A template description for resources available on the server."""
|
||||
|
||||
uriTemplate: str
|
||||
"""A URI template (according to RFC 6570) that can be used to construct resource URIs."""
|
||||
name: str | None = None
|
||||
"""A human-readable name for the type of resource this template refers to."""
|
||||
description: str | None = None
|
||||
"""A human-readable description of what this template is for."""
|
||||
mimeType: str | None = None
|
||||
"""The MIME type for all resources that match this template. This should only be included if all resources matching this template have the same type."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ListResourcesResult(Result):
|
||||
"""The server's response to a resources/list request from the client."""
|
||||
|
||||
resourceTemplates: list[ResourceTemplate] | None = None
|
||||
resources: list[Resource] | None = None
|
||||
|
||||
|
||||
class ReadResourceRequestParams(RequestParams):
|
||||
"""Parameters for reading a resource."""
|
||||
|
||||
uri: AnyUrl
|
||||
"""The URI of the resource to read. The URI can use any protocol; it is up to the server how to interpret it."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ReadResourceRequest(Request):
|
||||
"""Sent from the client to the server, to read a specific resource URI."""
|
||||
|
||||
method: Literal["resources/read"]
|
||||
params: ReadResourceRequestParams
|
||||
|
||||
|
||||
class ResourceContents(BaseModel):
|
||||
"""The contents of a specific resource or sub-resource."""
|
||||
|
||||
uri: AnyUrl
|
||||
"""The URI of this resource."""
|
||||
mimeType: str | None = None
|
||||
"""The MIME type of this resource, if known."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class TextResourceContents(ResourceContents):
|
||||
"""Text contents of a resource."""
|
||||
|
||||
text: str
|
||||
"""The text of the item. This must only be set if the item can actually be represented as text (not binary data)."""
|
||||
|
||||
|
||||
class BlobResourceContents(ResourceContents):
|
||||
"""Binary contents of a resource."""
|
||||
|
||||
blob: str
|
||||
"""A base64-encoded string representing the binary data of the item."""
|
||||
|
||||
|
||||
class ReadResourceResult(Result):
|
||||
"""The server's response to a resources/read request from the client."""
|
||||
|
||||
contents: list[TextResourceContents | BlobResourceContents]
|
||||
|
||||
|
||||
class ResourceListChangedNotification(Notification):
|
||||
"""An optional notification from the server to the client, informing it that the list of resources it can read from has changed."""
|
||||
|
||||
method: Literal["notifications/resources/list_changed"]
|
||||
params: NotificationParams | None = None
|
||||
|
||||
|
||||
class SubscribeRequestParams(RequestParams):
|
||||
"""Parameters for subscribing to a resource."""
|
||||
|
||||
uri: AnyUrl
|
||||
"""The URI of the resource to subscribe to. The URI can use any protocol; it is up to the server how to interpret it."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class SubscribeRequest(Request):
|
||||
"""Sent from the client to request resources/updated notifications from the server whenever a particular resource changes."""
|
||||
|
||||
method: Literal["resources/subscribe"]
|
||||
params: SubscribeRequestParams
|
||||
|
||||
|
||||
class UnsubscribeRequestParams(RequestParams):
|
||||
"""Parameters for unsubscribing from a resource."""
|
||||
|
||||
uri: AnyUrl
|
||||
"""The URI of the resource to unsubscribe from."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class UnsubscribeRequest(Request):
|
||||
"""Sent from the client to request cancellation of resources/updated notifications from the server."""
|
||||
|
||||
method: Literal["resources/unsubscribe"]
|
||||
params: UnsubscribeRequestParams
|
||||
|
||||
|
||||
class ResourceUpdatedNotificationParams(NotificationParams):
|
||||
"""Parameters for resource update notifications."""
|
||||
|
||||
uri: AnyUrl
|
||||
"""The URI of the resource that has been updated. This might be a sub-resource of the one that the client actually subscribed to."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ResourceUpdatedNotification(Notification):
|
||||
"""A notification from the server to the client, informing it that a resource has changed and may need to be read again."""
|
||||
|
||||
method: Literal["notifications/resources/updated"]
|
||||
params: ResourceUpdatedNotificationParams
|
||||
|
||||
|
||||
class ListPromptsRequest(Request):
|
||||
"""Sent from the client to request a list of prompts and prompt templates the server has."""
|
||||
|
||||
method: Literal["prompts/list"]
|
||||
params: RequestParams | None = None
|
||||
|
||||
|
||||
class PromptArgument(BaseModel):
|
||||
"""An argument for a prompt template."""
|
||||
|
||||
name: str
|
||||
"""The name of the argument."""
|
||||
description: str | None = None
|
||||
"""A human-readable description of the argument."""
|
||||
required: bool | None = None
|
||||
"""Whether this argument must be provided."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class Prompt(BaseModel):
|
||||
"""A prompt or prompt template that the server offers."""
|
||||
|
||||
name: str
|
||||
"""The name of the prompt or prompt template."""
|
||||
description: str | None = None
|
||||
"""An optional description of what this prompt provides."""
|
||||
arguments: list[PromptArgument] | None = None
|
||||
"""A list of arguments to use for templating the prompt."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ListPromptsResult(Result):
|
||||
"""The server's response to a prompts/list request from the client."""
|
||||
|
||||
prompts: list[Prompt]
|
||||
|
||||
|
||||
class GetPromptRequestParams(RequestParams):
|
||||
"""Parameters for getting a prompt."""
|
||||
|
||||
name: str
|
||||
"""The name of the prompt or prompt template."""
|
||||
arguments: dict[str, str] | None = None
|
||||
"""Arguments to use for templating the prompt."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class GetPromptRequest(Request):
|
||||
"""Used by the client to get a prompt provided by the server."""
|
||||
|
||||
method: Literal["prompts/get"]
|
||||
params: GetPromptRequestParams
|
||||
|
||||
|
||||
class TextContent(BaseModel):
|
||||
"""Text content for a message."""
|
||||
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
"""The text content of the message."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ImageContent(BaseModel):
|
||||
"""Image content for a message."""
|
||||
|
||||
type: Literal["image"]
|
||||
data: str
|
||||
"""The base64-encoded image data."""
|
||||
mimeType: str
|
||||
"""The MIME type of the image. Different providers may support different image types."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
Role = Literal["user", "assistant"]
|
||||
|
||||
|
||||
class SamplingMessage(BaseModel):
|
||||
"""Describes a message issued to or received from an LLM API."""
|
||||
|
||||
role: Role
|
||||
content: TextContent | ImageContent
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class GetPromptResult(Result):
|
||||
"""The server's response to a prompts/get request from the client."""
|
||||
|
||||
description: str | None = None
|
||||
"""An optional description for the prompt."""
|
||||
messages: list[SamplingMessage]
|
||||
|
||||
|
||||
class ListToolsRequest(Request):
|
||||
"""Sent from the client to request a list of tools the server has."""
|
||||
|
||||
method: Literal["tools/list"]
|
||||
params: RequestParams | None = None
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
"""Definition for a tool the client can call."""
|
||||
|
||||
name: str
|
||||
"""The name of the tool."""
|
||||
description: str | None = None
|
||||
"""A human-readable description of the tool."""
|
||||
inputSchema: dict[str, Any]
|
||||
"""A JSON Schema object defining the expected parameters for the tool."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ListToolsResult(Result):
|
||||
"""The server's response to a tools/list request from the client."""
|
||||
|
||||
tools: list[Tool]
|
||||
|
||||
|
||||
class CallToolRequestParams(RequestParams):
|
||||
"""Parameters for calling a tool."""
|
||||
|
||||
name: str
|
||||
arguments: dict[str, Any] | None = None
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class CallToolRequest(Request):
|
||||
"""Used by the client to invoke a tool provided by the server."""
|
||||
|
||||
method: Literal["tools/call"]
|
||||
params: CallToolRequestParams
|
||||
|
||||
|
||||
class CallToolResult(Result):
|
||||
"""The server's response to a tool call."""
|
||||
|
||||
toolResult: Any
|
||||
|
||||
|
||||
class ToolListChangedNotification(Notification):
|
||||
"""An optional notification from the server to the client, informing it that the list of tools it offers has changed."""
|
||||
|
||||
method: Literal["notifications/tools/list_changed"]
|
||||
params: NotificationParams | None = None
|
||||
|
||||
|
||||
LoggingLevel = Literal["debug", "info", "warning", "error"]
|
||||
|
||||
|
||||
class SetLevelRequestParams(RequestParams):
|
||||
"""Parameters for setting the logging level."""
|
||||
|
||||
level: LoggingLevel
|
||||
"""The level of logging that the client wants to receive from the server."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class SetLevelRequest(Request):
|
||||
"""A request from the client to the server, to enable or adjust logging."""
|
||||
|
||||
method: Literal["logging/setLevel"]
|
||||
params: SetLevelRequestParams
|
||||
|
||||
|
||||
class LoggingMessageNotificationParams(NotificationParams):
|
||||
"""Parameters for logging message notifications."""
|
||||
|
||||
level: LoggingLevel
|
||||
"""The severity of this log message."""
|
||||
logger: str | None = None
|
||||
"""An optional name of the logger issuing this message."""
|
||||
data: Any
|
||||
"""The data to be logged, such as a string message or an object. Any JSON serializable type is allowed here."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class LoggingMessageNotification(Notification):
|
||||
"""Notification of a log message passed from server to client."""
|
||||
|
||||
method: Literal["notifications/message"]
|
||||
params: LoggingMessageNotificationParams
|
||||
|
||||
|
||||
IncludeContext = Literal["none", "thisServer", "allServers"]
|
||||
|
||||
|
||||
class CreateMessageRequestParams(RequestParams):
|
||||
"""Parameters for creating a message."""
|
||||
|
||||
messages: list[SamplingMessage]
|
||||
systemPrompt: str | None = None
|
||||
"""An optional system prompt the server wants to use for sampling."""
|
||||
includeContext: IncludeContext | None = None
|
||||
"""A request to include context from one or more MCP servers (including the caller), to be attached to the prompt."""
|
||||
temperature: float | None = None
|
||||
maxTokens: int
|
||||
"""The maximum number of tokens to sample, as requested by the server."""
|
||||
stopSequences: list[str] | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
"""Optional metadata to pass through to the LLM provider."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class CreateMessageRequest(Request):
|
||||
"""A request from the server to sample an LLM via the client."""
|
||||
|
||||
method: Literal["sampling/createMessage"]
|
||||
params: CreateMessageRequestParams
|
||||
|
||||
|
||||
StopReason = Literal["endTurn", "stopSequence", "maxTokens"]
|
||||
|
||||
|
||||
class CreateMessageResult(Result):
|
||||
"""The client's response to a sampling/create_message request from the server."""
|
||||
|
||||
role: Role
|
||||
content: TextContent | ImageContent
|
||||
model: str
|
||||
"""The name of the model that generated the message."""
|
||||
stopReason: StopReason
|
||||
"""The reason why sampling stopped."""
|
||||
|
||||
|
||||
class ResourceReference(BaseModel):
|
||||
"""A reference to a resource or resource template definition."""
|
||||
|
||||
type: Literal["ref/resource"]
|
||||
uri: str
|
||||
"""The URI or URI template of the resource."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class PromptReference(BaseModel):
|
||||
"""Identifies a prompt."""
|
||||
|
||||
type: Literal["ref/prompt"]
|
||||
name: str
|
||||
"""The name of the prompt or prompt template"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class CompletionArgument(BaseModel):
|
||||
"""The argument's information for completion requests."""
|
||||
|
||||
name: str
|
||||
"""The name of the argument"""
|
||||
value: str
|
||||
"""The value of the argument to use for completion matching."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class CompleteRequestParams(RequestParams):
|
||||
"""Parameters for completion requests."""
|
||||
|
||||
ref: ResourceReference | PromptReference
|
||||
argument: CompletionArgument
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class CompleteRequest(Request):
|
||||
"""A request from the client to the server, to ask for completion options."""
|
||||
|
||||
method: Literal["completion/complete"]
|
||||
params: CompleteRequestParams
|
||||
|
||||
|
||||
class Completion(BaseModel):
|
||||
"""Completion information."""
|
||||
|
||||
values: list[str]
|
||||
"""An array of completion values. Must not exceed 100 items."""
|
||||
total: int | None = None
|
||||
"""The total number of completion options available. This can exceed the number of values actually sent in the response."""
|
||||
hasMore: bool | None = None
|
||||
"""Indicates whether there are additional completion options beyond those provided in the current response, even if the exact total is unknown."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class CompleteResult(Result):
|
||||
"""The server's response to a completion/complete request"""
|
||||
|
||||
completion: Completion
|
||||
|
||||
|
||||
class ClientRequest(
|
||||
RootModel[
|
||||
PingRequest
|
||||
| InitializeRequest
|
||||
| CompleteRequest
|
||||
| SetLevelRequest
|
||||
| GetPromptRequest
|
||||
| ListPromptsRequest
|
||||
| ListResourcesRequest
|
||||
| ReadResourceRequest
|
||||
| SubscribeRequest
|
||||
| UnsubscribeRequest
|
||||
| CallToolRequest
|
||||
| ListToolsRequest
|
||||
]
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class ClientNotification(RootModel[ProgressNotification | InitializedNotification]):
|
||||
pass
|
||||
|
||||
|
||||
class ClientResult(RootModel[EmptyResult | CreateMessageResult]):
|
||||
pass
|
||||
|
||||
|
||||
class ServerRequest(RootModel[PingRequest | CreateMessageRequest]):
|
||||
pass
|
||||
|
||||
|
||||
class ServerNotification(
|
||||
RootModel[
|
||||
ProgressNotification
|
||||
| LoggingMessageNotification
|
||||
| ResourceUpdatedNotification
|
||||
| ResourceListChangedNotification
|
||||
| ToolListChangedNotification
|
||||
]
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class ServerResult(
|
||||
RootModel[
|
||||
EmptyResult
|
||||
| InitializeResult
|
||||
| CompleteResult
|
||||
| GetPromptResult
|
||||
| ListPromptsResult
|
||||
| ListResourcesResult
|
||||
| ReadResourceResult
|
||||
| CallToolResult
|
||||
| ListToolsResult
|
||||
]
|
||||
):
|
||||
pass
|
||||
34
pyproject.toml
Normal file
34
pyproject.toml
Normal file
@@ -0,0 +1,34 @@
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "mcp-python"
|
||||
version = "0.1.2"
|
||||
description = "Model Context Protocol implementation for Python"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
dependencies = [
|
||||
"anyio",
|
||||
"httpx",
|
||||
"httpx-sse",
|
||||
"pydantic>=2.0.0",
|
||||
"starlette",
|
||||
"sse-starlette",
|
||||
]
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["mcp_python"]
|
||||
|
||||
[tool.pyright]
|
||||
include = ["mcp_python", "tests"]
|
||||
typeCheckingMode = "strict"
|
||||
|
||||
[tool.ruff]
|
||||
select = ["E", "F", "I"]
|
||||
ignore = []
|
||||
line-length = 88
|
||||
target-version = "py38"
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
"__init__.py" = ["F401"]
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
0
tests/client/__init__.py
Normal file
0
tests/client/__init__.py
Normal file
93
tests/client/test_session.py
Normal file
93
tests/client/test_session.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from mcp_python.client.session import ClientSession
|
||||
from mcp_python.types import (
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
Implementation,
|
||||
InitializedNotification,
|
||||
InitializeRequest,
|
||||
InitializeResult,
|
||||
JSONRPCMessage,
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
ServerCapabilities,
|
||||
ServerResult,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_client_session_initialize():
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||
JSONRPCMessage
|
||||
](1)
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||
JSONRPCMessage
|
||||
](1)
|
||||
|
||||
initialized_notification = None
|
||||
|
||||
async def mock_server():
|
||||
nonlocal initialized_notification
|
||||
|
||||
jsonrpc_request = await client_to_server_receive.receive()
|
||||
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
|
||||
request = ClientRequest.model_validate(
|
||||
jsonrpc_request.model_dump(by_alias=True, mode="json")
|
||||
)
|
||||
assert isinstance(request.root, InitializeRequest)
|
||||
|
||||
result = ServerResult(
|
||||
InitializeResult(
|
||||
protocolVersion=1,
|
||||
capabilities=ServerCapabilities(
|
||||
logging=None,
|
||||
resources=None,
|
||||
tools=None,
|
||||
experimental=None,
|
||||
prompts=None,
|
||||
),
|
||||
serverInfo=Implementation(name="mock-server", version="0.1.0"),
|
||||
)
|
||||
)
|
||||
|
||||
async with server_to_client_send:
|
||||
await server_to_client_send.send(
|
||||
JSONRPCMessage(
|
||||
JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=jsonrpc_request.root.id,
|
||||
result=result.model_dump(by_alias=True, mode="json"),
|
||||
)
|
||||
)
|
||||
)
|
||||
jsonrpc_notification = await client_to_server_receive.receive()
|
||||
assert isinstance(jsonrpc_notification.root, JSONRPCNotification)
|
||||
initialized_notification = ClientNotification.model_validate(
|
||||
jsonrpc_notification.model_dump(by_alias=True, mode="json")
|
||||
)
|
||||
|
||||
async def listen_session():
|
||||
async for message in session.incoming_messages:
|
||||
if isinstance(message, Exception):
|
||||
raise message
|
||||
|
||||
async with (
|
||||
ClientSession(server_to_client_receive, client_to_server_send) as session,
|
||||
anyio.create_task_group() as tg,
|
||||
):
|
||||
tg.start_soon(mock_server)
|
||||
tg.start_soon(listen_session)
|
||||
result = await session.initialize()
|
||||
|
||||
# Assert the result
|
||||
assert isinstance(result, InitializeResult)
|
||||
assert result.protocolVersion == 1
|
||||
assert isinstance(result.capabilities, ServerCapabilities)
|
||||
assert result.serverInfo == Implementation(name="mock-server", version="0.1.0")
|
||||
|
||||
# Check that the client sent the initialized notification
|
||||
assert initialized_notification
|
||||
assert isinstance(initialized_notification.root, InitializedNotification)
|
||||
38
tests/client/test_stdio.py
Normal file
38
tests/client/test_stdio.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import pytest
|
||||
|
||||
from mcp_python.client.stdio import StdioServerParameters, stdio_client
|
||||
from mcp_python.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_stdio_client():
|
||||
server_parameters = StdioServerParameters(command="/usr/bin/tee")
|
||||
|
||||
async with stdio_client(server_parameters) as (read_stream, write_stream):
|
||||
# Test sending and receiving messages
|
||||
messages = [
|
||||
JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")),
|
||||
JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})),
|
||||
]
|
||||
|
||||
async with write_stream:
|
||||
for message in messages:
|
||||
await write_stream.send(message)
|
||||
|
||||
read_messages = []
|
||||
async with read_stream:
|
||||
async for message in read_stream:
|
||||
if isinstance(message, Exception):
|
||||
raise message
|
||||
|
||||
read_messages.append(message)
|
||||
if len(read_messages) == 2:
|
||||
break
|
||||
|
||||
assert len(read_messages) == 2
|
||||
assert read_messages[0] == JSONRPCMessage(
|
||||
root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
|
||||
)
|
||||
assert read_messages[1] == JSONRPCMessage(
|
||||
root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})
|
||||
)
|
||||
0
tests/server/__init__.py
Normal file
0
tests/server/__init__.py
Normal file
59
tests/server/test_session.py
Normal file
59
tests/server/test_session.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from mcp_python.client.session import ClientSession
|
||||
from mcp_python.server.session import ServerSession
|
||||
from mcp_python.types import (
|
||||
ClientNotification,
|
||||
InitializedNotification,
|
||||
JSONRPCMessage,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_server_session_initialize():
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream(
|
||||
1, JSONRPCMessage
|
||||
)
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream(
|
||||
1, JSONRPCMessage
|
||||
)
|
||||
|
||||
async def run_client(client: ClientSession):
|
||||
async for message in client_session.incoming_messages:
|
||||
if isinstance(message, Exception):
|
||||
raise message
|
||||
|
||||
received_initialized = False
|
||||
|
||||
async def run_server():
|
||||
nonlocal received_initialized
|
||||
|
||||
async with ServerSession(
|
||||
client_to_server_receive, server_to_client_send
|
||||
) as server_session:
|
||||
async for message in server_session.incoming_messages:
|
||||
if isinstance(message, Exception):
|
||||
raise message
|
||||
|
||||
if isinstance(message, ClientNotification) and isinstance(
|
||||
message.root, InitializedNotification
|
||||
):
|
||||
received_initialized = True
|
||||
return
|
||||
|
||||
try:
|
||||
async with (
|
||||
ClientSession(
|
||||
server_to_client_receive, client_to_server_send
|
||||
) as client_session,
|
||||
anyio.create_task_group() as tg,
|
||||
):
|
||||
tg.start_soon(run_client, client_session)
|
||||
tg.start_soon(run_server)
|
||||
|
||||
await client_session.initialize()
|
||||
except* anyio.ClosedResourceError:
|
||||
pass
|
||||
|
||||
assert received_initialized
|
||||
68
tests/server/test_stdio.py
Normal file
68
tests/server/test_stdio.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import io
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from mcp_python.server.stdio import stdio_server
|
||||
from mcp_python.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_stdio_server():
|
||||
stdin = io.StringIO()
|
||||
stdout = io.StringIO()
|
||||
|
||||
messages = [
|
||||
JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")),
|
||||
JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})),
|
||||
]
|
||||
|
||||
for message in messages:
|
||||
stdin.write(message.model_dump_json() + "\n")
|
||||
stdin.seek(0)
|
||||
|
||||
async with stdio_server(
|
||||
stdin=anyio.AsyncFile(stdin), stdout=anyio.AsyncFile(stdout)
|
||||
) as (read_stream, write_stream):
|
||||
received_messages = []
|
||||
async with read_stream:
|
||||
async for message in read_stream:
|
||||
if isinstance(message, Exception):
|
||||
raise message
|
||||
received_messages.append(message)
|
||||
if len(received_messages) == 2:
|
||||
break
|
||||
|
||||
# Verify received messages
|
||||
assert len(received_messages) == 2
|
||||
assert received_messages[0] == JSONRPCMessage(
|
||||
root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
|
||||
)
|
||||
assert received_messages[1] == JSONRPCMessage(
|
||||
root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})
|
||||
)
|
||||
|
||||
# Test sending responses from the server
|
||||
responses = [
|
||||
JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")),
|
||||
JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})),
|
||||
]
|
||||
|
||||
async with write_stream:
|
||||
for response in responses:
|
||||
await write_stream.send(response)
|
||||
|
||||
stdout.seek(0)
|
||||
output_lines = stdout.readlines()
|
||||
assert len(output_lines) == 2
|
||||
|
||||
received_responses = [
|
||||
JSONRPCMessage.model_validate_json(line.strip()) for line in output_lines
|
||||
]
|
||||
assert len(received_responses) == 2
|
||||
assert received_responses[0] == JSONRPCMessage(
|
||||
root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")
|
||||
)
|
||||
assert received_responses[1] == JSONRPCMessage(
|
||||
root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})
|
||||
)
|
||||
24
tests/test_types.py
Normal file
24
tests/test_types.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from mcp_python.types import ClientRequest, JSONRPCMessage, JSONRPCRequest
|
||||
|
||||
|
||||
def test_jsonrpc_request():
|
||||
json_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": 1,
|
||||
"capabilities": {"batch": None, "sampling": None},
|
||||
"clientInfo": {"name": "mcp_python", "version": "0.1.0"},
|
||||
},
|
||||
}
|
||||
|
||||
request = JSONRPCMessage.model_validate(json_data)
|
||||
assert isinstance(request.root, JSONRPCRequest)
|
||||
ClientRequest.model_validate(request.model_dump(by_alias=True))
|
||||
|
||||
assert request.root.jsonrpc == "2.0"
|
||||
assert request.root.id == 1
|
||||
assert request.root.method == "initialize"
|
||||
assert request.root.params is not None
|
||||
assert request.root.params["protocolVersion"] == 1
|
||||
Reference in New Issue
Block a user