mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 06:54:18 +01:00
Properly infer prefix for SSE messages (#659)
This commit is contained in:
@@ -100,10 +100,26 @@ class SseServerTransport:
|
|||||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||||
|
|
||||||
session_id = uuid4()
|
session_id = uuid4()
|
||||||
session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}"
|
|
||||||
self._read_stream_writers[session_id] = read_stream_writer
|
self._read_stream_writers[session_id] = read_stream_writer
|
||||||
logger.debug(f"Created new session with ID: {session_id}")
|
logger.debug(f"Created new session with ID: {session_id}")
|
||||||
|
|
||||||
|
# Determine the full path for the message endpoint to be sent to the client.
|
||||||
|
# scope['root_path'] is the prefix where the current Starlette app
|
||||||
|
# instance is mounted.
|
||||||
|
# e.g., "" if top-level, or "/api_prefix" if mounted under "/api_prefix".
|
||||||
|
root_path = scope.get("root_path", "")
|
||||||
|
|
||||||
|
# self._endpoint is the path *within* this app, e.g., "/messages".
|
||||||
|
# Concatenating them gives the full absolute path from the server root.
|
||||||
|
# e.g., "" + "/messages" -> "/messages"
|
||||||
|
# e.g., "/api_prefix" + "/messages" -> "/api_prefix/messages"
|
||||||
|
full_message_path_for_client = root_path.rstrip("/") + self._endpoint
|
||||||
|
|
||||||
|
# This is the URI (path + query) the client will use to POST messages.
|
||||||
|
client_post_uri_data = (
|
||||||
|
f"{quote(full_message_path_for_client)}?session_id={session_id.hex}"
|
||||||
|
)
|
||||||
|
|
||||||
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
|
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
|
||||||
dict[str, Any]
|
dict[str, Any]
|
||||||
](0)
|
](0)
|
||||||
@@ -111,8 +127,10 @@ class SseServerTransport:
|
|||||||
async def sse_writer():
|
async def sse_writer():
|
||||||
logger.debug("Starting SSE writer")
|
logger.debug("Starting SSE writer")
|
||||||
async with sse_stream_writer, write_stream_reader:
|
async with sse_stream_writer, write_stream_reader:
|
||||||
await sse_stream_writer.send({"event": "endpoint", "data": session_uri})
|
await sse_stream_writer.send(
|
||||||
logger.debug(f"Sent endpoint event: {session_uri}")
|
{"event": "endpoint", "data": client_post_uri_data}
|
||||||
|
)
|
||||||
|
logger.debug(f"Sent endpoint event: {client_post_uri_data}")
|
||||||
|
|
||||||
async for session_message in write_stream_reader:
|
async for session_message in write_stream_reader:
|
||||||
logger.debug(f"Sending message via SSE: {session_message}")
|
logger.debug(f"Sending message via SSE: {session_message}")
|
||||||
|
|||||||
@@ -252,3 +252,69 @@ async def test_sse_client_timeout(
|
|||||||
return
|
return
|
||||||
|
|
||||||
pytest.fail("the client should have timed out and returned an error already")
|
pytest.fail("the client should have timed out and returned an error already")
|
||||||
|
|
||||||
|
|
||||||
|
def run_mounted_server(server_port: int) -> None:
|
||||||
|
app = make_server_app()
|
||||||
|
main_app = Starlette(routes=[Mount("/mounted_app", app=app)])
|
||||||
|
server = uvicorn.Server(
|
||||||
|
config=uvicorn.Config(
|
||||||
|
app=main_app, host="127.0.0.1", port=server_port, log_level="error"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(f"starting server on {server_port}")
|
||||||
|
server.run()
|
||||||
|
|
||||||
|
# Give server time to start
|
||||||
|
while not server.started:
|
||||||
|
print("waiting for server to start")
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def mounted_server(server_port: int) -> Generator[None, None, None]:
|
||||||
|
proc = multiprocessing.Process(
|
||||||
|
target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True
|
||||||
|
)
|
||||||
|
print("starting process")
|
||||||
|
proc.start()
|
||||||
|
|
||||||
|
# Wait for server to be running
|
||||||
|
max_attempts = 20
|
||||||
|
attempt = 0
|
||||||
|
print("waiting for server to start")
|
||||||
|
while attempt < max_attempts:
|
||||||
|
try:
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.connect(("127.0.0.1", server_port))
|
||||||
|
break
|
||||||
|
except ConnectionRefusedError:
|
||||||
|
time.sleep(0.1)
|
||||||
|
attempt += 1
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
print("killing server")
|
||||||
|
# Signal the server to stop
|
||||||
|
proc.kill()
|
||||||
|
proc.join(timeout=2)
|
||||||
|
if proc.is_alive():
|
||||||
|
print("server process failed to terminate")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_sse_client_basic_connection_mounted_app(
|
||||||
|
mounted_server: None, server_url: str
|
||||||
|
) -> None:
|
||||||
|
async with sse_client(server_url + "/mounted_app/sse") as streams:
|
||||||
|
async with ClientSession(*streams) as session:
|
||||||
|
# Test initialization
|
||||||
|
result = await session.initialize()
|
||||||
|
assert isinstance(result, InitializeResult)
|
||||||
|
assert result.serverInfo.name == SERVER_NAME
|
||||||
|
|
||||||
|
# Test ping
|
||||||
|
ping_result = await session.send_ping()
|
||||||
|
assert isinstance(ping_result, EmptyResult)
|
||||||
|
|||||||
Reference in New Issue
Block a user