mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Properly clean up response streams in BaseSession (#515)
This commit is contained in:
@@ -187,7 +187,6 @@ class BaseSession(
|
||||
self._receive_notification_type = receive_notification_type
|
||||
self._session_read_timeout_seconds = read_timeout_seconds
|
||||
self._in_flight = {}
|
||||
|
||||
self._exit_stack = AsyncExitStack()
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
@@ -232,45 +231,48 @@ class BaseSession(
|
||||
](1)
|
||||
self._response_streams[request_id] = response_stream
|
||||
|
||||
self._exit_stack.push_async_callback(lambda: response_stream.aclose())
|
||||
self._exit_stack.push_async_callback(lambda: response_stream_reader.aclose())
|
||||
|
||||
jsonrpc_request = JSONRPCRequest(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
|
||||
# TODO: Support progress callbacks
|
||||
|
||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_request))
|
||||
|
||||
# request read timeout takes precedence over session read timeout
|
||||
timeout = None
|
||||
if request_read_timeout_seconds is not None:
|
||||
timeout = request_read_timeout_seconds.total_seconds()
|
||||
elif self._session_read_timeout_seconds is not None:
|
||||
timeout = self._session_read_timeout_seconds.total_seconds()
|
||||
|
||||
try:
|
||||
with anyio.fail_after(timeout):
|
||||
response_or_error = await response_stream_reader.receive()
|
||||
except TimeoutError:
|
||||
raise McpError(
|
||||
ErrorData(
|
||||
code=httpx.codes.REQUEST_TIMEOUT,
|
||||
message=(
|
||||
f"Timed out while waiting for response to "
|
||||
f"{request.__class__.__name__}. Waited "
|
||||
f"{timeout} seconds."
|
||||
),
|
||||
)
|
||||
jsonrpc_request = JSONRPCRequest(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
|
||||
if isinstance(response_or_error, JSONRPCError):
|
||||
raise McpError(response_or_error.error)
|
||||
else:
|
||||
return result_type.model_validate(response_or_error.result)
|
||||
# TODO: Support progress callbacks
|
||||
|
||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_request))
|
||||
|
||||
# request read timeout takes precedence over session read timeout
|
||||
timeout = None
|
||||
if request_read_timeout_seconds is not None:
|
||||
timeout = request_read_timeout_seconds.total_seconds()
|
||||
elif self._session_read_timeout_seconds is not None:
|
||||
timeout = self._session_read_timeout_seconds.total_seconds()
|
||||
|
||||
try:
|
||||
with anyio.fail_after(timeout):
|
||||
response_or_error = await response_stream_reader.receive()
|
||||
except TimeoutError:
|
||||
raise McpError(
|
||||
ErrorData(
|
||||
code=httpx.codes.REQUEST_TIMEOUT,
|
||||
message=(
|
||||
f"Timed out while waiting for response to "
|
||||
f"{request.__class__.__name__}. Waited "
|
||||
f"{timeout} seconds."
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(response_or_error, JSONRPCError):
|
||||
raise McpError(response_or_error.error)
|
||||
else:
|
||||
return result_type.model_validate(response_or_error.result)
|
||||
|
||||
finally:
|
||||
self._response_streams.pop(request_id, None)
|
||||
await response_stream.aclose()
|
||||
await response_stream_reader.aclose()
|
||||
|
||||
async def send_notification(self, notification: SendNotificationT) -> None:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user