Streamable Http - clean up server memory streams (#604)

This commit is contained in:
ihrpr
2025-05-02 14:59:17 +01:00
committed by GitHub
parent 74f5fcfa0d
commit 5d8eaf77be
4 changed files with 108 additions and 69 deletions

View File

@@ -185,10 +185,12 @@ def main(
)
server_instances[http_transport.mcp_session_id] = http_transport
logger.info(f"Created new transport with session ID: {new_session_id}")
async def run_server(task_status=None):
async with http_transport.connect() as streams:
read_stream, write_stream = streams
async def run_server():
if task_status:
task_status.started()
await app.run(
read_stream,
write_stream,
@@ -198,7 +200,7 @@ def main(
if not task_group:
raise RuntimeError("Task group is not initialized")
task_group.start_soon(run_server)
await task_group.start(run_server)
# Handle the HTTP request and return the response
await http_transport.handle_request(scope, receive, send)

View File

@@ -480,7 +480,7 @@ class Server(Generic[LifespanResultT]):
# but also make tracing exceptions much easier during testing and when using
# in-process servers.
raise_exceptions: bool = False,
# When True, the server as stateless deployments where
# When True, the server is stateless and
# clients can perform initialization with any node. The client must still follow
# the initialization lifecycle, but can do so with any available node
# rather than requiring initialization for each connection.

View File

@@ -129,6 +129,8 @@ class StreamableHTTPServerTransport:
_read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = (
None
)
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None
_write_stream: MemoryObjectSendStream[SessionMessage] | None = None
_write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None
def __init__(
@@ -163,7 +165,11 @@ class StreamableHTTPServerTransport:
self.is_json_response_enabled = is_json_response_enabled
self._event_store = event_store
self._request_streams: dict[
RequestId, MemoryObjectSendStream[EventMessage]
RequestId,
tuple[
MemoryObjectSendStream[EventMessage],
MemoryObjectReceiveStream[EventMessage],
],
] = {}
self._terminated = False
@@ -239,6 +245,19 @@ class StreamableHTTPServerTransport:
return event_data
async def _clean_up_memory_streams(self, request_id: RequestId) -> None:
"""Clean up memory streams for a given request ID."""
if request_id in self._request_streams:
try:
# Close the request stream
await self._request_streams[request_id][0].aclose()
await self._request_streams[request_id][1].aclose()
except Exception as e:
logger.debug(f"Error closing memory streams: {e}")
finally:
# Remove the request stream from the mapping
self._request_streams.pop(request_id, None)
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Application entry point that handles all HTTP requests"""
request = Request(scope, receive)
@@ -386,13 +405,11 @@ class StreamableHTTPServerTransport:
# Extract the request ID outside the try block for proper scope
request_id = str(message.root.id)
# Create promise stream for getting response
request_stream_writer, request_stream_reader = (
anyio.create_memory_object_stream[EventMessage](0)
)
# Register this stream for the request ID
self._request_streams[request_id] = request_stream_writer
self._request_streams[request_id] = anyio.create_memory_object_stream[
EventMessage
](0)
request_stream_reader = self._request_streams[request_id][1]
if self.is_json_response_enabled:
# Process the message
@@ -441,11 +458,7 @@ class StreamableHTTPServerTransport:
)
await response(scope, receive, send)
finally:
# Clean up the request stream
if request_id in self._request_streams:
self._request_streams.pop(request_id, None)
await request_stream_reader.aclose()
await request_stream_writer.aclose()
await self._clean_up_memory_streams(request_id)
else:
# Create SSE stream
sse_stream_writer, sse_stream_reader = (
@@ -467,16 +480,12 @@ class StreamableHTTPServerTransport:
event_message.message.root,
JSONRPCResponse | JSONRPCError,
):
if request_id:
self._request_streams.pop(request_id, None)
break
except Exception as e:
logger.exception(f"Error in SSE writer: {e}")
finally:
logger.debug("Closing SSE writer")
# Clean up the request-specific streams
if request_id and request_id in self._request_streams:
self._request_streams.pop(request_id, None)
await self._clean_up_memory_streams(request_id)
# Create and start EventSourceResponse
# SSE stream mode (original behavior)
@@ -507,9 +516,9 @@ class StreamableHTTPServerTransport:
await writer.send(session_message)
except Exception:
logger.exception("SSE response error")
# Clean up the request stream if something goes wrong
if request_id and request_id in self._request_streams:
self._request_streams.pop(request_id, None)
await sse_stream_writer.aclose()
await sse_stream_reader.aclose()
await self._clean_up_memory_streams(request_id)
except Exception as err:
logger.exception("Error handling POST request")
@@ -581,12 +590,11 @@ class StreamableHTTPServerTransport:
async def standalone_sse_writer():
try:
# Create a standalone message stream for server-initiated messages
standalone_stream_writer, standalone_stream_reader = (
self._request_streams[GET_STREAM_KEY] = (
anyio.create_memory_object_stream[EventMessage](0)
)
# Register this stream using the special key
self._request_streams[GET_STREAM_KEY] = standalone_stream_writer
standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1]
async with sse_stream_writer, standalone_stream_reader:
# Process messages from the standalone stream
@@ -603,8 +611,7 @@ class StreamableHTTPServerTransport:
logger.exception(f"Error in standalone SSE writer: {e}")
finally:
logger.debug("Closing standalone SSE writer")
# Remove the stream from request_streams
self._request_streams.pop(GET_STREAM_KEY, None)
await self._clean_up_memory_streams(GET_STREAM_KEY)
# Create and start EventSourceResponse
response = EventSourceResponse(
@@ -618,8 +625,9 @@ class StreamableHTTPServerTransport:
await response(request.scope, request.receive, send)
except Exception as e:
logger.exception(f"Error in standalone SSE response: {e}")
# Clean up the request stream
self._request_streams.pop(GET_STREAM_KEY, None)
await sse_stream_writer.aclose()
await sse_stream_reader.aclose()
await self._clean_up_memory_streams(GET_STREAM_KEY)
async def _handle_delete_request(self, request: Request, send: Send) -> None:
"""Handle DELETE requests for explicit session termination."""
@@ -636,7 +644,7 @@ class StreamableHTTPServerTransport:
if not await self._validate_session(request, send):
return
self._terminate_session()
await self._terminate_session()
response = self._create_json_response(
None,
@@ -644,7 +652,7 @@ class StreamableHTTPServerTransport:
)
await response(request.scope, request.receive, send)
def _terminate_session(self) -> None:
async def _terminate_session(self) -> None:
"""Terminate the current session, closing all streams.
Once terminated, all requests with this session ID will receive 404 Not Found.
@@ -656,19 +664,26 @@ class StreamableHTTPServerTransport:
# We need a copy of the keys to avoid modification during iteration
request_stream_keys = list(self._request_streams.keys())
# Close all request streams (synchronously)
# Close all request streams asynchronously
for key in request_stream_keys:
try:
# Get the stream
stream = self._request_streams.get(key)
if stream:
# We must use close() here, not aclose() since this is a sync method
stream.close()
await self._clean_up_memory_streams(key)
except Exception as e:
logger.debug(f"Error closing stream {key} during termination: {e}")
# Clear the request streams dictionary immediately
self._request_streams.clear()
try:
if self._read_stream_writer is not None:
await self._read_stream_writer.aclose()
if self._read_stream is not None:
await self._read_stream.aclose()
if self._write_stream_reader is not None:
await self._write_stream_reader.aclose()
if self._write_stream is not None:
await self._write_stream.aclose()
except Exception as e:
logger.debug(f"Error closing streams: {e}")
async def _handle_unsupported_request(self, request: Request, send: Send) -> None:
"""Handle unsupported HTTP methods."""
@@ -756,10 +771,10 @@ class StreamableHTTPServerTransport:
# If stream ID not in mapping, create it
if stream_id and stream_id not in self._request_streams:
msg_writer, msg_reader = anyio.create_memory_object_stream[
EventMessage
](0)
self._request_streams[stream_id] = msg_writer
self._request_streams[stream_id] = (
anyio.create_memory_object_stream[EventMessage](0)
)
msg_reader = self._request_streams[stream_id][1]
# Forward messages to SSE
async with msg_reader:
@@ -781,6 +796,9 @@ class StreamableHTTPServerTransport:
await response(request.scope, request.receive, send)
except Exception as e:
logger.exception(f"Error in replay response: {e}")
finally:
await sse_stream_writer.aclose()
await sse_stream_reader.aclose()
except Exception as e:
logger.exception(f"Error replaying events: {e}")
@@ -818,7 +836,9 @@ class StreamableHTTPServerTransport:
# Store the streams
self._read_stream_writer = read_stream_writer
self._read_stream = read_stream
self._write_stream_reader = write_stream_reader
self._write_stream = write_stream
# Start a task group for message routing
async with anyio.create_task_group() as tg:
@@ -863,7 +883,7 @@ class StreamableHTTPServerTransport:
if request_stream_id in self._request_streams:
try:
# Send both the message and the event ID
await self._request_streams[request_stream_id].send(
await self._request_streams[request_stream_id][0].send(
EventMessage(message, event_id)
)
except (
@@ -872,6 +892,12 @@ class StreamableHTTPServerTransport:
):
# Stream might be closed, remove from registry
self._request_streams.pop(request_stream_id, None)
else:
logging.debug(
f"""Request stream {request_stream_id} not found
for message. Still processing message as the client
might reconnect and replay."""
)
except Exception as e:
logger.exception(f"Error in message router: {e}")
@@ -882,9 +908,19 @@ class StreamableHTTPServerTransport:
# Yield the streams for the caller to use
yield read_stream, write_stream
finally:
for stream in list(self._request_streams.values()):
for stream_id in list(self._request_streams.keys()):
try:
await stream.aclose()
except Exception:
await self._clean_up_memory_streams(stream_id)
except Exception as e:
logger.debug(f"Error closing request stream: {e}")
pass
self._request_streams.clear()
# Clean up the read and write streams
try:
await read_stream_writer.aclose()
await read_stream.aclose()
await write_stream_reader.aclose()
await write_stream.aclose()
except Exception as e:
logger.debug(f"Error closing streams: {e}")

View File

@@ -234,10 +234,11 @@ def create_app(
event_store=event_store,
)
async def run_server(task_status=None):
async with http_transport.connect() as streams:
read_stream, write_stream = streams
async def run_server():
if task_status:
task_status.started()
await server.run(
read_stream,
write_stream,
@@ -254,7 +255,7 @@ def create_app(
# Store the instance before starting the task to prevent races
server_instances[http_transport.mcp_session_id] = http_transport
task_group.start_soon(run_server)
await task_group.start(run_server)
await http_transport.handle_request(scope, receive, send)
else: