Streamable Http - clean up server memory streams (#604)

This commit is contained in:
ihrpr
2025-05-02 14:59:17 +01:00
committed by GitHub
parent 74f5fcfa0d
commit 5d8eaf77be
4 changed files with 108 additions and 69 deletions

View File

@@ -185,10 +185,12 @@ def main(
) )
server_instances[http_transport.mcp_session_id] = http_transport server_instances[http_transport.mcp_session_id] = http_transport
logger.info(f"Created new transport with session ID: {new_session_id}") logger.info(f"Created new transport with session ID: {new_session_id}")
async def run_server(task_status=None):
async with http_transport.connect() as streams: async with http_transport.connect() as streams:
read_stream, write_stream = streams read_stream, write_stream = streams
if task_status:
async def run_server(): task_status.started()
await app.run( await app.run(
read_stream, read_stream,
write_stream, write_stream,
@@ -198,7 +200,7 @@ def main(
if not task_group: if not task_group:
raise RuntimeError("Task group is not initialized") raise RuntimeError("Task group is not initialized")
task_group.start_soon(run_server) await task_group.start(run_server)
# Handle the HTTP request and return the response # Handle the HTTP request and return the response
await http_transport.handle_request(scope, receive, send) await http_transport.handle_request(scope, receive, send)

View File

@@ -480,7 +480,7 @@ class Server(Generic[LifespanResultT]):
# but also make tracing exceptions much easier during testing and when using # but also make tracing exceptions much easier during testing and when using
# in-process servers. # in-process servers.
raise_exceptions: bool = False, raise_exceptions: bool = False,
# When True, the server as stateless deployments where # When True, the server is stateless and
# clients can perform initialization with any node. The client must still follow # clients can perform initialization with any node. The client must still follow
# the initialization lifecycle, but can do so with any available node # the initialization lifecycle, but can do so with any available node
# rather than requiring initialization for each connection. # rather than requiring initialization for each connection.

View File

@@ -129,6 +129,8 @@ class StreamableHTTPServerTransport:
_read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = ( _read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = (
None None
) )
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None
_write_stream: MemoryObjectSendStream[SessionMessage] | None = None
_write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None _write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None
def __init__( def __init__(
@@ -163,7 +165,11 @@ class StreamableHTTPServerTransport:
self.is_json_response_enabled = is_json_response_enabled self.is_json_response_enabled = is_json_response_enabled
self._event_store = event_store self._event_store = event_store
self._request_streams: dict[ self._request_streams: dict[
RequestId, MemoryObjectSendStream[EventMessage] RequestId,
tuple[
MemoryObjectSendStream[EventMessage],
MemoryObjectReceiveStream[EventMessage],
],
] = {} ] = {}
self._terminated = False self._terminated = False
@@ -239,6 +245,19 @@ class StreamableHTTPServerTransport:
return event_data return event_data
async def _clean_up_memory_streams(self, request_id: RequestId) -> None:
"""Clean up memory streams for a given request ID."""
if request_id in self._request_streams:
try:
# Close the request stream
await self._request_streams[request_id][0].aclose()
await self._request_streams[request_id][1].aclose()
except Exception as e:
logger.debug(f"Error closing memory streams: {e}")
finally:
# Remove the request stream from the mapping
self._request_streams.pop(request_id, None)
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Application entry point that handles all HTTP requests""" """Application entry point that handles all HTTP requests"""
request = Request(scope, receive) request = Request(scope, receive)
@@ -386,13 +405,11 @@ class StreamableHTTPServerTransport:
# Extract the request ID outside the try block for proper scope # Extract the request ID outside the try block for proper scope
request_id = str(message.root.id) request_id = str(message.root.id)
# Create promise stream for getting response
request_stream_writer, request_stream_reader = (
anyio.create_memory_object_stream[EventMessage](0)
)
# Register this stream for the request ID # Register this stream for the request ID
self._request_streams[request_id] = request_stream_writer self._request_streams[request_id] = anyio.create_memory_object_stream[
EventMessage
](0)
request_stream_reader = self._request_streams[request_id][1]
if self.is_json_response_enabled: if self.is_json_response_enabled:
# Process the message # Process the message
@@ -441,11 +458,7 @@ class StreamableHTTPServerTransport:
) )
await response(scope, receive, send) await response(scope, receive, send)
finally: finally:
# Clean up the request stream await self._clean_up_memory_streams(request_id)
if request_id in self._request_streams:
self._request_streams.pop(request_id, None)
await request_stream_reader.aclose()
await request_stream_writer.aclose()
else: else:
# Create SSE stream # Create SSE stream
sse_stream_writer, sse_stream_reader = ( sse_stream_writer, sse_stream_reader = (
@@ -467,16 +480,12 @@ class StreamableHTTPServerTransport:
event_message.message.root, event_message.message.root,
JSONRPCResponse | JSONRPCError, JSONRPCResponse | JSONRPCError,
): ):
if request_id:
self._request_streams.pop(request_id, None)
break break
except Exception as e: except Exception as e:
logger.exception(f"Error in SSE writer: {e}") logger.exception(f"Error in SSE writer: {e}")
finally: finally:
logger.debug("Closing SSE writer") logger.debug("Closing SSE writer")
# Clean up the request-specific streams await self._clean_up_memory_streams(request_id)
if request_id and request_id in self._request_streams:
self._request_streams.pop(request_id, None)
# Create and start EventSourceResponse # Create and start EventSourceResponse
# SSE stream mode (original behavior) # SSE stream mode (original behavior)
@@ -507,9 +516,9 @@ class StreamableHTTPServerTransport:
await writer.send(session_message) await writer.send(session_message)
except Exception: except Exception:
logger.exception("SSE response error") logger.exception("SSE response error")
# Clean up the request stream if something goes wrong await sse_stream_writer.aclose()
if request_id and request_id in self._request_streams: await sse_stream_reader.aclose()
self._request_streams.pop(request_id, None) await self._clean_up_memory_streams(request_id)
except Exception as err: except Exception as err:
logger.exception("Error handling POST request") logger.exception("Error handling POST request")
@@ -581,12 +590,11 @@ class StreamableHTTPServerTransport:
async def standalone_sse_writer(): async def standalone_sse_writer():
try: try:
# Create a standalone message stream for server-initiated messages # Create a standalone message stream for server-initiated messages
standalone_stream_writer, standalone_stream_reader = (
self._request_streams[GET_STREAM_KEY] = (
anyio.create_memory_object_stream[EventMessage](0) anyio.create_memory_object_stream[EventMessage](0)
) )
standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1]
# Register this stream using the special key
self._request_streams[GET_STREAM_KEY] = standalone_stream_writer
async with sse_stream_writer, standalone_stream_reader: async with sse_stream_writer, standalone_stream_reader:
# Process messages from the standalone stream # Process messages from the standalone stream
@@ -603,8 +611,7 @@ class StreamableHTTPServerTransport:
logger.exception(f"Error in standalone SSE writer: {e}") logger.exception(f"Error in standalone SSE writer: {e}")
finally: finally:
logger.debug("Closing standalone SSE writer") logger.debug("Closing standalone SSE writer")
# Remove the stream from request_streams await self._clean_up_memory_streams(GET_STREAM_KEY)
self._request_streams.pop(GET_STREAM_KEY, None)
# Create and start EventSourceResponse # Create and start EventSourceResponse
response = EventSourceResponse( response = EventSourceResponse(
@@ -618,8 +625,9 @@ class StreamableHTTPServerTransport:
await response(request.scope, request.receive, send) await response(request.scope, request.receive, send)
except Exception as e: except Exception as e:
logger.exception(f"Error in standalone SSE response: {e}") logger.exception(f"Error in standalone SSE response: {e}")
# Clean up the request stream await sse_stream_writer.aclose()
self._request_streams.pop(GET_STREAM_KEY, None) await sse_stream_reader.aclose()
await self._clean_up_memory_streams(GET_STREAM_KEY)
async def _handle_delete_request(self, request: Request, send: Send) -> None: async def _handle_delete_request(self, request: Request, send: Send) -> None:
"""Handle DELETE requests for explicit session termination.""" """Handle DELETE requests for explicit session termination."""
@@ -636,7 +644,7 @@ class StreamableHTTPServerTransport:
if not await self._validate_session(request, send): if not await self._validate_session(request, send):
return return
self._terminate_session() await self._terminate_session()
response = self._create_json_response( response = self._create_json_response(
None, None,
@@ -644,7 +652,7 @@ class StreamableHTTPServerTransport:
) )
await response(request.scope, request.receive, send) await response(request.scope, request.receive, send)
def _terminate_session(self) -> None: async def _terminate_session(self) -> None:
"""Terminate the current session, closing all streams. """Terminate the current session, closing all streams.
Once terminated, all requests with this session ID will receive 404 Not Found. Once terminated, all requests with this session ID will receive 404 Not Found.
@@ -656,19 +664,26 @@ class StreamableHTTPServerTransport:
# We need a copy of the keys to avoid modification during iteration # We need a copy of the keys to avoid modification during iteration
request_stream_keys = list(self._request_streams.keys()) request_stream_keys = list(self._request_streams.keys())
# Close all request streams (synchronously) # Close all request streams asynchronously
for key in request_stream_keys: for key in request_stream_keys:
try: try:
# Get the stream await self._clean_up_memory_streams(key)
stream = self._request_streams.get(key)
if stream:
# We must use close() here, not aclose() since this is a sync method
stream.close()
except Exception as e: except Exception as e:
logger.debug(f"Error closing stream {key} during termination: {e}") logger.debug(f"Error closing stream {key} during termination: {e}")
# Clear the request streams dictionary immediately # Clear the request streams dictionary immediately
self._request_streams.clear() self._request_streams.clear()
try:
if self._read_stream_writer is not None:
await self._read_stream_writer.aclose()
if self._read_stream is not None:
await self._read_stream.aclose()
if self._write_stream_reader is not None:
await self._write_stream_reader.aclose()
if self._write_stream is not None:
await self._write_stream.aclose()
except Exception as e:
logger.debug(f"Error closing streams: {e}")
async def _handle_unsupported_request(self, request: Request, send: Send) -> None: async def _handle_unsupported_request(self, request: Request, send: Send) -> None:
"""Handle unsupported HTTP methods.""" """Handle unsupported HTTP methods."""
@@ -756,10 +771,10 @@ class StreamableHTTPServerTransport:
# If stream ID not in mapping, create it # If stream ID not in mapping, create it
if stream_id and stream_id not in self._request_streams: if stream_id and stream_id not in self._request_streams:
msg_writer, msg_reader = anyio.create_memory_object_stream[ self._request_streams[stream_id] = (
EventMessage anyio.create_memory_object_stream[EventMessage](0)
](0) )
self._request_streams[stream_id] = msg_writer msg_reader = self._request_streams[stream_id][1]
# Forward messages to SSE # Forward messages to SSE
async with msg_reader: async with msg_reader:
@@ -781,6 +796,9 @@ class StreamableHTTPServerTransport:
await response(request.scope, request.receive, send) await response(request.scope, request.receive, send)
except Exception as e: except Exception as e:
logger.exception(f"Error in replay response: {e}") logger.exception(f"Error in replay response: {e}")
finally:
await sse_stream_writer.aclose()
await sse_stream_reader.aclose()
except Exception as e: except Exception as e:
logger.exception(f"Error replaying events: {e}") logger.exception(f"Error replaying events: {e}")
@@ -818,7 +836,9 @@ class StreamableHTTPServerTransport:
# Store the streams # Store the streams
self._read_stream_writer = read_stream_writer self._read_stream_writer = read_stream_writer
self._read_stream = read_stream
self._write_stream_reader = write_stream_reader self._write_stream_reader = write_stream_reader
self._write_stream = write_stream
# Start a task group for message routing # Start a task group for message routing
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
@@ -863,7 +883,7 @@ class StreamableHTTPServerTransport:
if request_stream_id in self._request_streams: if request_stream_id in self._request_streams:
try: try:
# Send both the message and the event ID # Send both the message and the event ID
await self._request_streams[request_stream_id].send( await self._request_streams[request_stream_id][0].send(
EventMessage(message, event_id) EventMessage(message, event_id)
) )
except ( except (
@@ -872,6 +892,12 @@ class StreamableHTTPServerTransport:
): ):
# Stream might be closed, remove from registry # Stream might be closed, remove from registry
self._request_streams.pop(request_stream_id, None) self._request_streams.pop(request_stream_id, None)
else:
logging.debug(
f"""Request stream {request_stream_id} not found
for message. Still processing message as the client
might reconnect and replay."""
)
except Exception as e: except Exception as e:
logger.exception(f"Error in message router: {e}") logger.exception(f"Error in message router: {e}")
@@ -882,9 +908,19 @@ class StreamableHTTPServerTransport:
# Yield the streams for the caller to use # Yield the streams for the caller to use
yield read_stream, write_stream yield read_stream, write_stream
finally: finally:
for stream in list(self._request_streams.values()): for stream_id in list(self._request_streams.keys()):
try: try:
await stream.aclose() await self._clean_up_memory_streams(stream_id)
except Exception: except Exception as e:
logger.debug(f"Error closing request stream: {e}")
pass pass
self._request_streams.clear() self._request_streams.clear()
# Clean up the read and write streams
try:
await read_stream_writer.aclose()
await read_stream.aclose()
await write_stream_reader.aclose()
await write_stream.aclose()
except Exception as e:
logger.debug(f"Error closing streams: {e}")

View File

@@ -234,10 +234,11 @@ def create_app(
event_store=event_store, event_store=event_store,
) )
async def run_server(task_status=None):
async with http_transport.connect() as streams: async with http_transport.connect() as streams:
read_stream, write_stream = streams read_stream, write_stream = streams
if task_status:
async def run_server(): task_status.started()
await server.run( await server.run(
read_stream, read_stream,
write_stream, write_stream,
@@ -254,7 +255,7 @@ def create_app(
# Store the instance before starting the task to prevent races # Store the instance before starting the task to prevent races
server_instances[http_transport.mcp_session_id] = http_transport server_instances[http_transport.mcp_session_id] = http_transport
task_group.start_soon(run_server) await task_group.start(run_server)
await http_transport.handle_request(scope, receive, send) await http_transport.handle_request(scope, receive, send)
else: else: