mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Prevent stdio connection hang for missing server path. (#401)
Co-authored-by: ihrpr <inna@anthropic.com>
This commit is contained in:
@@ -108,20 +108,28 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
|
|||||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||||
|
|
||||||
command = _get_executable_command(server.command)
|
try:
|
||||||
|
command = _get_executable_command(server.command)
|
||||||
|
|
||||||
# Open process with stderr piped for capture
|
# Open process with stderr piped for capture
|
||||||
process = await _create_platform_compatible_process(
|
process = await _create_platform_compatible_process(
|
||||||
command=command,
|
command=command,
|
||||||
args=server.args,
|
args=server.args,
|
||||||
env=(
|
env=(
|
||||||
{**get_default_environment(), **server.env}
|
{**get_default_environment(), **server.env}
|
||||||
if server.env is not None
|
if server.env is not None
|
||||||
else get_default_environment()
|
else get_default_environment()
|
||||||
),
|
),
|
||||||
errlog=errlog,
|
errlog=errlog,
|
||||||
cwd=server.cwd,
|
cwd=server.cwd,
|
||||||
)
|
)
|
||||||
|
except OSError:
|
||||||
|
# Clean up streams if process creation fails
|
||||||
|
await read_stream.aclose()
|
||||||
|
await write_stream.aclose()
|
||||||
|
await read_stream_writer.aclose()
|
||||||
|
await write_stream_reader.aclose()
|
||||||
|
raise
|
||||||
|
|
||||||
async def stdout_reader():
|
async def stdout_reader():
|
||||||
assert process.stdout, "Opened process is missing stdout"
|
assert process.stdout, "Opened process is missing stdout"
|
||||||
@@ -177,12 +185,18 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
|
|||||||
yield read_stream, write_stream
|
yield read_stream, write_stream
|
||||||
finally:
|
finally:
|
||||||
# Clean up process to prevent any dangling orphaned processes
|
# Clean up process to prevent any dangling orphaned processes
|
||||||
if sys.platform == "win32":
|
try:
|
||||||
await terminate_windows_process(process)
|
if sys.platform == "win32":
|
||||||
else:
|
await terminate_windows_process(process)
|
||||||
process.terminate()
|
else:
|
||||||
|
process.terminate()
|
||||||
|
except ProcessLookupError:
|
||||||
|
# Process already exited, which is fine
|
||||||
|
pass
|
||||||
await read_stream.aclose()
|
await read_stream.aclose()
|
||||||
await write_stream.aclose()
|
await write_stream.aclose()
|
||||||
|
await read_stream_writer.aclose()
|
||||||
|
await write_stream_reader.aclose()
|
||||||
|
|
||||||
|
|
||||||
def _get_executable_command(command: str) -> str:
|
def _get_executable_command(command: str) -> str:
|
||||||
|
|||||||
@@ -2,11 +2,17 @@ import shutil
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from mcp.client.stdio import StdioServerParameters, stdio_client
|
from mcp.client.session import ClientSession
|
||||||
|
from mcp.client.stdio import (
|
||||||
|
StdioServerParameters,
|
||||||
|
stdio_client,
|
||||||
|
)
|
||||||
|
from mcp.shared.exceptions import McpError
|
||||||
from mcp.shared.message import SessionMessage
|
from mcp.shared.message import SessionMessage
|
||||||
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
|
from mcp.types import CONNECTION_CLOSED, JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
|
||||||
|
|
||||||
tee: str = shutil.which("tee") # type: ignore
|
tee: str = shutil.which("tee") # type: ignore
|
||||||
|
python: str = shutil.which("python") # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -50,3 +56,43 @@ async def test_stdio_client():
|
|||||||
assert read_messages[1] == JSONRPCMessage(
|
assert read_messages[1] == JSONRPCMessage(
|
||||||
root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})
|
root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_stdio_client_bad_path():
|
||||||
|
"""Check that the connection doesn't hang if process errors."""
|
||||||
|
server_params = StdioServerParameters(
|
||||||
|
command="python", args=["-c", "non-existent-file.py"]
|
||||||
|
)
|
||||||
|
async with stdio_client(server_params) as (read_stream, write_stream):
|
||||||
|
async with ClientSession(read_stream, write_stream) as session:
|
||||||
|
# The session should raise an error when the connection closes
|
||||||
|
with pytest.raises(McpError) as exc_info:
|
||||||
|
await session.initialize()
|
||||||
|
|
||||||
|
# Check that we got a connection closed error
|
||||||
|
assert exc_info.value.error.code == CONNECTION_CLOSED
|
||||||
|
assert "Connection closed" in exc_info.value.error.message
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_stdio_client_nonexistent_command():
|
||||||
|
"""Test that stdio_client raises an error for non-existent commands."""
|
||||||
|
# Create a server with a non-existent command
|
||||||
|
server_params = StdioServerParameters(
|
||||||
|
command="/path/to/nonexistent/command",
|
||||||
|
args=["--help"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should raise an error when trying to start the process
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
async with stdio_client(server_params) as (_, _):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# The error should indicate the command was not found
|
||||||
|
error_message = str(exc_info.value)
|
||||||
|
assert (
|
||||||
|
"nonexistent" in error_message
|
||||||
|
or "not found" in error_message.lower()
|
||||||
|
or "cannot find the file" in error_message.lower() # Windows error message
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user