From 7ab1fc71aa9447bf63fa837e9fad049d00d9a488 Mon Sep 17 00:00:00 2001 From: Nick Merrill Date: Tue, 14 Jan 2025 11:01:43 -0500 Subject: [PATCH] attempt to get server to shut down --- tests/shared/test_sse.py | 44 +++++++++++++++++++++++++++++++--------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 96c9758..bab41a8 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -1,5 +1,6 @@ # test_sse.py import re +import socket import time import json import anyio @@ -25,13 +26,9 @@ SERVER_NAME = "test_server_for_SSE" @pytest.fixture def server_port() -> int: - import socket - - s = socket.socket() - s.bind(('', 0)) - port = s.getsockname()[1] - s.close() - return port + with socket.socket() as s: + s.bind(('127.0.0.1', 0)) + return s.getsockname()[1] @pytest.fixture def server_url(server_port: int) -> str: @@ -89,6 +86,12 @@ async def server_app()-> Starlette: return app +@pytest.fixture(autouse=True) +def space_around_test(): + time.sleep(0.1) + yield + time.sleep(0.1) + @pytest.fixture() def server(server_app: Starlette, server_port: int): server = uvicorn.Server(config=uvicorn.Config(app=server_app, host="127.0.0.1", port=server_port, log_level="error")) @@ -99,9 +102,27 @@ def server(server_app: Starlette, server_port: int): while not server.started: print('waiting for server to start') time.sleep(0.5) - yield - print('killing server') - server_thread.join(timeout=0.1) + + try: + yield + finally: + print('killing server') + # Signal the server to stop + server.should_exit = True + + # Force close the server's main socket + if hasattr(server.servers, "servers"): + for s in server.servers: + print(f'closing {s}') + s.close() + + # Wait for thread to finish + server_thread.join(timeout=2) + if server_thread.is_alive(): + print("Warning: Server thread did not exit cleanly") + # Optionally, you could add more aggressive cleanup here + import _thread + _thread.interrupt_main() @pytest.fixture() async def http_client(server, server_url) -> AsyncGenerator[httpx.AsyncClient, None]: @@ -167,3 +188,6 @@ async def test_sse_client_exception_handling(initialized_sse_client_session: Cli session = initialized_sse_client_session with pytest.raises(McpError, match="OOPS! no resource with that URI was found"): await session.read_resource(uri=AnyUrl("xxx://will-not-work")) + + +# TODO: test that timeouts are respected and that the error comes back