mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Prevent stdio connection hang for missing server path. (#401)
Co-authored-by: ihrpr <inna@anthropic.com>
This commit is contained in:
@@ -108,6 +108,7 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
try:
|
||||
command = _get_executable_command(server.command)
|
||||
|
||||
# Open process with stderr piped for capture
|
||||
@@ -122,6 +123,13 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
|
||||
errlog=errlog,
|
||||
cwd=server.cwd,
|
||||
)
|
||||
except OSError:
|
||||
# Clean up streams if process creation fails
|
||||
await read_stream.aclose()
|
||||
await write_stream.aclose()
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream_reader.aclose()
|
||||
raise
|
||||
|
||||
async def stdout_reader():
|
||||
assert process.stdout, "Opened process is missing stdout"
|
||||
@@ -177,12 +185,18 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
|
||||
yield read_stream, write_stream
|
||||
finally:
|
||||
# Clean up process to prevent any dangling orphaned processes
|
||||
try:
|
||||
if sys.platform == "win32":
|
||||
await terminate_windows_process(process)
|
||||
else:
|
||||
process.terminate()
|
||||
except ProcessLookupError:
|
||||
# Process already exited, which is fine
|
||||
pass
|
||||
await read_stream.aclose()
|
||||
await write_stream.aclose()
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream_reader.aclose()
|
||||
|
||||
|
||||
def _get_executable_command(command: str) -> str:
|
||||
|
||||
@@ -2,11 +2,17 @@ import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
from mcp.client.stdio import StdioServerParameters, stdio_client
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.stdio import (
|
||||
StdioServerParameters,
|
||||
stdio_client,
|
||||
)
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.message import SessionMessage
|
||||
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
|
||||
from mcp.types import CONNECTION_CLOSED, JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
|
||||
|
||||
tee: str = shutil.which("tee") # type: ignore
|
||||
python: str = shutil.which("python") # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -50,3 +56,43 @@ async def test_stdio_client():
|
||||
assert read_messages[1] == JSONRPCMessage(
|
||||
root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_stdio_client_bad_path():
|
||||
"""Check that the connection doesn't hang if process errors."""
|
||||
server_params = StdioServerParameters(
|
||||
command="python", args=["-c", "non-existent-file.py"]
|
||||
)
|
||||
async with stdio_client(server_params) as (read_stream, write_stream):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
# The session should raise an error when the connection closes
|
||||
with pytest.raises(McpError) as exc_info:
|
||||
await session.initialize()
|
||||
|
||||
# Check that we got a connection closed error
|
||||
assert exc_info.value.error.code == CONNECTION_CLOSED
|
||||
assert "Connection closed" in exc_info.value.error.message
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_stdio_client_nonexistent_command():
|
||||
"""Test that stdio_client raises an error for non-existent commands."""
|
||||
# Create a server with a non-existent command
|
||||
server_params = StdioServerParameters(
|
||||
command="/path/to/nonexistent/command",
|
||||
args=["--help"],
|
||||
)
|
||||
|
||||
# Should raise an error when trying to start the process
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
async with stdio_client(server_params) as (_, _):
|
||||
pass
|
||||
|
||||
# The error should indicate the command was not found
|
||||
error_message = str(exc_info.value)
|
||||
assert (
|
||||
"nonexistent" in error_message
|
||||
or "not found" in error_message.lower()
|
||||
or "cannot find the file" in error_message.lower() # Windows error message
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user