Prevent stdio connection hang for missing server path. (#401)

Co-authored-by: ihrpr <inna@anthropic.com>
This commit is contained in:
Tim Child
2025-05-28 14:57:46 -07:00
committed by GitHub
parent 70014a2bbb
commit f5dd324354
2 changed files with 79 additions and 19 deletions

View File

@@ -108,6 +108,7 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
read_stream_writer, read_stream = anyio.create_memory_object_stream(0) read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
try:
command = _get_executable_command(server.command) command = _get_executable_command(server.command)
# Open process with stderr piped for capture # Open process with stderr piped for capture
@@ -122,6 +123,13 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
errlog=errlog, errlog=errlog,
cwd=server.cwd, cwd=server.cwd,
) )
except OSError:
# Clean up streams if process creation fails
await read_stream.aclose()
await write_stream.aclose()
await read_stream_writer.aclose()
await write_stream_reader.aclose()
raise
async def stdout_reader(): async def stdout_reader():
assert process.stdout, "Opened process is missing stdout" assert process.stdout, "Opened process is missing stdout"
@@ -177,12 +185,18 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
yield read_stream, write_stream yield read_stream, write_stream
finally: finally:
# Clean up process to prevent any dangling orphaned processes # Clean up process to prevent any dangling orphaned processes
try:
if sys.platform == "win32": if sys.platform == "win32":
await terminate_windows_process(process) await terminate_windows_process(process)
else: else:
process.terminate() process.terminate()
except ProcessLookupError:
# Process already exited, which is fine
pass
await read_stream.aclose() await read_stream.aclose()
await write_stream.aclose() await write_stream.aclose()
await read_stream_writer.aclose()
await write_stream_reader.aclose()
def _get_executable_command(command: str) -> str: def _get_executable_command(command: str) -> str:

View File

@@ -2,11 +2,17 @@ import shutil
import pytest import pytest
from mcp.client.stdio import StdioServerParameters, stdio_client from mcp.client.session import ClientSession
from mcp.client.stdio import (
StdioServerParameters,
stdio_client,
)
from mcp.shared.exceptions import McpError
from mcp.shared.message import SessionMessage from mcp.shared.message import SessionMessage
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse from mcp.types import CONNECTION_CLOSED, JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
tee: str = shutil.which("tee") # type: ignore tee: str = shutil.which("tee") # type: ignore
python: str = shutil.which("python") # type: ignore
@pytest.mark.anyio @pytest.mark.anyio
@@ -50,3 +56,43 @@ async def test_stdio_client():
assert read_messages[1] == JSONRPCMessage( assert read_messages[1] == JSONRPCMessage(
root=JSONRPCResponse(jsonrpc="2.0", id=2, result={}) root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})
) )
@pytest.mark.anyio
async def test_stdio_client_bad_path():
"""Check that the connection doesn't hang if process errors."""
server_params = StdioServerParameters(
command="python", args=["-c", "non-existent-file.py"]
)
async with stdio_client(server_params) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
# The session should raise an error when the connection closes
with pytest.raises(McpError) as exc_info:
await session.initialize()
# Check that we got a connection closed error
assert exc_info.value.error.code == CONNECTION_CLOSED
assert "Connection closed" in exc_info.value.error.message
@pytest.mark.anyio
async def test_stdio_client_nonexistent_command():
"""Test that stdio_client raises an error for non-existent commands."""
# Create a server with a non-existent command
server_params = StdioServerParameters(
command="/path/to/nonexistent/command",
args=["--help"],
)
# Should raise an error when trying to start the process
with pytest.raises(Exception) as exc_info:
async with stdio_client(server_params) as (_, _):
pass
# The error should indicate the command was not found
error_message = str(exc_info.value)
assert (
"nonexistent" in error_message
or "not found" in error_message.lower()
or "cannot find the file" in error_message.lower() # Windows error message
)