mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 06:54:18 +01:00
Streamable Http - clean up server memory streams (#604)
This commit is contained in:
@@ -185,20 +185,22 @@ def main(
|
|||||||
)
|
)
|
||||||
server_instances[http_transport.mcp_session_id] = http_transport
|
server_instances[http_transport.mcp_session_id] = http_transport
|
||||||
logger.info(f"Created new transport with session ID: {new_session_id}")
|
logger.info(f"Created new transport with session ID: {new_session_id}")
|
||||||
async with http_transport.connect() as streams:
|
|
||||||
read_stream, write_stream = streams
|
|
||||||
|
|
||||||
async def run_server():
|
async def run_server(task_status=None):
|
||||||
await app.run(
|
async with http_transport.connect() as streams:
|
||||||
read_stream,
|
read_stream, write_stream = streams
|
||||||
write_stream,
|
if task_status:
|
||||||
app.create_initialization_options(),
|
task_status.started()
|
||||||
)
|
await app.run(
|
||||||
|
read_stream,
|
||||||
|
write_stream,
|
||||||
|
app.create_initialization_options(),
|
||||||
|
)
|
||||||
|
|
||||||
if not task_group:
|
if not task_group:
|
||||||
raise RuntimeError("Task group is not initialized")
|
raise RuntimeError("Task group is not initialized")
|
||||||
|
|
||||||
task_group.start_soon(run_server)
|
await task_group.start(run_server)
|
||||||
|
|
||||||
# Handle the HTTP request and return the response
|
# Handle the HTTP request and return the response
|
||||||
await http_transport.handle_request(scope, receive, send)
|
await http_transport.handle_request(scope, receive, send)
|
||||||
|
|||||||
@@ -480,7 +480,7 @@ class Server(Generic[LifespanResultT]):
|
|||||||
# but also make tracing exceptions much easier during testing and when using
|
# but also make tracing exceptions much easier during testing and when using
|
||||||
# in-process servers.
|
# in-process servers.
|
||||||
raise_exceptions: bool = False,
|
raise_exceptions: bool = False,
|
||||||
# When True, the server as stateless deployments where
|
# When True, the server is stateless and
|
||||||
# clients can perform initialization with any node. The client must still follow
|
# clients can perform initialization with any node. The client must still follow
|
||||||
# the initialization lifecycle, but can do so with any available node
|
# the initialization lifecycle, but can do so with any available node
|
||||||
# rather than requiring initialization for each connection.
|
# rather than requiring initialization for each connection.
|
||||||
|
|||||||
@@ -129,6 +129,8 @@ class StreamableHTTPServerTransport:
|
|||||||
_read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = (
|
_read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = (
|
||||||
None
|
None
|
||||||
)
|
)
|
||||||
|
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None
|
||||||
|
_write_stream: MemoryObjectSendStream[SessionMessage] | None = None
|
||||||
_write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None
|
_write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -163,7 +165,11 @@ class StreamableHTTPServerTransport:
|
|||||||
self.is_json_response_enabled = is_json_response_enabled
|
self.is_json_response_enabled = is_json_response_enabled
|
||||||
self._event_store = event_store
|
self._event_store = event_store
|
||||||
self._request_streams: dict[
|
self._request_streams: dict[
|
||||||
RequestId, MemoryObjectSendStream[EventMessage]
|
RequestId,
|
||||||
|
tuple[
|
||||||
|
MemoryObjectSendStream[EventMessage],
|
||||||
|
MemoryObjectReceiveStream[EventMessage],
|
||||||
|
],
|
||||||
] = {}
|
] = {}
|
||||||
self._terminated = False
|
self._terminated = False
|
||||||
|
|
||||||
@@ -239,6 +245,19 @@ class StreamableHTTPServerTransport:
|
|||||||
|
|
||||||
return event_data
|
return event_data
|
||||||
|
|
||||||
|
async def _clean_up_memory_streams(self, request_id: RequestId) -> None:
|
||||||
|
"""Clean up memory streams for a given request ID."""
|
||||||
|
if request_id in self._request_streams:
|
||||||
|
try:
|
||||||
|
# Close the request stream
|
||||||
|
await self._request_streams[request_id][0].aclose()
|
||||||
|
await self._request_streams[request_id][1].aclose()
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error closing memory streams: {e}")
|
||||||
|
finally:
|
||||||
|
# Remove the request stream from the mapping
|
||||||
|
self._request_streams.pop(request_id, None)
|
||||||
|
|
||||||
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
|
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
"""Application entry point that handles all HTTP requests"""
|
"""Application entry point that handles all HTTP requests"""
|
||||||
request = Request(scope, receive)
|
request = Request(scope, receive)
|
||||||
@@ -386,13 +405,11 @@ class StreamableHTTPServerTransport:
|
|||||||
|
|
||||||
# Extract the request ID outside the try block for proper scope
|
# Extract the request ID outside the try block for proper scope
|
||||||
request_id = str(message.root.id)
|
request_id = str(message.root.id)
|
||||||
# Create promise stream for getting response
|
|
||||||
request_stream_writer, request_stream_reader = (
|
|
||||||
anyio.create_memory_object_stream[EventMessage](0)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Register this stream for the request ID
|
# Register this stream for the request ID
|
||||||
self._request_streams[request_id] = request_stream_writer
|
self._request_streams[request_id] = anyio.create_memory_object_stream[
|
||||||
|
EventMessage
|
||||||
|
](0)
|
||||||
|
request_stream_reader = self._request_streams[request_id][1]
|
||||||
|
|
||||||
if self.is_json_response_enabled:
|
if self.is_json_response_enabled:
|
||||||
# Process the message
|
# Process the message
|
||||||
@@ -441,11 +458,7 @@ class StreamableHTTPServerTransport:
|
|||||||
)
|
)
|
||||||
await response(scope, receive, send)
|
await response(scope, receive, send)
|
||||||
finally:
|
finally:
|
||||||
# Clean up the request stream
|
await self._clean_up_memory_streams(request_id)
|
||||||
if request_id in self._request_streams:
|
|
||||||
self._request_streams.pop(request_id, None)
|
|
||||||
await request_stream_reader.aclose()
|
|
||||||
await request_stream_writer.aclose()
|
|
||||||
else:
|
else:
|
||||||
# Create SSE stream
|
# Create SSE stream
|
||||||
sse_stream_writer, sse_stream_reader = (
|
sse_stream_writer, sse_stream_reader = (
|
||||||
@@ -467,16 +480,12 @@ class StreamableHTTPServerTransport:
|
|||||||
event_message.message.root,
|
event_message.message.root,
|
||||||
JSONRPCResponse | JSONRPCError,
|
JSONRPCResponse | JSONRPCError,
|
||||||
):
|
):
|
||||||
if request_id:
|
|
||||||
self._request_streams.pop(request_id, None)
|
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Error in SSE writer: {e}")
|
logger.exception(f"Error in SSE writer: {e}")
|
||||||
finally:
|
finally:
|
||||||
logger.debug("Closing SSE writer")
|
logger.debug("Closing SSE writer")
|
||||||
# Clean up the request-specific streams
|
await self._clean_up_memory_streams(request_id)
|
||||||
if request_id and request_id in self._request_streams:
|
|
||||||
self._request_streams.pop(request_id, None)
|
|
||||||
|
|
||||||
# Create and start EventSourceResponse
|
# Create and start EventSourceResponse
|
||||||
# SSE stream mode (original behavior)
|
# SSE stream mode (original behavior)
|
||||||
@@ -507,9 +516,9 @@ class StreamableHTTPServerTransport:
|
|||||||
await writer.send(session_message)
|
await writer.send(session_message)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("SSE response error")
|
logger.exception("SSE response error")
|
||||||
# Clean up the request stream if something goes wrong
|
await sse_stream_writer.aclose()
|
||||||
if request_id and request_id in self._request_streams:
|
await sse_stream_reader.aclose()
|
||||||
self._request_streams.pop(request_id, None)
|
await self._clean_up_memory_streams(request_id)
|
||||||
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
logger.exception("Error handling POST request")
|
logger.exception("Error handling POST request")
|
||||||
@@ -581,12 +590,11 @@ class StreamableHTTPServerTransport:
|
|||||||
async def standalone_sse_writer():
|
async def standalone_sse_writer():
|
||||||
try:
|
try:
|
||||||
# Create a standalone message stream for server-initiated messages
|
# Create a standalone message stream for server-initiated messages
|
||||||
standalone_stream_writer, standalone_stream_reader = (
|
|
||||||
|
self._request_streams[GET_STREAM_KEY] = (
|
||||||
anyio.create_memory_object_stream[EventMessage](0)
|
anyio.create_memory_object_stream[EventMessage](0)
|
||||||
)
|
)
|
||||||
|
standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1]
|
||||||
# Register this stream using the special key
|
|
||||||
self._request_streams[GET_STREAM_KEY] = standalone_stream_writer
|
|
||||||
|
|
||||||
async with sse_stream_writer, standalone_stream_reader:
|
async with sse_stream_writer, standalone_stream_reader:
|
||||||
# Process messages from the standalone stream
|
# Process messages from the standalone stream
|
||||||
@@ -603,8 +611,7 @@ class StreamableHTTPServerTransport:
|
|||||||
logger.exception(f"Error in standalone SSE writer: {e}")
|
logger.exception(f"Error in standalone SSE writer: {e}")
|
||||||
finally:
|
finally:
|
||||||
logger.debug("Closing standalone SSE writer")
|
logger.debug("Closing standalone SSE writer")
|
||||||
# Remove the stream from request_streams
|
await self._clean_up_memory_streams(GET_STREAM_KEY)
|
||||||
self._request_streams.pop(GET_STREAM_KEY, None)
|
|
||||||
|
|
||||||
# Create and start EventSourceResponse
|
# Create and start EventSourceResponse
|
||||||
response = EventSourceResponse(
|
response = EventSourceResponse(
|
||||||
@@ -618,8 +625,9 @@ class StreamableHTTPServerTransport:
|
|||||||
await response(request.scope, request.receive, send)
|
await response(request.scope, request.receive, send)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Error in standalone SSE response: {e}")
|
logger.exception(f"Error in standalone SSE response: {e}")
|
||||||
# Clean up the request stream
|
await sse_stream_writer.aclose()
|
||||||
self._request_streams.pop(GET_STREAM_KEY, None)
|
await sse_stream_reader.aclose()
|
||||||
|
await self._clean_up_memory_streams(GET_STREAM_KEY)
|
||||||
|
|
||||||
async def _handle_delete_request(self, request: Request, send: Send) -> None:
|
async def _handle_delete_request(self, request: Request, send: Send) -> None:
|
||||||
"""Handle DELETE requests for explicit session termination."""
|
"""Handle DELETE requests for explicit session termination."""
|
||||||
@@ -636,7 +644,7 @@ class StreamableHTTPServerTransport:
|
|||||||
if not await self._validate_session(request, send):
|
if not await self._validate_session(request, send):
|
||||||
return
|
return
|
||||||
|
|
||||||
self._terminate_session()
|
await self._terminate_session()
|
||||||
|
|
||||||
response = self._create_json_response(
|
response = self._create_json_response(
|
||||||
None,
|
None,
|
||||||
@@ -644,7 +652,7 @@ class StreamableHTTPServerTransport:
|
|||||||
)
|
)
|
||||||
await response(request.scope, request.receive, send)
|
await response(request.scope, request.receive, send)
|
||||||
|
|
||||||
def _terminate_session(self) -> None:
|
async def _terminate_session(self) -> None:
|
||||||
"""Terminate the current session, closing all streams.
|
"""Terminate the current session, closing all streams.
|
||||||
|
|
||||||
Once terminated, all requests with this session ID will receive 404 Not Found.
|
Once terminated, all requests with this session ID will receive 404 Not Found.
|
||||||
@@ -656,19 +664,26 @@ class StreamableHTTPServerTransport:
|
|||||||
# We need a copy of the keys to avoid modification during iteration
|
# We need a copy of the keys to avoid modification during iteration
|
||||||
request_stream_keys = list(self._request_streams.keys())
|
request_stream_keys = list(self._request_streams.keys())
|
||||||
|
|
||||||
# Close all request streams (synchronously)
|
# Close all request streams asynchronously
|
||||||
for key in request_stream_keys:
|
for key in request_stream_keys:
|
||||||
try:
|
try:
|
||||||
# Get the stream
|
await self._clean_up_memory_streams(key)
|
||||||
stream = self._request_streams.get(key)
|
|
||||||
if stream:
|
|
||||||
# We must use close() here, not aclose() since this is a sync method
|
|
||||||
stream.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Error closing stream {key} during termination: {e}")
|
logger.debug(f"Error closing stream {key} during termination: {e}")
|
||||||
|
|
||||||
# Clear the request streams dictionary immediately
|
# Clear the request streams dictionary immediately
|
||||||
self._request_streams.clear()
|
self._request_streams.clear()
|
||||||
|
try:
|
||||||
|
if self._read_stream_writer is not None:
|
||||||
|
await self._read_stream_writer.aclose()
|
||||||
|
if self._read_stream is not None:
|
||||||
|
await self._read_stream.aclose()
|
||||||
|
if self._write_stream_reader is not None:
|
||||||
|
await self._write_stream_reader.aclose()
|
||||||
|
if self._write_stream is not None:
|
||||||
|
await self._write_stream.aclose()
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error closing streams: {e}")
|
||||||
|
|
||||||
async def _handle_unsupported_request(self, request: Request, send: Send) -> None:
|
async def _handle_unsupported_request(self, request: Request, send: Send) -> None:
|
||||||
"""Handle unsupported HTTP methods."""
|
"""Handle unsupported HTTP methods."""
|
||||||
@@ -756,10 +771,10 @@ class StreamableHTTPServerTransport:
|
|||||||
|
|
||||||
# If stream ID not in mapping, create it
|
# If stream ID not in mapping, create it
|
||||||
if stream_id and stream_id not in self._request_streams:
|
if stream_id and stream_id not in self._request_streams:
|
||||||
msg_writer, msg_reader = anyio.create_memory_object_stream[
|
self._request_streams[stream_id] = (
|
||||||
EventMessage
|
anyio.create_memory_object_stream[EventMessage](0)
|
||||||
](0)
|
)
|
||||||
self._request_streams[stream_id] = msg_writer
|
msg_reader = self._request_streams[stream_id][1]
|
||||||
|
|
||||||
# Forward messages to SSE
|
# Forward messages to SSE
|
||||||
async with msg_reader:
|
async with msg_reader:
|
||||||
@@ -781,6 +796,9 @@ class StreamableHTTPServerTransport:
|
|||||||
await response(request.scope, request.receive, send)
|
await response(request.scope, request.receive, send)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Error in replay response: {e}")
|
logger.exception(f"Error in replay response: {e}")
|
||||||
|
finally:
|
||||||
|
await sse_stream_writer.aclose()
|
||||||
|
await sse_stream_reader.aclose()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Error replaying events: {e}")
|
logger.exception(f"Error replaying events: {e}")
|
||||||
@@ -818,7 +836,9 @@ class StreamableHTTPServerTransport:
|
|||||||
|
|
||||||
# Store the streams
|
# Store the streams
|
||||||
self._read_stream_writer = read_stream_writer
|
self._read_stream_writer = read_stream_writer
|
||||||
|
self._read_stream = read_stream
|
||||||
self._write_stream_reader = write_stream_reader
|
self._write_stream_reader = write_stream_reader
|
||||||
|
self._write_stream = write_stream
|
||||||
|
|
||||||
# Start a task group for message routing
|
# Start a task group for message routing
|
||||||
async with anyio.create_task_group() as tg:
|
async with anyio.create_task_group() as tg:
|
||||||
@@ -863,7 +883,7 @@ class StreamableHTTPServerTransport:
|
|||||||
if request_stream_id in self._request_streams:
|
if request_stream_id in self._request_streams:
|
||||||
try:
|
try:
|
||||||
# Send both the message and the event ID
|
# Send both the message and the event ID
|
||||||
await self._request_streams[request_stream_id].send(
|
await self._request_streams[request_stream_id][0].send(
|
||||||
EventMessage(message, event_id)
|
EventMessage(message, event_id)
|
||||||
)
|
)
|
||||||
except (
|
except (
|
||||||
@@ -872,6 +892,12 @@ class StreamableHTTPServerTransport:
|
|||||||
):
|
):
|
||||||
# Stream might be closed, remove from registry
|
# Stream might be closed, remove from registry
|
||||||
self._request_streams.pop(request_stream_id, None)
|
self._request_streams.pop(request_stream_id, None)
|
||||||
|
else:
|
||||||
|
logging.debug(
|
||||||
|
f"""Request stream {request_stream_id} not found
|
||||||
|
for message. Still processing message as the client
|
||||||
|
might reconnect and replay."""
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Error in message router: {e}")
|
logger.exception(f"Error in message router: {e}")
|
||||||
|
|
||||||
@@ -882,9 +908,19 @@ class StreamableHTTPServerTransport:
|
|||||||
# Yield the streams for the caller to use
|
# Yield the streams for the caller to use
|
||||||
yield read_stream, write_stream
|
yield read_stream, write_stream
|
||||||
finally:
|
finally:
|
||||||
for stream in list(self._request_streams.values()):
|
for stream_id in list(self._request_streams.keys()):
|
||||||
try:
|
try:
|
||||||
await stream.aclose()
|
await self._clean_up_memory_streams(stream_id)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
|
logger.debug(f"Error closing request stream: {e}")
|
||||||
pass
|
pass
|
||||||
self._request_streams.clear()
|
self._request_streams.clear()
|
||||||
|
|
||||||
|
# Clean up the read and write streams
|
||||||
|
try:
|
||||||
|
await read_stream_writer.aclose()
|
||||||
|
await read_stream.aclose()
|
||||||
|
await write_stream_reader.aclose()
|
||||||
|
await write_stream.aclose()
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error closing streams: {e}")
|
||||||
|
|||||||
@@ -234,29 +234,30 @@ def create_app(
|
|||||||
event_store=event_store,
|
event_store=event_store,
|
||||||
)
|
)
|
||||||
|
|
||||||
async with http_transport.connect() as streams:
|
async def run_server(task_status=None):
|
||||||
read_stream, write_stream = streams
|
async with http_transport.connect() as streams:
|
||||||
|
read_stream, write_stream = streams
|
||||||
async def run_server():
|
if task_status:
|
||||||
|
task_status.started()
|
||||||
await server.run(
|
await server.run(
|
||||||
read_stream,
|
read_stream,
|
||||||
write_stream,
|
write_stream,
|
||||||
server.create_initialization_options(),
|
server.create_initialization_options(),
|
||||||
)
|
)
|
||||||
|
|
||||||
if task_group is None:
|
if task_group is None:
|
||||||
response = Response(
|
response = Response(
|
||||||
"Internal Server Error: Task group is not initialized",
|
"Internal Server Error: Task group is not initialized",
|
||||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||||
)
|
)
|
||||||
await response(scope, receive, send)
|
await response(scope, receive, send)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Store the instance before starting the task to prevent races
|
# Store the instance before starting the task to prevent races
|
||||||
server_instances[http_transport.mcp_session_id] = http_transport
|
server_instances[http_transport.mcp_session_id] = http_transport
|
||||||
task_group.start_soon(run_server)
|
await task_group.start(run_server)
|
||||||
|
|
||||||
await http_transport.handle_request(scope, receive, send)
|
await http_transport.handle_request(scope, receive, send)
|
||||||
else:
|
else:
|
||||||
response = Response(
|
response = Response(
|
||||||
"Bad Request: No valid session ID provided",
|
"Bad Request: No valid session ID provided",
|
||||||
|
|||||||
Reference in New Issue
Block a user