mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
StreamableHttp - client refactoring and resumability support (#595)
This commit is contained in:
@@ -261,6 +261,7 @@ class ClientSession(
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
) -> types.CallToolResult:
|
||||
"""Send a tools/call request."""
|
||||
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.CallToolRequest(
|
||||
|
||||
@@ -7,32 +7,403 @@ and session management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
from httpx_sse import EventSource, aconnect_sse
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
|
||||
|
||||
from mcp.shared.message import SessionMessage
|
||||
from mcp.shared.message import ClientMessageMetadata, SessionMessage
|
||||
from mcp.types import (
|
||||
ErrorData,
|
||||
JSONRPCError,
|
||||
JSONRPCMessage,
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
RequestId,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Header names
|
||||
MCP_SESSION_ID_HEADER = "mcp-session-id"
|
||||
LAST_EVENT_ID_HEADER = "last-event-id"
|
||||
|
||||
# Content types
|
||||
CONTENT_TYPE_JSON = "application/json"
|
||||
CONTENT_TYPE_SSE = "text/event-stream"
|
||||
SessionMessageOrError = SessionMessage | Exception
|
||||
StreamWriter = MemoryObjectSendStream[SessionMessageOrError]
|
||||
StreamReader = MemoryObjectReceiveStream[SessionMessage]
|
||||
GetSessionIdCallback = Callable[[], str | None]
|
||||
|
||||
MCP_SESSION_ID = "mcp-session-id"
|
||||
LAST_EVENT_ID = "last-event-id"
|
||||
CONTENT_TYPE = "content-type"
|
||||
ACCEPT = "Accept"
|
||||
|
||||
|
||||
JSON = "application/json"
|
||||
SSE = "text/event-stream"
|
||||
|
||||
|
||||
class StreamableHTTPError(Exception):
|
||||
"""Base exception for StreamableHTTP transport errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ResumptionError(StreamableHTTPError):
|
||||
"""Raised when resumption request is invalid."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext:
|
||||
"""Context for a request operation."""
|
||||
|
||||
client: httpx.AsyncClient
|
||||
headers: dict[str, str]
|
||||
session_id: str | None
|
||||
session_message: SessionMessage
|
||||
metadata: ClientMessageMetadata | None
|
||||
read_stream_writer: StreamWriter
|
||||
sse_read_timeout: timedelta
|
||||
|
||||
|
||||
class StreamableHTTPTransport:
|
||||
"""StreamableHTTP client transport implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: timedelta = timedelta(seconds=30),
|
||||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
|
||||
) -> None:
|
||||
"""Initialize the StreamableHTTP transport.
|
||||
|
||||
Args:
|
||||
url: The endpoint URL.
|
||||
headers: Optional headers to include in requests.
|
||||
timeout: HTTP timeout for regular operations.
|
||||
sse_read_timeout: Timeout for SSE read operations.
|
||||
"""
|
||||
self.url = url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
self.session_id: str | None = None
|
||||
self.request_headers = {
|
||||
ACCEPT: f"{JSON}, {SSE}",
|
||||
CONTENT_TYPE: JSON,
|
||||
**self.headers,
|
||||
}
|
||||
|
||||
def _update_headers_with_session(
|
||||
self, base_headers: dict[str, str]
|
||||
) -> dict[str, str]:
|
||||
"""Update headers with session ID if available."""
|
||||
headers = base_headers.copy()
|
||||
if self.session_id:
|
||||
headers[MCP_SESSION_ID] = self.session_id
|
||||
return headers
|
||||
|
||||
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
|
||||
"""Check if the message is an initialization request."""
|
||||
return (
|
||||
isinstance(message.root, JSONRPCRequest)
|
||||
and message.root.method == "initialize"
|
||||
)
|
||||
|
||||
def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
|
||||
"""Check if the message is an initialized notification."""
|
||||
return (
|
||||
isinstance(message.root, JSONRPCNotification)
|
||||
and message.root.method == "notifications/initialized"
|
||||
)
|
||||
|
||||
def _maybe_extract_session_id_from_response(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
) -> None:
|
||||
"""Extract and store session ID from response headers."""
|
||||
new_session_id = response.headers.get(MCP_SESSION_ID)
|
||||
if new_session_id:
|
||||
self.session_id = new_session_id
|
||||
logger.info(f"Received session ID: {self.session_id}")
|
||||
|
||||
async def _handle_sse_event(
|
||||
self,
|
||||
sse: ServerSentEvent,
|
||||
read_stream_writer: StreamWriter,
|
||||
original_request_id: RequestId | None = None,
|
||||
resumption_callback: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> bool:
|
||||
"""Handle an SSE event, returning True if the response is complete."""
|
||||
if sse.event == "message":
|
||||
try:
|
||||
message = JSONRPCMessage.model_validate_json(sse.data)
|
||||
logger.debug(f"SSE message: {message}")
|
||||
|
||||
# If this is a response and we have original_request_id, replace it
|
||||
if original_request_id is not None and isinstance(
|
||||
message.root, JSONRPCResponse | JSONRPCError
|
||||
):
|
||||
message.root.id = original_request_id
|
||||
|
||||
session_message = SessionMessage(message)
|
||||
await read_stream_writer.send(session_message)
|
||||
|
||||
# Call resumption token callback if we have an ID
|
||||
if sse.id and resumption_callback:
|
||||
await resumption_callback(sse.id)
|
||||
|
||||
# If this is a response or error return True indicating completion
|
||||
# Otherwise, return False to continue listening
|
||||
return isinstance(message.root, JSONRPCResponse | JSONRPCError)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Error parsing SSE message: {exc}")
|
||||
await read_stream_writer.send(exc)
|
||||
return False
|
||||
else:
|
||||
logger.warning(f"Unknown SSE event: {sse.event}")
|
||||
return False
|
||||
|
||||
async def handle_get_stream(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
read_stream_writer: StreamWriter,
|
||||
) -> None:
|
||||
"""Handle GET stream for server-initiated messages."""
|
||||
try:
|
||||
if not self.session_id:
|
||||
return
|
||||
|
||||
headers = self._update_headers_with_session(self.request_headers)
|
||||
|
||||
async with aconnect_sse(
|
||||
client,
|
||||
"GET",
|
||||
self.url,
|
||||
headers=headers,
|
||||
timeout=httpx.Timeout(
|
||||
self.timeout.seconds, read=self.sse_read_timeout.seconds
|
||||
),
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("GET SSE connection established")
|
||||
|
||||
async for sse in event_source.aiter_sse():
|
||||
await self._handle_sse_event(sse, read_stream_writer)
|
||||
|
||||
except Exception as exc:
|
||||
logger.debug(f"GET stream error (non-fatal): {exc}")
|
||||
|
||||
async def _handle_resumption_request(self, ctx: RequestContext) -> None:
|
||||
"""Handle a resumption request using GET with SSE."""
|
||||
headers = self._update_headers_with_session(ctx.headers)
|
||||
if ctx.metadata and ctx.metadata.resumption_token:
|
||||
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
|
||||
else:
|
||||
raise ResumptionError("Resumption request requires a resumption token")
|
||||
|
||||
# Extract original request ID to map responses
|
||||
original_request_id = None
|
||||
if isinstance(ctx.session_message.message.root, JSONRPCRequest):
|
||||
original_request_id = ctx.session_message.message.root.id
|
||||
|
||||
async with aconnect_sse(
|
||||
ctx.client,
|
||||
"GET",
|
||||
self.url,
|
||||
headers=headers,
|
||||
timeout=httpx.Timeout(
|
||||
self.timeout.seconds, read=ctx.sse_read_timeout.seconds
|
||||
),
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("Resumption GET SSE connection established")
|
||||
|
||||
async for sse in event_source.aiter_sse():
|
||||
is_complete = await self._handle_sse_event(
|
||||
sse,
|
||||
ctx.read_stream_writer,
|
||||
original_request_id,
|
||||
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
|
||||
)
|
||||
if is_complete:
|
||||
break
|
||||
|
||||
async def _handle_post_request(self, ctx: RequestContext) -> None:
|
||||
"""Handle a POST request with response processing."""
|
||||
headers = self._update_headers_with_session(ctx.headers)
|
||||
message = ctx.session_message.message
|
||||
is_initialization = self._is_initialization_request(message)
|
||||
|
||||
async with ctx.client.stream(
|
||||
"POST",
|
||||
self.url,
|
||||
json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
headers=headers,
|
||||
) as response:
|
||||
if response.status_code == 202:
|
||||
logger.debug("Received 202 Accepted")
|
||||
return
|
||||
|
||||
if response.status_code == 404:
|
||||
if isinstance(message.root, JSONRPCRequest):
|
||||
await self._send_session_terminated_error(
|
||||
ctx.read_stream_writer,
|
||||
message.root.id,
|
||||
)
|
||||
return
|
||||
|
||||
response.raise_for_status()
|
||||
if is_initialization:
|
||||
self._maybe_extract_session_id_from_response(response)
|
||||
|
||||
content_type = response.headers.get(CONTENT_TYPE, "").lower()
|
||||
|
||||
if content_type.startswith(JSON):
|
||||
await self._handle_json_response(response, ctx.read_stream_writer)
|
||||
elif content_type.startswith(SSE):
|
||||
await self._handle_sse_response(response, ctx)
|
||||
else:
|
||||
await self._handle_unexpected_content_type(
|
||||
content_type,
|
||||
ctx.read_stream_writer,
|
||||
)
|
||||
|
||||
async def _handle_json_response(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
read_stream_writer: StreamWriter,
|
||||
) -> None:
|
||||
"""Handle JSON response from the server."""
|
||||
try:
|
||||
content = await response.aread()
|
||||
message = JSONRPCMessage.model_validate_json(content)
|
||||
session_message = SessionMessage(message)
|
||||
await read_stream_writer.send(session_message)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error parsing JSON response: {exc}")
|
||||
await read_stream_writer.send(exc)
|
||||
|
||||
async def _handle_sse_response(
|
||||
self, response: httpx.Response, ctx: RequestContext
|
||||
) -> None:
|
||||
"""Handle SSE response from the server."""
|
||||
try:
|
||||
event_source = EventSource(response)
|
||||
async for sse in event_source.aiter_sse():
|
||||
await self._handle_sse_event(
|
||||
sse,
|
||||
ctx.read_stream_writer,
|
||||
resumption_callback=(
|
||||
ctx.metadata.on_resumption_token_update
|
||||
if ctx.metadata
|
||||
else None
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error reading SSE stream:")
|
||||
await ctx.read_stream_writer.send(e)
|
||||
|
||||
async def _handle_unexpected_content_type(
|
||||
self,
|
||||
content_type: str,
|
||||
read_stream_writer: StreamWriter,
|
||||
) -> None:
|
||||
"""Handle unexpected content type in response."""
|
||||
error_msg = f"Unexpected content type: {content_type}"
|
||||
logger.error(error_msg)
|
||||
await read_stream_writer.send(ValueError(error_msg))
|
||||
|
||||
async def _send_session_terminated_error(
|
||||
self,
|
||||
read_stream_writer: StreamWriter,
|
||||
request_id: RequestId,
|
||||
) -> None:
|
||||
"""Send a session terminated error response."""
|
||||
jsonrpc_error = JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
error=ErrorData(code=32600, message="Session terminated"),
|
||||
)
|
||||
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
|
||||
await read_stream_writer.send(session_message)
|
||||
|
||||
async def post_writer(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
write_stream_reader: StreamReader,
|
||||
read_stream_writer: StreamWriter,
|
||||
write_stream: MemoryObjectSendStream[SessionMessage],
|
||||
start_get_stream: Callable[[], None],
|
||||
) -> None:
|
||||
"""Handle writing requests to the server."""
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for session_message in write_stream_reader:
|
||||
message = session_message.message
|
||||
metadata = (
|
||||
session_message.metadata
|
||||
if isinstance(session_message.metadata, ClientMessageMetadata)
|
||||
else None
|
||||
)
|
||||
|
||||
# Check if this is a resumption request
|
||||
is_resumption = bool(metadata and metadata.resumption_token)
|
||||
|
||||
logger.debug(f"Sending client message: {message}")
|
||||
|
||||
# Handle initialized notification
|
||||
if self._is_initialized_notification(message):
|
||||
start_get_stream()
|
||||
|
||||
ctx = RequestContext(
|
||||
client=client,
|
||||
headers=self.request_headers,
|
||||
session_id=self.session_id,
|
||||
session_message=session_message,
|
||||
metadata=metadata,
|
||||
read_stream_writer=read_stream_writer,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
|
||||
if is_resumption:
|
||||
await self._handle_resumption_request(ctx)
|
||||
else:
|
||||
await self._handle_post_request(ctx)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Error in post_writer: {exc}")
|
||||
finally:
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream.aclose()
|
||||
|
||||
async def terminate_session(self, client: httpx.AsyncClient) -> None:
|
||||
"""Terminate the session by sending a DELETE request."""
|
||||
if not self.session_id:
|
||||
return
|
||||
|
||||
try:
|
||||
headers = self._update_headers_with_session(self.request_headers)
|
||||
response = await client.delete(self.url, headers=headers)
|
||||
|
||||
if response.status_code == 405:
|
||||
logger.debug("Server does not allow session termination")
|
||||
elif response.status_code != 200:
|
||||
logger.warning(f"Session termination failed: {response.status_code}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"Session termination failed: {exc}")
|
||||
|
||||
def get_session_id(self) -> str | None:
|
||||
"""Get the current session ID."""
|
||||
return self.session_id
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -41,7 +412,15 @@ async def streamablehttp_client(
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: timedelta = timedelta(seconds=30),
|
||||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
|
||||
):
|
||||
terminate_on_close: bool = True,
|
||||
) -> AsyncGenerator[
|
||||
tuple[
|
||||
MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||
MemoryObjectSendStream[SessionMessage],
|
||||
GetSessionIdCallback,
|
||||
],
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
Client transport for StreamableHTTP.
|
||||
|
||||
@@ -49,8 +428,12 @@ async def streamablehttp_client(
|
||||
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||
|
||||
Yields:
|
||||
Tuple of (read_stream, write_stream, terminate_callback)
|
||||
Tuple containing:
|
||||
- read_stream: Stream for reading messages from the server
|
||||
- write_stream: Stream for sending messages to the server
|
||||
- get_session_id_callback: Function to retrieve the current session ID
|
||||
"""
|
||||
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout)
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream[
|
||||
SessionMessage | Exception
|
||||
@@ -59,208 +442,41 @@ async def streamablehttp_client(
|
||||
SessionMessage
|
||||
](0)
|
||||
|
||||
async def get_stream():
|
||||
"""
|
||||
Optional GET stream for server-initiated messages
|
||||
"""
|
||||
nonlocal session_id
|
||||
try:
|
||||
# Only attempt GET if we have a session ID
|
||||
if not session_id:
|
||||
return
|
||||
|
||||
get_headers = request_headers.copy()
|
||||
get_headers[MCP_SESSION_ID_HEADER] = session_id
|
||||
|
||||
async with aconnect_sse(
|
||||
client,
|
||||
"GET",
|
||||
url,
|
||||
headers=get_headers,
|
||||
timeout=httpx.Timeout(timeout.seconds, read=sse_read_timeout.seconds),
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("GET SSE connection established")
|
||||
|
||||
async for sse in event_source.aiter_sse():
|
||||
if sse.event == "message":
|
||||
try:
|
||||
message = JSONRPCMessage.model_validate_json(sse.data)
|
||||
logger.debug(f"GET message: {message}")
|
||||
session_message = SessionMessage(message)
|
||||
await read_stream_writer.send(session_message)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error parsing GET message: {exc}")
|
||||
await read_stream_writer.send(exc)
|
||||
else:
|
||||
logger.warning(f"Unknown SSE event from GET: {sse.event}")
|
||||
except Exception as exc:
|
||||
# GET stream is optional, so don't propagate errors
|
||||
logger.debug(f"GET stream error (non-fatal): {exc}")
|
||||
|
||||
async def post_writer(client: httpx.AsyncClient):
|
||||
nonlocal session_id
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for session_message in write_stream_reader:
|
||||
message = session_message.message
|
||||
# Add session ID to headers if we have one
|
||||
post_headers = request_headers.copy()
|
||||
if session_id:
|
||||
post_headers[MCP_SESSION_ID_HEADER] = session_id
|
||||
|
||||
logger.debug(f"Sending client message: {message}")
|
||||
|
||||
# Handle initial initialization request
|
||||
is_initialization = (
|
||||
isinstance(message.root, JSONRPCRequest)
|
||||
and message.root.method == "initialize"
|
||||
)
|
||||
if (
|
||||
isinstance(message.root, JSONRPCNotification)
|
||||
and message.root.method == "notifications/initialized"
|
||||
):
|
||||
tg.start_soon(get_stream)
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
url,
|
||||
json=message.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
),
|
||||
headers=post_headers,
|
||||
) as response:
|
||||
if response.status_code == 202:
|
||||
logger.debug("Received 202 Accepted")
|
||||
continue
|
||||
# Check for 404 (session expired/invalid)
|
||||
if response.status_code == 404:
|
||||
if isinstance(message.root, JSONRPCRequest):
|
||||
jsonrpc_error = JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=message.root.id,
|
||||
error=ErrorData(
|
||||
code=32600,
|
||||
message="Session terminated",
|
||||
),
|
||||
)
|
||||
session_message = SessionMessage(
|
||||
JSONRPCMessage(jsonrpc_error)
|
||||
)
|
||||
await read_stream_writer.send(session_message)
|
||||
continue
|
||||
response.raise_for_status()
|
||||
|
||||
# Extract session ID from response headers
|
||||
if is_initialization:
|
||||
new_session_id = response.headers.get(MCP_SESSION_ID_HEADER)
|
||||
if new_session_id:
|
||||
session_id = new_session_id
|
||||
logger.info(f"Received session ID: {session_id}")
|
||||
|
||||
# Handle different response types
|
||||
content_type = response.headers.get("content-type", "").lower()
|
||||
|
||||
if content_type.startswith(CONTENT_TYPE_JSON):
|
||||
try:
|
||||
content = await response.aread()
|
||||
json_message = JSONRPCMessage.model_validate_json(
|
||||
content
|
||||
)
|
||||
session_message = SessionMessage(json_message)
|
||||
await read_stream_writer.send(session_message)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error parsing JSON response: {exc}")
|
||||
await read_stream_writer.send(exc)
|
||||
|
||||
elif content_type.startswith(CONTENT_TYPE_SSE):
|
||||
# Parse SSE events from the response
|
||||
try:
|
||||
event_source = EventSource(response)
|
||||
async for sse in event_source.aiter_sse():
|
||||
if sse.event == "message":
|
||||
try:
|
||||
message = (
|
||||
JSONRPCMessage.model_validate_json(
|
||||
sse.data
|
||||
)
|
||||
)
|
||||
session_message = SessionMessage(message)
|
||||
await read_stream_writer.send(
|
||||
session_message
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("Error parsing message")
|
||||
await read_stream_writer.send(exc)
|
||||
else:
|
||||
logger.warning(f"Unknown event: {sse.event}")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error reading SSE stream:")
|
||||
await read_stream_writer.send(e)
|
||||
|
||||
else:
|
||||
# For 202 Accepted with no body
|
||||
if response.status_code == 202:
|
||||
logger.debug("Received 202 Accepted")
|
||||
continue
|
||||
|
||||
error_msg = f"Unexpected content type: {content_type}"
|
||||
logger.error(error_msg)
|
||||
await read_stream_writer.send(ValueError(error_msg))
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Error in post_writer: {exc}")
|
||||
finally:
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream.aclose()
|
||||
|
||||
async def terminate_session():
|
||||
"""
|
||||
Terminate the session by sending a DELETE request.
|
||||
"""
|
||||
nonlocal session_id
|
||||
if not session_id:
|
||||
return # No session to terminate
|
||||
|
||||
try:
|
||||
delete_headers = request_headers.copy()
|
||||
delete_headers[MCP_SESSION_ID_HEADER] = session_id
|
||||
|
||||
response = await client.delete(
|
||||
url,
|
||||
headers=delete_headers,
|
||||
)
|
||||
|
||||
if response.status_code == 405:
|
||||
# Server doesn't allow client-initiated termination
|
||||
logger.debug("Server does not allow session termination")
|
||||
elif response.status_code != 200:
|
||||
logger.warning(f"Session termination failed: {response.status_code}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"Session termination failed: {exc}")
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
try:
|
||||
logger.info(f"Connecting to StreamableHTTP endpoint: {url}")
|
||||
# Set up headers with required Accept header
|
||||
request_headers = {
|
||||
"Accept": f"{CONTENT_TYPE_JSON}, {CONTENT_TYPE_SSE}",
|
||||
"Content-Type": CONTENT_TYPE_JSON,
|
||||
**(headers or {}),
|
||||
}
|
||||
# Track session ID if provided by server
|
||||
session_id: str | None = None
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
headers=request_headers,
|
||||
timeout=httpx.Timeout(timeout.seconds, read=sse_read_timeout.seconds),
|
||||
headers=transport.request_headers,
|
||||
timeout=httpx.Timeout(
|
||||
transport.timeout.seconds, read=transport.sse_read_timeout.seconds
|
||||
),
|
||||
follow_redirects=True,
|
||||
) as client:
|
||||
tg.start_soon(post_writer, client)
|
||||
# Define callbacks that need access to tg
|
||||
def start_get_stream() -> None:
|
||||
tg.start_soon(
|
||||
transport.handle_get_stream, client, read_stream_writer
|
||||
)
|
||||
|
||||
tg.start_soon(
|
||||
transport.post_writer,
|
||||
client,
|
||||
write_stream_reader,
|
||||
read_stream_writer,
|
||||
write_stream,
|
||||
start_get_stream,
|
||||
)
|
||||
|
||||
try:
|
||||
yield read_stream, write_stream, terminate_session
|
||||
yield (
|
||||
read_stream,
|
||||
write_stream,
|
||||
transport.get_session_id,
|
||||
)
|
||||
finally:
|
||||
if transport.session_id and terminate_on_close:
|
||||
await transport.terminate_session(client)
|
||||
tg.cancel_scope.cancel()
|
||||
finally:
|
||||
await read_stream_writer.aclose()
|
||||
|
||||
@@ -5,16 +5,24 @@ This module defines a wrapper type that combines JSONRPCMessage with metadata
|
||||
to support transport-specific features like resumability.
|
||||
"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from mcp.types import JSONRPCMessage, RequestId
|
||||
|
||||
ResumptionToken = str
|
||||
|
||||
ResumptionTokenUpdateCallback = Callable[[ResumptionToken], Awaitable[None]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClientMessageMetadata:
|
||||
"""Metadata specific to client messages."""
|
||||
|
||||
resumption_token: str | None = None
|
||||
resumption_token: ResumptionToken | None = None
|
||||
on_resumption_token_update: Callable[[ResumptionToken], Awaitable[None]] | None = (
|
||||
None
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -12,7 +12,7 @@ from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.message import ServerMessageMetadata, SessionMessage
|
||||
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
|
||||
from mcp.types import (
|
||||
CancelledNotification,
|
||||
ClientNotification,
|
||||
@@ -213,6 +213,7 @@ class BaseSession(
|
||||
request: SendRequestT,
|
||||
result_type: type[ReceiveResultT],
|
||||
request_read_timeout_seconds: timedelta | None = None,
|
||||
metadata: MessageMetadata = None,
|
||||
) -> ReceiveResultT:
|
||||
"""
|
||||
Sends a request and wait for a response. Raises an McpError if the
|
||||
@@ -241,7 +242,9 @@ class BaseSession(
|
||||
# TODO: Support progress callbacks
|
||||
|
||||
await self._write_stream.send(
|
||||
SessionMessage(message=JSONRPCMessage(jsonrpc_request))
|
||||
SessionMessage(
|
||||
message=JSONRPCMessage(jsonrpc_request), metadata=metadata
|
||||
)
|
||||
)
|
||||
|
||||
# request read timeout takes precedence over session read timeout
|
||||
|
||||
Reference in New Issue
Block a user