mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 06:54:18 +01:00
Streamable Http - clean up server memory streams (#604)
This commit is contained in:
@@ -185,10 +185,12 @@ def main(
|
||||
)
|
||||
server_instances[http_transport.mcp_session_id] = http_transport
|
||||
logger.info(f"Created new transport with session ID: {new_session_id}")
|
||||
|
||||
async def run_server(task_status=None):
|
||||
async with http_transport.connect() as streams:
|
||||
read_stream, write_stream = streams
|
||||
|
||||
async def run_server():
|
||||
if task_status:
|
||||
task_status.started()
|
||||
await app.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
@@ -198,7 +200,7 @@ def main(
|
||||
if not task_group:
|
||||
raise RuntimeError("Task group is not initialized")
|
||||
|
||||
task_group.start_soon(run_server)
|
||||
await task_group.start(run_server)
|
||||
|
||||
# Handle the HTTP request and return the response
|
||||
await http_transport.handle_request(scope, receive, send)
|
||||
|
||||
@@ -480,7 +480,7 @@ class Server(Generic[LifespanResultT]):
|
||||
# but also make tracing exceptions much easier during testing and when using
|
||||
# in-process servers.
|
||||
raise_exceptions: bool = False,
|
||||
# When True, the server as stateless deployments where
|
||||
# When True, the server is stateless and
|
||||
# clients can perform initialization with any node. The client must still follow
|
||||
# the initialization lifecycle, but can do so with any available node
|
||||
# rather than requiring initialization for each connection.
|
||||
|
||||
@@ -129,6 +129,8 @@ class StreamableHTTPServerTransport:
|
||||
_read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = (
|
||||
None
|
||||
)
|
||||
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None
|
||||
_write_stream: MemoryObjectSendStream[SessionMessage] | None = None
|
||||
_write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None
|
||||
|
||||
def __init__(
|
||||
@@ -163,7 +165,11 @@ class StreamableHTTPServerTransport:
|
||||
self.is_json_response_enabled = is_json_response_enabled
|
||||
self._event_store = event_store
|
||||
self._request_streams: dict[
|
||||
RequestId, MemoryObjectSendStream[EventMessage]
|
||||
RequestId,
|
||||
tuple[
|
||||
MemoryObjectSendStream[EventMessage],
|
||||
MemoryObjectReceiveStream[EventMessage],
|
||||
],
|
||||
] = {}
|
||||
self._terminated = False
|
||||
|
||||
@@ -239,6 +245,19 @@ class StreamableHTTPServerTransport:
|
||||
|
||||
return event_data
|
||||
|
||||
async def _clean_up_memory_streams(self, request_id: RequestId) -> None:
|
||||
"""Clean up memory streams for a given request ID."""
|
||||
if request_id in self._request_streams:
|
||||
try:
|
||||
# Close the request stream
|
||||
await self._request_streams[request_id][0].aclose()
|
||||
await self._request_streams[request_id][1].aclose()
|
||||
except Exception as e:
|
||||
logger.debug(f"Error closing memory streams: {e}")
|
||||
finally:
|
||||
# Remove the request stream from the mapping
|
||||
self._request_streams.pop(request_id, None)
|
||||
|
||||
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
"""Application entry point that handles all HTTP requests"""
|
||||
request = Request(scope, receive)
|
||||
@@ -386,13 +405,11 @@ class StreamableHTTPServerTransport:
|
||||
|
||||
# Extract the request ID outside the try block for proper scope
|
||||
request_id = str(message.root.id)
|
||||
# Create promise stream for getting response
|
||||
request_stream_writer, request_stream_reader = (
|
||||
anyio.create_memory_object_stream[EventMessage](0)
|
||||
)
|
||||
|
||||
# Register this stream for the request ID
|
||||
self._request_streams[request_id] = request_stream_writer
|
||||
self._request_streams[request_id] = anyio.create_memory_object_stream[
|
||||
EventMessage
|
||||
](0)
|
||||
request_stream_reader = self._request_streams[request_id][1]
|
||||
|
||||
if self.is_json_response_enabled:
|
||||
# Process the message
|
||||
@@ -441,11 +458,7 @@ class StreamableHTTPServerTransport:
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
finally:
|
||||
# Clean up the request stream
|
||||
if request_id in self._request_streams:
|
||||
self._request_streams.pop(request_id, None)
|
||||
await request_stream_reader.aclose()
|
||||
await request_stream_writer.aclose()
|
||||
await self._clean_up_memory_streams(request_id)
|
||||
else:
|
||||
# Create SSE stream
|
||||
sse_stream_writer, sse_stream_reader = (
|
||||
@@ -467,16 +480,12 @@ class StreamableHTTPServerTransport:
|
||||
event_message.message.root,
|
||||
JSONRPCResponse | JSONRPCError,
|
||||
):
|
||||
if request_id:
|
||||
self._request_streams.pop(request_id, None)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in SSE writer: {e}")
|
||||
finally:
|
||||
logger.debug("Closing SSE writer")
|
||||
# Clean up the request-specific streams
|
||||
if request_id and request_id in self._request_streams:
|
||||
self._request_streams.pop(request_id, None)
|
||||
await self._clean_up_memory_streams(request_id)
|
||||
|
||||
# Create and start EventSourceResponse
|
||||
# SSE stream mode (original behavior)
|
||||
@@ -507,9 +516,9 @@ class StreamableHTTPServerTransport:
|
||||
await writer.send(session_message)
|
||||
except Exception:
|
||||
logger.exception("SSE response error")
|
||||
# Clean up the request stream if something goes wrong
|
||||
if request_id and request_id in self._request_streams:
|
||||
self._request_streams.pop(request_id, None)
|
||||
await sse_stream_writer.aclose()
|
||||
await sse_stream_reader.aclose()
|
||||
await self._clean_up_memory_streams(request_id)
|
||||
|
||||
except Exception as err:
|
||||
logger.exception("Error handling POST request")
|
||||
@@ -581,12 +590,11 @@ class StreamableHTTPServerTransport:
|
||||
async def standalone_sse_writer():
|
||||
try:
|
||||
# Create a standalone message stream for server-initiated messages
|
||||
standalone_stream_writer, standalone_stream_reader = (
|
||||
|
||||
self._request_streams[GET_STREAM_KEY] = (
|
||||
anyio.create_memory_object_stream[EventMessage](0)
|
||||
)
|
||||
|
||||
# Register this stream using the special key
|
||||
self._request_streams[GET_STREAM_KEY] = standalone_stream_writer
|
||||
standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1]
|
||||
|
||||
async with sse_stream_writer, standalone_stream_reader:
|
||||
# Process messages from the standalone stream
|
||||
@@ -603,8 +611,7 @@ class StreamableHTTPServerTransport:
|
||||
logger.exception(f"Error in standalone SSE writer: {e}")
|
||||
finally:
|
||||
logger.debug("Closing standalone SSE writer")
|
||||
# Remove the stream from request_streams
|
||||
self._request_streams.pop(GET_STREAM_KEY, None)
|
||||
await self._clean_up_memory_streams(GET_STREAM_KEY)
|
||||
|
||||
# Create and start EventSourceResponse
|
||||
response = EventSourceResponse(
|
||||
@@ -618,8 +625,9 @@ class StreamableHTTPServerTransport:
|
||||
await response(request.scope, request.receive, send)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in standalone SSE response: {e}")
|
||||
# Clean up the request stream
|
||||
self._request_streams.pop(GET_STREAM_KEY, None)
|
||||
await sse_stream_writer.aclose()
|
||||
await sse_stream_reader.aclose()
|
||||
await self._clean_up_memory_streams(GET_STREAM_KEY)
|
||||
|
||||
async def _handle_delete_request(self, request: Request, send: Send) -> None:
|
||||
"""Handle DELETE requests for explicit session termination."""
|
||||
@@ -636,7 +644,7 @@ class StreamableHTTPServerTransport:
|
||||
if not await self._validate_session(request, send):
|
||||
return
|
||||
|
||||
self._terminate_session()
|
||||
await self._terminate_session()
|
||||
|
||||
response = self._create_json_response(
|
||||
None,
|
||||
@@ -644,7 +652,7 @@ class StreamableHTTPServerTransport:
|
||||
)
|
||||
await response(request.scope, request.receive, send)
|
||||
|
||||
def _terminate_session(self) -> None:
|
||||
async def _terminate_session(self) -> None:
|
||||
"""Terminate the current session, closing all streams.
|
||||
|
||||
Once terminated, all requests with this session ID will receive 404 Not Found.
|
||||
@@ -656,19 +664,26 @@ class StreamableHTTPServerTransport:
|
||||
# We need a copy of the keys to avoid modification during iteration
|
||||
request_stream_keys = list(self._request_streams.keys())
|
||||
|
||||
# Close all request streams (synchronously)
|
||||
# Close all request streams asynchronously
|
||||
for key in request_stream_keys:
|
||||
try:
|
||||
# Get the stream
|
||||
stream = self._request_streams.get(key)
|
||||
if stream:
|
||||
# We must use close() here, not aclose() since this is a sync method
|
||||
stream.close()
|
||||
await self._clean_up_memory_streams(key)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error closing stream {key} during termination: {e}")
|
||||
|
||||
# Clear the request streams dictionary immediately
|
||||
self._request_streams.clear()
|
||||
try:
|
||||
if self._read_stream_writer is not None:
|
||||
await self._read_stream_writer.aclose()
|
||||
if self._read_stream is not None:
|
||||
await self._read_stream.aclose()
|
||||
if self._write_stream_reader is not None:
|
||||
await self._write_stream_reader.aclose()
|
||||
if self._write_stream is not None:
|
||||
await self._write_stream.aclose()
|
||||
except Exception as e:
|
||||
logger.debug(f"Error closing streams: {e}")
|
||||
|
||||
async def _handle_unsupported_request(self, request: Request, send: Send) -> None:
|
||||
"""Handle unsupported HTTP methods."""
|
||||
@@ -756,10 +771,10 @@ class StreamableHTTPServerTransport:
|
||||
|
||||
# If stream ID not in mapping, create it
|
||||
if stream_id and stream_id not in self._request_streams:
|
||||
msg_writer, msg_reader = anyio.create_memory_object_stream[
|
||||
EventMessage
|
||||
](0)
|
||||
self._request_streams[stream_id] = msg_writer
|
||||
self._request_streams[stream_id] = (
|
||||
anyio.create_memory_object_stream[EventMessage](0)
|
||||
)
|
||||
msg_reader = self._request_streams[stream_id][1]
|
||||
|
||||
# Forward messages to SSE
|
||||
async with msg_reader:
|
||||
@@ -781,6 +796,9 @@ class StreamableHTTPServerTransport:
|
||||
await response(request.scope, request.receive, send)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in replay response: {e}")
|
||||
finally:
|
||||
await sse_stream_writer.aclose()
|
||||
await sse_stream_reader.aclose()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error replaying events: {e}")
|
||||
@@ -818,7 +836,9 @@ class StreamableHTTPServerTransport:
|
||||
|
||||
# Store the streams
|
||||
self._read_stream_writer = read_stream_writer
|
||||
self._read_stream = read_stream
|
||||
self._write_stream_reader = write_stream_reader
|
||||
self._write_stream = write_stream
|
||||
|
||||
# Start a task group for message routing
|
||||
async with anyio.create_task_group() as tg:
|
||||
@@ -863,7 +883,7 @@ class StreamableHTTPServerTransport:
|
||||
if request_stream_id in self._request_streams:
|
||||
try:
|
||||
# Send both the message and the event ID
|
||||
await self._request_streams[request_stream_id].send(
|
||||
await self._request_streams[request_stream_id][0].send(
|
||||
EventMessage(message, event_id)
|
||||
)
|
||||
except (
|
||||
@@ -872,6 +892,12 @@ class StreamableHTTPServerTransport:
|
||||
):
|
||||
# Stream might be closed, remove from registry
|
||||
self._request_streams.pop(request_stream_id, None)
|
||||
else:
|
||||
logging.debug(
|
||||
f"""Request stream {request_stream_id} not found
|
||||
for message. Still processing message as the client
|
||||
might reconnect and replay."""
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in message router: {e}")
|
||||
|
||||
@@ -882,9 +908,19 @@ class StreamableHTTPServerTransport:
|
||||
# Yield the streams for the caller to use
|
||||
yield read_stream, write_stream
|
||||
finally:
|
||||
for stream in list(self._request_streams.values()):
|
||||
for stream_id in list(self._request_streams.keys()):
|
||||
try:
|
||||
await stream.aclose()
|
||||
except Exception:
|
||||
await self._clean_up_memory_streams(stream_id)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error closing request stream: {e}")
|
||||
pass
|
||||
self._request_streams.clear()
|
||||
|
||||
# Clean up the read and write streams
|
||||
try:
|
||||
await read_stream_writer.aclose()
|
||||
await read_stream.aclose()
|
||||
await write_stream_reader.aclose()
|
||||
await write_stream.aclose()
|
||||
except Exception as e:
|
||||
logger.debug(f"Error closing streams: {e}")
|
||||
|
||||
@@ -234,10 +234,11 @@ def create_app(
|
||||
event_store=event_store,
|
||||
)
|
||||
|
||||
async def run_server(task_status=None):
|
||||
async with http_transport.connect() as streams:
|
||||
read_stream, write_stream = streams
|
||||
|
||||
async def run_server():
|
||||
if task_status:
|
||||
task_status.started()
|
||||
await server.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
@@ -254,7 +255,7 @@ def create_app(
|
||||
|
||||
# Store the instance before starting the task to prevent races
|
||||
server_instances[http_transport.mcp_session_id] = http_transport
|
||||
task_group.start_soon(run_server)
|
||||
await task_group.start(run_server)
|
||||
|
||||
await http_transport.handle_request(scope, receive, send)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user