mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Initial import
This commit is contained in:
0
mcp_python/client/__init__.py
Normal file
0
mcp_python/client/__init__.py
Normal file
76
mcp_python/client/__main__.py
Normal file
76
mcp_python/client/__main__.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import logging
|
||||
import sys
|
||||
from functools import partial
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import anyio
|
||||
import click
|
||||
|
||||
from mcp_python.client.session import ClientSession
|
||||
from mcp_python.client.sse import sse_client
|
||||
from mcp_python.client.stdio import StdioServerParameters, stdio_client
|
||||
|
||||
if not sys.warnoptions:
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("client")
|
||||
|
||||
|
||||
async def receive_loop(session: ClientSession):
|
||||
logger.info("Starting receive loop")
|
||||
async for message in session.incoming_messages:
|
||||
if isinstance(message, Exception):
|
||||
logger.error("Error: %s", message)
|
||||
continue
|
||||
|
||||
logger.info("Received message from server: %s", message)
|
||||
|
||||
|
||||
async def run_session(read_stream, write_stream):
|
||||
async with (
|
||||
ClientSession(read_stream, write_stream) as session,
|
||||
anyio.create_task_group() as tg,
|
||||
):
|
||||
tg.start_soon(receive_loop, session)
|
||||
|
||||
logger.info("Initializing session")
|
||||
await session.initialize()
|
||||
logger.info("Initialized")
|
||||
|
||||
|
||||
async def main(command_or_url: str, args: list[str], env: list[tuple[str, str]]):
|
||||
env_dict = dict(env)
|
||||
|
||||
if urlparse(command_or_url).scheme in ("http", "https"):
|
||||
# Use SSE client for HTTP(S) URLs
|
||||
async with sse_client(command_or_url) as streams:
|
||||
await run_session(*streams)
|
||||
else:
|
||||
# Use stdio client for commands
|
||||
server_parameters = StdioServerParameters(
|
||||
command=command_or_url, args=args, env=env_dict
|
||||
)
|
||||
async with stdio_client(server_parameters) as streams:
|
||||
await run_session(*streams)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("command_or_url")
|
||||
@click.argument("args", nargs=-1)
|
||||
@click.option(
|
||||
"--env",
|
||||
"-e",
|
||||
multiple=True,
|
||||
nargs=2,
|
||||
metavar="KEY VALUE",
|
||||
help="Environment variables to set. Can be used multiple times.",
|
||||
)
|
||||
def cli(*args, **kwargs):
|
||||
anyio.run(partial(main, *args, **kwargs), backend="trio")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
211
mcp_python/client/session.py
Normal file
211
mcp_python/client/session.py
Normal file
@@ -0,0 +1,211 @@
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl
|
||||
|
||||
from mcp_python.shared.session import BaseSession
|
||||
from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSION
|
||||
from mcp_python.types import (
|
||||
CallToolResult,
|
||||
ClientCapabilities,
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
ClientResult,
|
||||
EmptyResult,
|
||||
Implementation,
|
||||
InitializedNotification,
|
||||
InitializeResult,
|
||||
JSONRPCMessage,
|
||||
ListResourcesResult,
|
||||
LoggingLevel,
|
||||
ReadResourceResult,
|
||||
ServerNotification,
|
||||
ServerRequest,
|
||||
)
|
||||
|
||||
|
||||
class ClientSession(
|
||||
BaseSession[
|
||||
ClientRequest,
|
||||
ClientNotification,
|
||||
ClientResult,
|
||||
ServerRequest,
|
||||
ServerNotification,
|
||||
]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||
) -> None:
|
||||
super().__init__(read_stream, write_stream, ServerRequest, ServerNotification)
|
||||
|
||||
async def initialize(self) -> InitializeResult:
|
||||
from mcp_python.types import (
|
||||
InitializeRequest,
|
||||
InitializeRequestParams,
|
||||
)
|
||||
|
||||
result = await self.send_request(
|
||||
ClientRequest(
|
||||
InitializeRequest(
|
||||
method="initialize",
|
||||
params=InitializeRequestParams(
|
||||
protocolVersion=SUPPORTED_PROTOCOL_VERSION,
|
||||
capabilities=ClientCapabilities(
|
||||
sampling=None, experimental=None
|
||||
),
|
||||
clientInfo=Implementation(name="mcp_python", version="0.1.0"),
|
||||
),
|
||||
)
|
||||
),
|
||||
InitializeResult,
|
||||
)
|
||||
|
||||
if result.protocolVersion != SUPPORTED_PROTOCOL_VERSION:
|
||||
raise RuntimeError(
|
||||
f"Unsupported protocol version from the server: {result.protocolVersion}"
|
||||
)
|
||||
|
||||
await self.send_notification(
|
||||
ClientNotification(
|
||||
InitializedNotification(method="notifications/initialized")
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def send_ping(self) -> EmptyResult:
|
||||
"""Send a ping request."""
|
||||
from mcp_python.types import PingRequest
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
PingRequest(
|
||||
method="ping",
|
||||
)
|
||||
),
|
||||
EmptyResult,
|
||||
)
|
||||
|
||||
async def send_progress_notification(
|
||||
self, progress_token: str | int, progress: float, total: float | None = None
|
||||
) -> None:
|
||||
"""Send a progress notification."""
|
||||
from mcp_python.types import (
|
||||
ProgressNotification,
|
||||
ProgressNotificationParams,
|
||||
)
|
||||
|
||||
await self.send_notification(
|
||||
ClientNotification(
|
||||
ProgressNotification(
|
||||
method="notifications/progress",
|
||||
params=ProgressNotificationParams(
|
||||
progressToken=progress_token,
|
||||
progress=progress,
|
||||
total=total,
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
async def set_logging_level(self, level: LoggingLevel) -> EmptyResult:
|
||||
"""Send a logging/setLevel request."""
|
||||
from mcp_python.types import (
|
||||
SetLevelRequest,
|
||||
SetLevelRequestParams,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
SetLevelRequest(
|
||||
method="logging/setLevel",
|
||||
params=SetLevelRequestParams(level=level),
|
||||
)
|
||||
),
|
||||
EmptyResult,
|
||||
)
|
||||
|
||||
async def list_resources(self) -> ListResourcesResult:
|
||||
"""Send a resources/list request."""
|
||||
from mcp_python.types import (
|
||||
ListResourcesRequest,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
ListResourcesRequest(
|
||||
method="resources/list",
|
||||
)
|
||||
),
|
||||
ListResourcesResult,
|
||||
)
|
||||
|
||||
async def read_resource(self, uri: AnyUrl) -> ReadResourceResult:
|
||||
"""Send a resources/read request."""
|
||||
from mcp_python.types import (
|
||||
ReadResourceRequest,
|
||||
ReadResourceRequestParams,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
ReadResourceRequest(
|
||||
method="resources/read",
|
||||
params=ReadResourceRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
ReadResourceResult,
|
||||
)
|
||||
|
||||
async def subscribe_resource(self, uri: AnyUrl) -> EmptyResult:
|
||||
"""Send a resources/subscribe request."""
|
||||
from mcp_python.types import (
|
||||
SubscribeRequest,
|
||||
SubscribeRequestParams,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
SubscribeRequest(
|
||||
method="resources/subscribe",
|
||||
params=SubscribeRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
EmptyResult,
|
||||
)
|
||||
|
||||
async def unsubscribe_resource(self, uri: AnyUrl) -> EmptyResult:
|
||||
"""Send a resources/unsubscribe request."""
|
||||
from mcp_python.types import (
|
||||
UnsubscribeRequest,
|
||||
UnsubscribeRequestParams,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
UnsubscribeRequest(
|
||||
method="resources/unsubscribe",
|
||||
params=UnsubscribeRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
EmptyResult,
|
||||
)
|
||||
|
||||
async def call_tool(
|
||||
self, name: str, arguments: dict | None = None
|
||||
) -> CallToolResult:
|
||||
"""Send a tools/call request."""
|
||||
from mcp_python.types import (
|
||||
CallToolRequest,
|
||||
CallToolRequestParams,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ClientRequest(
|
||||
CallToolRequest(
|
||||
method="tools/call",
|
||||
params=CallToolRequestParams(name=name, arguments=arguments),
|
||||
)
|
||||
),
|
||||
CallToolResult,
|
||||
)
|
||||
129
mcp_python/client/sse.py
Normal file
129
mcp_python/client/sse.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
from anyio.abc import TaskStatus
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from httpx_sse import aconnect_sse
|
||||
|
||||
from mcp_python.types import JSONRPCMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def remove_request_params(url: str) -> str:
|
||||
return urljoin(url, urlparse(url).path)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def sse_client(url: str, timeout: float = 5, sse_read_timeout: float = 60 * 5):
|
||||
"""
|
||||
Client transport for SSE.
|
||||
|
||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||
"""
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
try:
|
||||
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with aconnect_sse(
|
||||
client,
|
||||
"GET",
|
||||
url,
|
||||
timeout=httpx.Timeout(timeout, read=sse_read_timeout),
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("SSE connection established")
|
||||
|
||||
async def sse_reader(
|
||||
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
|
||||
):
|
||||
try:
|
||||
async for sse in event_source.aiter_sse():
|
||||
logger.debug(f"Received SSE event: {sse.event}")
|
||||
match sse.event:
|
||||
case "endpoint":
|
||||
endpoint_url = urljoin(url, sse.data)
|
||||
logger.info(
|
||||
f"Received endpoint URL: {endpoint_url}"
|
||||
)
|
||||
|
||||
url_parsed = urlparse(url)
|
||||
endpoint_parsed = urlparse(endpoint_url)
|
||||
if (
|
||||
url_parsed.netloc != endpoint_parsed.netloc
|
||||
or url_parsed.scheme
|
||||
!= endpoint_parsed.scheme
|
||||
):
|
||||
error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
task_status.started(endpoint_url)
|
||||
|
||||
case "message":
|
||||
try:
|
||||
message = (
|
||||
JSONRPCMessage.model_validate_json(
|
||||
sse.data
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
f"Received server message: {message}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f"Error parsing server message: {exc}"
|
||||
)
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
await read_stream_writer.send(message)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error in sse_reader: {exc}")
|
||||
await read_stream_writer.send(exc)
|
||||
finally:
|
||||
await read_stream_writer.aclose()
|
||||
|
||||
async def post_writer(endpoint_url: str):
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for message in write_stream_reader:
|
||||
logger.debug(f"Sending client message: {message}")
|
||||
response = await client.post(
|
||||
endpoint_url,
|
||||
json=message.model_dump(by_alias=True, mode="json"),
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.debug(
|
||||
f"Client message sent successfully: {response.status_code}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error in post_writer: {exc}")
|
||||
finally:
|
||||
await write_stream.aclose()
|
||||
|
||||
endpoint_url = await tg.start(sse_reader)
|
||||
logger.info(
|
||||
f"Starting post writer with endpoint URL: {endpoint_url}"
|
||||
)
|
||||
tg.start_soon(post_writer, endpoint_url)
|
||||
|
||||
try:
|
||||
yield read_stream, write_stream
|
||||
finally:
|
||||
tg.cancel_scope.cancel()
|
||||
finally:
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream.aclose()
|
||||
84
mcp_python/client/stdio.py
Normal file
84
mcp_python/client/stdio.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import anyio
|
||||
import anyio.lowlevel
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from anyio.streams.text import TextReceiveStream
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from mcp_python.types import JSONRPCMessage
|
||||
|
||||
|
||||
class StdioServerParameters(BaseModel):
|
||||
command: str
|
||||
"""The executable to run to start the server."""
|
||||
|
||||
args: list[str] = Field(default_factory=list)
|
||||
"""Command line arguments to pass to the executable."""
|
||||
|
||||
env: dict[str, str] = Field(default_factory=dict)
|
||||
"""
|
||||
The environment to use when spawning the process.
|
||||
|
||||
The environment is NOT inherited from the parent process by default.
|
||||
"""
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def stdio_client(server: StdioServerParameters):
|
||||
"""
|
||||
Client transport for stdio: this will connect to a server by spawning a process and communicating with it over stdin/stdout.
|
||||
"""
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
process = await anyio.open_process(
|
||||
[server.command, *server.args], env=server.env, stderr=sys.stderr
|
||||
)
|
||||
|
||||
async def stdout_reader():
|
||||
assert process.stdout, "Opened process is missing stdout"
|
||||
|
||||
try:
|
||||
async with read_stream_writer:
|
||||
buffer = ""
|
||||
async for chunk in TextReceiveStream(process.stdout):
|
||||
lines = (buffer + chunk).split("\n")
|
||||
buffer = lines.pop()
|
||||
|
||||
for line in lines:
|
||||
try:
|
||||
message = JSONRPCMessage.model_validate_json(line)
|
||||
except Exception as exc:
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
await read_stream_writer.send(message)
|
||||
except anyio.ClosedResourceError:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
async def stdin_writer():
|
||||
assert process.stdin, "Opened process is missing stdin"
|
||||
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for message in write_stream_reader:
|
||||
json = message.model_dump_json(by_alias=True)
|
||||
await process.stdin.send((json + "\n").encode())
|
||||
except anyio.ClosedResourceError:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
async with (
|
||||
anyio.create_task_group() as tg,
|
||||
process,
|
||||
):
|
||||
tg.start_soon(stdout_reader)
|
||||
tg.start_soon(stdin_writer)
|
||||
yield read_stream, write_stream
|
||||
Reference in New Issue
Block a user