mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-22 08:14:22 +01:00
WIP
This commit is contained in:
@@ -4,6 +4,8 @@ from starlette.applications import Starlette
|
|||||||
from starlette.routing import Mount, Route
|
from starlette.routing import Mount, Route
|
||||||
import httpx
|
import httpx
|
||||||
from httpx import ReadTimeout, ASGITransport
|
from httpx import ReadTimeout, ASGITransport
|
||||||
|
from starlette.responses import Response
|
||||||
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
|
||||||
from mcp.client.sse import sse_client
|
from mcp.client.sse import sse_client
|
||||||
from mcp.server.sse import SseServerTransport
|
from mcp.server.sse import SseServerTransport
|
||||||
@@ -21,17 +23,33 @@ async def sse_app(sse_transport):
|
|||||||
"""Fixture that creates a Starlette app with SSE endpoints."""
|
"""Fixture that creates a Starlette app with SSE endpoints."""
|
||||||
async def handle_sse(request):
|
async def handle_sse(request):
|
||||||
"""Handler for SSE connections."""
|
"""Handler for SSE connections."""
|
||||||
async with sse_transport.connect_sse(
|
async def event_generator():
|
||||||
request.scope, request.receive, request._send
|
# Send initial connection event
|
||||||
) as streams:
|
yield {
|
||||||
client_to_server, server_to_client = streams
|
"event": "endpoint",
|
||||||
async for message in client_to_server:
|
"data": "/messages",
|
||||||
# Echo messages back for testing
|
}
|
||||||
await server_to_client.send(message)
|
|
||||||
|
# Keep connection alive
|
||||||
|
async with sse_transport.connect_sse(
|
||||||
|
request.scope, request.receive, request._send
|
||||||
|
) as streams:
|
||||||
|
client_to_server, server_to_client = streams
|
||||||
|
async for message in client_to_server:
|
||||||
|
yield {
|
||||||
|
"event": "message",
|
||||||
|
"data": message.model_dump_json(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return EventSourceResponse(event_generator())
|
||||||
|
|
||||||
|
async def handle_post(request):
|
||||||
|
"""Handler for POST messages."""
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
routes = [
|
routes = [
|
||||||
Route("/sse", endpoint=handle_sse),
|
Route("/sse", endpoint=handle_sse),
|
||||||
Mount("/messages", app=sse_transport.handle_post_message),
|
Route("/messages", endpoint=handle_post, methods=["POST"]),
|
||||||
]
|
]
|
||||||
|
|
||||||
return Starlette(routes=routes)
|
return Starlette(routes=routes)
|
||||||
@@ -40,9 +58,11 @@ async def sse_app(sse_transport):
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def test_client(sse_app):
|
async def test_client(sse_app):
|
||||||
"""Create a test client with ASGI transport."""
|
"""Create a test client with ASGI transport."""
|
||||||
|
transport = ASGITransport(app=sse_app)
|
||||||
async with httpx.AsyncClient(
|
async with httpx.AsyncClient(
|
||||||
transport=ASGITransport(app=sse_app),
|
transport=transport,
|
||||||
base_url="http://testserver",
|
base_url="http://testserver",
|
||||||
|
timeout=5.0,
|
||||||
) as client:
|
) as client:
|
||||||
yield client
|
yield client
|
||||||
|
|
||||||
@@ -53,7 +73,8 @@ async def test_sse_connection(test_client):
|
|||||||
async with sse_client(
|
async with sse_client(
|
||||||
"http://testserver/sse",
|
"http://testserver/sse",
|
||||||
headers={"Host": "testserver"},
|
headers={"Host": "testserver"},
|
||||||
timeout=5,
|
timeout=2,
|
||||||
|
sse_read_timeout=1,
|
||||||
client=test_client,
|
client=test_client,
|
||||||
) as (read_stream, write_stream):
|
) as (read_stream, write_stream):
|
||||||
# Send a test message
|
# Send a test message
|
||||||
@@ -74,7 +95,7 @@ async def test_sse_read_timeout(test_client):
|
|||||||
async with sse_client(
|
async with sse_client(
|
||||||
"http://testserver/sse",
|
"http://testserver/sse",
|
||||||
headers={"Host": "testserver"},
|
headers={"Host": "testserver"},
|
||||||
timeout=5,
|
timeout=2,
|
||||||
sse_read_timeout=1,
|
sse_read_timeout=1,
|
||||||
client=test_client,
|
client=test_client,
|
||||||
) as (read_stream, write_stream):
|
) as (read_stream, write_stream):
|
||||||
@@ -90,7 +111,7 @@ async def test_sse_connection_error(test_client):
|
|||||||
async with sse_client(
|
async with sse_client(
|
||||||
"http://testserver/nonexistent",
|
"http://testserver/nonexistent",
|
||||||
headers={"Host": "testserver"},
|
headers={"Host": "testserver"},
|
||||||
timeout=5,
|
timeout=2,
|
||||||
client=test_client,
|
client=test_client,
|
||||||
):
|
):
|
||||||
pass # Should not reach here
|
pass # Should not reach here
|
||||||
@@ -102,7 +123,8 @@ async def test_sse_multiple_messages(test_client):
|
|||||||
async with sse_client(
|
async with sse_client(
|
||||||
"http://testserver/sse",
|
"http://testserver/sse",
|
||||||
headers={"Host": "testserver"},
|
headers={"Host": "testserver"},
|
||||||
timeout=5,
|
timeout=2,
|
||||||
|
sse_read_timeout=1,
|
||||||
client=test_client,
|
client=test_client,
|
||||||
) as (read_stream, write_stream):
|
) as (read_stream, write_stream):
|
||||||
# Send multiple test messages
|
# Send multiple test messages
|
||||||
|
|||||||
Reference in New Issue
Block a user