mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 23:04:25 +01:00
WIP
This commit is contained in:
@@ -24,12 +24,20 @@ async def sse_client(
|
|||||||
headers: dict[str, Any] | None = None,
|
headers: dict[str, Any] | None = None,
|
||||||
timeout: float = 5,
|
timeout: float = 5,
|
||||||
sse_read_timeout: float = 60 * 5,
|
sse_read_timeout: float = 60 * 5,
|
||||||
|
client: httpx.AsyncClient | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Client transport for SSE.
|
Client transport for SSE.
|
||||||
|
|
||||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
||||||
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The URL to connect to
|
||||||
|
headers: Optional headers to send with the request
|
||||||
|
timeout: Connection timeout in seconds
|
||||||
|
sse_read_timeout: Read timeout in seconds
|
||||||
|
client: Optional httpx.AsyncClient instance to use for requests
|
||||||
"""
|
"""
|
||||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
||||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||||
@@ -43,7 +51,13 @@ async def sse_client(
|
|||||||
async with anyio.create_task_group() as tg:
|
async with anyio.create_task_group() as tg:
|
||||||
try:
|
try:
|
||||||
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
|
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
|
||||||
async with httpx.AsyncClient(headers=headers) as client:
|
if client is None:
|
||||||
|
client = httpx.AsyncClient(headers=headers)
|
||||||
|
should_close_client = True
|
||||||
|
else:
|
||||||
|
should_close_client = False
|
||||||
|
|
||||||
|
try:
|
||||||
async with aconnect_sse(
|
async with aconnect_sse(
|
||||||
client,
|
client,
|
||||||
"GET",
|
"GET",
|
||||||
@@ -137,6 +151,9 @@ async def sse_client(
|
|||||||
yield read_stream, write_stream
|
yield read_stream, write_stream
|
||||||
finally:
|
finally:
|
||||||
tg.cancel_scope.cancel()
|
tg.cancel_scope.cancel()
|
||||||
|
finally:
|
||||||
|
if should_close_client:
|
||||||
|
await client.aclose()
|
||||||
finally:
|
finally:
|
||||||
await read_stream_writer.aclose()
|
await read_stream_writer.aclose()
|
||||||
await write_stream.aclose()
|
await write_stream.aclose()
|
||||||
|
|||||||
@@ -1,82 +1,127 @@
|
|||||||
import pytest
|
|
||||||
import anyio
|
import anyio
|
||||||
|
import pytest
|
||||||
from starlette.applications import Starlette
|
from starlette.applications import Starlette
|
||||||
from starlette.routing import Mount, Route
|
from starlette.routing import Mount, Route
|
||||||
import uvicorn
|
|
||||||
from mcp.client.sse import sse_client
|
|
||||||
from exceptiongroup import ExceptionGroup
|
|
||||||
import asyncio
|
|
||||||
import httpx
|
import httpx
|
||||||
from httpx import ReadTimeout
|
from httpx import ReadTimeout, ASGITransport
|
||||||
|
|
||||||
|
from mcp.client.sse import sse_client
|
||||||
from mcp.server.sse import SseServerTransport
|
from mcp.server.sse import SseServerTransport
|
||||||
|
from mcp.types import JSONRPCMessage
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def sse_server():
|
async def sse_transport():
|
||||||
|
"""Fixture that creates an SSE transport instance."""
|
||||||
|
return SseServerTransport("/messages/")
|
||||||
|
|
||||||
# Create an SSE transport at an endpoint
|
|
||||||
sse = SseServerTransport("/messages/")
|
|
||||||
|
|
||||||
# Create Starlette routes for SSE and message handling
|
@pytest.fixture
|
||||||
routes = [
|
async def sse_app(sse_transport):
|
||||||
Route("/sse", endpoint=handle_sse),
|
"""Fixture that creates a Starlette app with SSE endpoints."""
|
||||||
Mount("/messages/", app=sse.handle_post_message),
|
|
||||||
]
|
|
||||||
#
|
|
||||||
# Create and run Starlette app
|
|
||||||
app = Starlette(routes=routes)
|
|
||||||
|
|
||||||
# Define handler functions
|
|
||||||
async def handle_sse(request):
|
async def handle_sse(request):
|
||||||
async with sse.connect_sse(
|
"""Handler for SSE connections."""
|
||||||
|
async with sse_transport.connect_sse(
|
||||||
request.scope, request.receive, request._send
|
request.scope, request.receive, request._send
|
||||||
) as streams:
|
) as streams:
|
||||||
await app.run(
|
client_to_server, server_to_client = streams
|
||||||
streams[0], streams[1], app.create_initialization_options()
|
async for message in client_to_server:
|
||||||
)
|
# Echo messages back for testing
|
||||||
|
await server_to_client.send(message)
|
||||||
|
|
||||||
uvicorn.run(app, host="127.0.0.1", port=34891)
|
routes = [
|
||||||
|
Route("/sse", endpoint=handle_sse),
|
||||||
|
Mount("/messages", app=sse_transport.handle_post_message),
|
||||||
|
]
|
||||||
|
|
||||||
async def sse_handler(request):
|
return Starlette(routes=routes)
|
||||||
response = httpx.Response(200, content_type="text/event-stream")
|
|
||||||
response.send_headers()
|
|
||||||
response.write("data: test\n\n")
|
|
||||||
await response.aclose()
|
|
||||||
|
|
||||||
async with httpx.AsyncServer(sse_handler) as server:
|
|
||||||
yield server.url
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def sse_client():
|
async def test_client(sse_app):
|
||||||
async with sse_client("http://test/sse") as (read_stream, write_stream):
|
"""Create a test client with ASGI transport."""
|
||||||
async with read_stream:
|
async with httpx.AsyncClient(
|
||||||
async for message in read_stream:
|
transport=ASGITransport(app=sse_app),
|
||||||
if isinstance(message, Exception):
|
base_url="http://testserver",
|
||||||
raise message
|
) as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
return read_stream, write_stream
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_sse_happy_path(monkeypatch):
|
async def test_sse_connection(test_client):
|
||||||
# Mock httpx.AsyncClient to return our mock response
|
"""Test basic SSE connection and message exchange."""
|
||||||
monkeypatch.setattr(httpx, "AsyncClient", MockClient)
|
async with sse_client(
|
||||||
|
"http://testserver/sse",
|
||||||
|
headers={"Host": "testserver"},
|
||||||
|
timeout=5,
|
||||||
|
client=test_client,
|
||||||
|
) as (read_stream, write_stream):
|
||||||
|
# Send a test message
|
||||||
|
test_message = JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": "test"})
|
||||||
|
await write_stream.send(test_message)
|
||||||
|
|
||||||
with pytest.raises(ReadTimeout) as exc_info:
|
# Receive echoed message
|
||||||
|
async with read_stream:
|
||||||
|
message = await read_stream.__anext__()
|
||||||
|
assert isinstance(message, JSONRPCMessage)
|
||||||
|
assert message.model_dump() == test_message.model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_sse_read_timeout(test_client):
|
||||||
|
"""Test that SSE client properly handles read timeouts."""
|
||||||
|
with pytest.raises(ReadTimeout):
|
||||||
async with sse_client(
|
async with sse_client(
|
||||||
"http://test/sse",
|
"http://testserver/sse",
|
||||||
timeout=5, # Connection timeout - make this longer
|
headers={"Host": "testserver"},
|
||||||
sse_read_timeout=1 # Read timeout - this should trigger
|
timeout=5,
|
||||||
|
sse_read_timeout=1,
|
||||||
|
client=test_client,
|
||||||
) as (read_stream, write_stream):
|
) as (read_stream, write_stream):
|
||||||
async with read_stream:
|
async with read_stream:
|
||||||
async for message in read_stream:
|
# This should timeout since no messages are being sent
|
||||||
if isinstance(message, Exception):
|
await read_stream.__anext__()
|
||||||
raise message
|
|
||||||
|
|
||||||
error = exc_info.value
|
|
||||||
assert isinstance(error, ReadTimeout)
|
|
||||||
assert str(error) == "Read timeout"
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_sse_read_timeouts(monkeypatch):
|
async def test_sse_connection_error(test_client):
|
||||||
"""Test that the SSE client properly handles read timeouts between SSE messages."""
|
"""Test SSE client behavior with connection errors."""
|
||||||
|
with pytest.raises(httpx.HTTPError):
|
||||||
|
async with sse_client(
|
||||||
|
"http://testserver/nonexistent",
|
||||||
|
headers={"Host": "testserver"},
|
||||||
|
timeout=5,
|
||||||
|
client=test_client,
|
||||||
|
):
|
||||||
|
pass # Should not reach here
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_sse_multiple_messages(test_client):
|
||||||
|
"""Test sending and receiving multiple SSE messages."""
|
||||||
|
async with sse_client(
|
||||||
|
"http://testserver/sse",
|
||||||
|
headers={"Host": "testserver"},
|
||||||
|
timeout=5,
|
||||||
|
client=test_client,
|
||||||
|
) as (read_stream, write_stream):
|
||||||
|
# Send multiple test messages
|
||||||
|
messages = [
|
||||||
|
JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": f"test{i}"})
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
await write_stream.send(msg)
|
||||||
|
|
||||||
|
# Receive all echoed messages
|
||||||
|
received = []
|
||||||
|
async with read_stream:
|
||||||
|
for _ in range(len(messages)):
|
||||||
|
message = await read_stream.__anext__()
|
||||||
|
assert isinstance(message, JSONRPCMessage)
|
||||||
|
received.append(message)
|
||||||
|
|
||||||
|
# Verify all messages were received in order
|
||||||
|
for sent, received in zip(messages, received):
|
||||||
|
assert sent.model_dump() == received.model_dump()
|
||||||
Reference in New Issue
Block a user