mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 06:54:18 +01:00
221 lines
6.4 KiB
Python
221 lines
6.4 KiB
Python
import os
|
|
import sys
|
|
from contextlib import asynccontextmanager
|
|
from pathlib import Path
|
|
from typing import Literal, TextIO
|
|
|
|
import anyio
|
|
import anyio.lowlevel
|
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
from anyio.streams.text import TextReceiveStream
|
|
from pydantic import BaseModel, Field
|
|
|
|
import mcp.types as types
|
|
from mcp.shared.message import SessionMessage
|
|
|
|
from .win32 import (
|
|
create_windows_process,
|
|
get_windows_executable_command,
|
|
terminate_windows_process,
|
|
)
|
|
|
|
# Environment variables to inherit by default
|
|
DEFAULT_INHERITED_ENV_VARS = (
|
|
[
|
|
"APPDATA",
|
|
"HOMEDRIVE",
|
|
"HOMEPATH",
|
|
"LOCALAPPDATA",
|
|
"PATH",
|
|
"PROCESSOR_ARCHITECTURE",
|
|
"SYSTEMDRIVE",
|
|
"SYSTEMROOT",
|
|
"TEMP",
|
|
"USERNAME",
|
|
"USERPROFILE",
|
|
]
|
|
if sys.platform == "win32"
|
|
else ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"]
|
|
)
|
|
|
|
|
|
def get_default_environment() -> dict[str, str]:
|
|
"""
|
|
Returns a default environment object including only environment variables deemed
|
|
safe to inherit.
|
|
"""
|
|
env: dict[str, str] = {}
|
|
|
|
for key in DEFAULT_INHERITED_ENV_VARS:
|
|
value = os.environ.get(key)
|
|
if value is None:
|
|
continue
|
|
|
|
if value.startswith("()"):
|
|
# Skip functions, which are a security risk
|
|
continue
|
|
|
|
env[key] = value
|
|
|
|
return env
|
|
|
|
|
|
class StdioServerParameters(BaseModel):
|
|
command: str
|
|
"""The executable to run to start the server."""
|
|
|
|
args: list[str] = Field(default_factory=list)
|
|
"""Command line arguments to pass to the executable."""
|
|
|
|
env: dict[str, str] | None = None
|
|
"""
|
|
The environment to use when spawning the process.
|
|
|
|
If not specified, the result of get_default_environment() will be used.
|
|
"""
|
|
|
|
cwd: str | Path | None = None
|
|
"""The working directory to use when spawning the process."""
|
|
|
|
encoding: str = "utf-8"
|
|
"""
|
|
The text encoding used when sending/receiving messages to the server
|
|
|
|
defaults to utf-8
|
|
"""
|
|
|
|
encoding_error_handler: Literal["strict", "ignore", "replace"] = "strict"
|
|
"""
|
|
The text encoding error handler.
|
|
|
|
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
|
|
explanations of possible values
|
|
"""
|
|
|
|
|
|
@asynccontextmanager
|
|
async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stderr):
|
|
"""
|
|
Client transport for stdio: this will connect to a server by spawning a
|
|
process and communicating with it over stdin/stdout.
|
|
"""
|
|
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
|
|
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
|
|
|
|
write_stream: MemoryObjectSendStream[SessionMessage]
|
|
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
|
|
|
|
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
|
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
|
|
|
command = _get_executable_command(server.command)
|
|
|
|
# Open process with stderr piped for capture
|
|
process = await _create_platform_compatible_process(
|
|
command=command,
|
|
args=server.args,
|
|
env=(
|
|
{**get_default_environment(), **server.env}
|
|
if server.env is not None
|
|
else get_default_environment()
|
|
),
|
|
errlog=errlog,
|
|
cwd=server.cwd,
|
|
)
|
|
|
|
async def stdout_reader():
|
|
assert process.stdout, "Opened process is missing stdout"
|
|
|
|
try:
|
|
async with read_stream_writer:
|
|
buffer = ""
|
|
async for chunk in TextReceiveStream(
|
|
process.stdout,
|
|
encoding=server.encoding,
|
|
errors=server.encoding_error_handler,
|
|
):
|
|
lines = (buffer + chunk).split("\n")
|
|
buffer = lines.pop()
|
|
|
|
for line in lines:
|
|
try:
|
|
message = types.JSONRPCMessage.model_validate_json(line)
|
|
except Exception as exc:
|
|
await read_stream_writer.send(exc)
|
|
continue
|
|
|
|
session_message = SessionMessage(message)
|
|
await read_stream_writer.send(session_message)
|
|
except anyio.ClosedResourceError:
|
|
await anyio.lowlevel.checkpoint()
|
|
|
|
async def stdin_writer():
|
|
assert process.stdin, "Opened process is missing stdin"
|
|
|
|
try:
|
|
async with write_stream_reader:
|
|
async for session_message in write_stream_reader:
|
|
json = session_message.message.model_dump_json(
|
|
by_alias=True, exclude_none=True
|
|
)
|
|
await process.stdin.send(
|
|
(json + "\n").encode(
|
|
encoding=server.encoding,
|
|
errors=server.encoding_error_handler,
|
|
)
|
|
)
|
|
except anyio.ClosedResourceError:
|
|
await anyio.lowlevel.checkpoint()
|
|
|
|
async with (
|
|
anyio.create_task_group() as tg,
|
|
process,
|
|
):
|
|
tg.start_soon(stdout_reader)
|
|
tg.start_soon(stdin_writer)
|
|
try:
|
|
yield read_stream, write_stream
|
|
finally:
|
|
# Clean up process to prevent any dangling orphaned processes
|
|
if sys.platform == "win32":
|
|
await terminate_windows_process(process)
|
|
else:
|
|
process.terminate()
|
|
|
|
|
|
def _get_executable_command(command: str) -> str:
|
|
"""
|
|
Get the correct executable command normalized for the current platform.
|
|
|
|
Args:
|
|
command: Base command (e.g., 'uvx', 'npx')
|
|
|
|
Returns:
|
|
str: Platform-appropriate command
|
|
"""
|
|
if sys.platform == "win32":
|
|
return get_windows_executable_command(command)
|
|
else:
|
|
return command
|
|
|
|
|
|
async def _create_platform_compatible_process(
|
|
command: str,
|
|
args: list[str],
|
|
env: dict[str, str] | None = None,
|
|
errlog: TextIO = sys.stderr,
|
|
cwd: Path | str | None = None,
|
|
):
|
|
"""
|
|
Creates a subprocess in a platform-compatible way.
|
|
Returns a process handle.
|
|
"""
|
|
if sys.platform == "win32":
|
|
process = await create_windows_process(command, args, env, errlog, cwd)
|
|
else:
|
|
process = await anyio.open_process(
|
|
[command, *args], env=env, stderr=errlog, cwd=cwd
|
|
)
|
|
|
|
return process
|