Handle SSE Disconnects Properly (#612)

This commit is contained in:
Akash D
2025-05-02 09:32:46 -07:00
committed by GitHub
parent 5d8eaf77be
commit 83968b5b2f
7 changed files with 38 additions and 11 deletions

View File

@@ -90,6 +90,7 @@ def main(port: int, transport: str) -> int:
if transport == "sse": if transport == "sse":
from mcp.server.sse import SseServerTransport from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.responses import Response
from starlette.routing import Mount, Route from starlette.routing import Mount, Route
sse = SseServerTransport("/messages/") sse = SseServerTransport("/messages/")
@@ -101,6 +102,7 @@ def main(port: int, transport: str) -> int:
await app.run( await app.run(
streams[0], streams[1], app.create_initialization_options() streams[0], streams[1], app.create_initialization_options()
) )
return Response()
starlette_app = Starlette( starlette_app = Starlette(
debug=True, debug=True,

View File

@@ -46,6 +46,7 @@ def main(port: int, transport: str) -> int:
if transport == "sse": if transport == "sse":
from mcp.server.sse import SseServerTransport from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.responses import Response
from starlette.routing import Mount, Route from starlette.routing import Mount, Route
sse = SseServerTransport("/messages/") sse = SseServerTransport("/messages/")
@@ -57,11 +58,12 @@ def main(port: int, transport: str) -> int:
await app.run( await app.run(
streams[0], streams[1], app.create_initialization_options() streams[0], streams[1], app.create_initialization_options()
) )
return Response()
starlette_app = Starlette( starlette_app = Starlette(
debug=True, debug=True,
routes=[ routes=[
Route("/sse", endpoint=handle_sse), Route("/sse", endpoint=handle_sse, methods=["GET"]),
Mount("/messages/", app=sse.handle_post_message), Mount("/messages/", app=sse.handle_post_message),
], ],
) )

View File

@@ -60,6 +60,7 @@ def main(port: int, transport: str) -> int:
if transport == "sse": if transport == "sse":
from mcp.server.sse import SseServerTransport from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.responses import Response
from starlette.routing import Mount, Route from starlette.routing import Mount, Route
sse = SseServerTransport("/messages/") sse = SseServerTransport("/messages/")
@@ -71,11 +72,12 @@ def main(port: int, transport: str) -> int:
await app.run( await app.run(
streams[0], streams[1], app.create_initialization_options() streams[0], streams[1], app.create_initialization_options()
) )
return Response()
starlette_app = Starlette( starlette_app = Starlette(
debug=True, debug=True,
routes=[ routes=[
Route("/sse", endpoint=handle_sse), Route("/sse", endpoint=handle_sse, methods=["GET"]),
Mount("/messages/", app=sse.handle_post_message), Mount("/messages/", app=sse.handle_post_message),
], ],
) )

View File

@@ -589,6 +589,7 @@ class FastMCP:
streams[1], streams[1],
self._mcp_server.create_initialization_options(), self._mcp_server.create_initialization_options(),
) )
return Response()
# Create routes # Create routes
routes: list[Route | Mount] = [] routes: list[Route | Mount] = []

View File

@@ -104,9 +104,6 @@ class ServerSession(
self._exit_stack.push_async_callback( self._exit_stack.push_async_callback(
lambda: self._incoming_message_stream_reader.aclose() lambda: self._incoming_message_stream_reader.aclose()
) )
self._exit_stack.push_async_callback(
lambda: self._incoming_message_stream_writer.aclose()
)
@property @property
def client_params(self) -> types.InitializeRequestParams | None: def client_params(self) -> types.InitializeRequestParams | None:
@@ -144,6 +141,10 @@ class ServerSession(
return True return True
async def _receive_loop(self) -> None:
async with self._incoming_message_stream_writer:
await super()._receive_loop()
async def _received_request( async def _received_request(
self, responder: RequestResponder[types.ClientRequest, types.ServerResult] self, responder: RequestResponder[types.ClientRequest, types.ServerResult]
): ):

View File

@@ -10,7 +10,7 @@ Example usage:
# Create Starlette routes for SSE and message handling # Create Starlette routes for SSE and message handling
routes = [ routes = [
Route("/sse", endpoint=handle_sse), Route("/sse", endpoint=handle_sse, methods=["GET"]),
Mount("/messages/", app=sse.handle_post_message), Mount("/messages/", app=sse.handle_post_message),
] ]
@@ -22,12 +22,18 @@ Example usage:
await app.run( await app.run(
streams[0], streams[1], app.create_initialization_options() streams[0], streams[1], app.create_initialization_options()
) )
# Return empty response to avoid NoneType error
return Response()
# Create and run Starlette app # Create and run Starlette app
starlette_app = Starlette(routes=routes) starlette_app = Starlette(routes=routes)
uvicorn.run(starlette_app, host="0.0.0.0", port=port) uvicorn.run(starlette_app, host="0.0.0.0", port=port)
``` ```
Note: The handle_sse function must return a Response to avoid a "TypeError: 'NoneType'
object is not callable" error when client disconnects. The example above returns
an empty Response() after the SSE connection ends to fix this.
See SseServerTransport class documentation for more details. See SseServerTransport class documentation for more details.
""" """
@@ -120,11 +126,22 @@ class SseServerTransport:
) )
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
response = EventSourceResponse(
content=sse_stream_reader, data_sender_callable=sse_writer async def response_wrapper(scope: Scope, receive: Receive, send: Send):
) """
The EventSourceResponse returning signals a client close / disconnect.
In this case we close our side of the streams to signal the client that
the connection has been closed.
"""
await EventSourceResponse(
content=sse_stream_reader, data_sender_callable=sse_writer
)(scope, receive, send)
await read_stream_writer.aclose()
await write_stream_reader.aclose()
logging.debug(f"Client session disconnected {session_id}")
logger.debug("Starting SSE response task") logger.debug("Starting SSE response task")
tg.start_soon(response, scope, receive, send) tg.start_soon(response_wrapper, scope, receive, send)
logger.debug("Yielding read and write streams") logger.debug("Yielding read and write streams")
yield (read_stream, write_stream) yield (read_stream, write_stream)

View File

@@ -10,6 +10,7 @@ import uvicorn
from pydantic import AnyUrl from pydantic import AnyUrl
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Mount, Route from starlette.routing import Mount, Route
from mcp.client.session import ClientSession from mcp.client.session import ClientSession
@@ -83,13 +84,14 @@ def make_server_app() -> Starlette:
sse = SseServerTransport("/messages/") sse = SseServerTransport("/messages/")
server = ServerTest() server = ServerTest()
async def handle_sse(request: Request) -> None: async def handle_sse(request: Request) -> Response:
async with sse.connect_sse( async with sse.connect_sse(
request.scope, request.receive, request._send request.scope, request.receive, request._send
) as streams: ) as streams:
await server.run( await server.run(
streams[0], streams[1], server.create_initialization_options() streams[0], streams[1], server.create_initialization_options()
) )
return Response()
app = Starlette( app = Starlette(
routes=[ routes=[