add auth client sse (#760)

This commit is contained in:
ihrpr
2025-05-20 15:21:14 +01:00
committed by GitHub
parent 5a9340af71
commit 43ded92633
3 changed files with 64 additions and 33 deletions

View File

@@ -1,11 +1,11 @@
# Simple Auth Client Example # Simple Auth Client Example
A demonstration of how to use the MCP Python SDK with OAuth authentication over streamable HTTP transport. A demonstration of how to use the MCP Python SDK with OAuth authentication over streamable HTTP or SSE transport.
## Features ## Features
- OAuth 2.0 authentication with PKCE - OAuth 2.0 authentication with PKCE
- Streamable HTTP transport - Support for both StreamableHTTP and SSE transports
- Interactive command-line interface - Interactive command-line interface
## Installation ## Installation
@@ -31,7 +31,10 @@ uv run mcp-simple-auth --transport streamable-http --port 3001
uv run mcp-simple-auth-client uv run mcp-simple-auth-client
# Or with custom server URL # Or with custom server URL
MCP_SERVER_URL=http://localhost:3001 uv run mcp-simple-auth-client MCP_SERVER_PORT=3001 uv run mcp-simple-auth-client
# Use SSE transport
MCP_TRANSPORT_TYPE=sse uv run mcp-simple-auth-client
``` ```
### 3. Complete OAuth flow ### 3. Complete OAuth flow
@@ -67,4 +70,5 @@ mcp> quit
## Configuration ## Configuration
- `MCP_SERVER_URL` - Server URL (default: http://localhost:3001) - `MCP_SERVER_PORT` - Server URL (default: 8000)
- `MCP_TRANSPORT_TYPE` - Transport type: `streamable_http` (default) or `sse`

View File

@@ -18,6 +18,7 @@ from urllib.parse import parse_qs, urlparse
from mcp.client.auth import OAuthClientProvider, TokenStorage from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.client.session import ClientSession from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client from mcp.client.streamable_http import streamablehttp_client
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
@@ -149,8 +150,9 @@ class CallbackServer:
class SimpleAuthClient: class SimpleAuthClient:
"""Simple MCP client with auth support.""" """Simple MCP client with auth support."""
def __init__(self, server_url: str): def __init__(self, server_url: str, transport_type: str = "streamable_http"):
self.server_url = server_url self.server_url = server_url
self.transport_type = transport_type
self.session: ClientSession | None = None self.session: ClientSession | None = None
async def connect(self): async def connect(self):
@@ -195,17 +197,32 @@ class SimpleAuthClient:
callback_handler=callback_handler, callback_handler=callback_handler,
) )
# Create streamable HTTP transport with auth handler # Create transport with auth handler based on transport type
stream_context = streamablehttp_client( if self.transport_type == "sse":
print("📡 Opening SSE transport connection with auth...")
async with sse_client(
url=self.server_url,
auth=oauth_auth,
timeout=60,
) as (read_stream, write_stream):
await self._run_session(read_stream, write_stream, None)
else:
print("📡 Opening StreamableHTTP transport connection with auth...")
async with streamablehttp_client(
url=self.server_url, url=self.server_url,
auth=oauth_auth, auth=oauth_auth,
timeout=timedelta(seconds=60), timeout=timedelta(seconds=60),
) ) as (read_stream, write_stream, get_session_id):
await self._run_session(read_stream, write_stream, get_session_id)
print( except Exception as e:
"📡 Opening transport connection (HTTPX handles auth automatically)..." print(f"❌ Failed to connect: {e}")
) import traceback
async with stream_context as (read_stream, write_stream, get_session_id):
traceback.print_exc()
async def _run_session(self, read_stream, write_stream, get_session_id):
"""Run the MCP session with the given streams."""
print("🤝 Initializing MCP session...") print("🤝 Initializing MCP session...")
async with ClientSession(read_stream, write_stream) as session: async with ClientSession(read_stream, write_stream) as session:
self.session = session self.session = session
@@ -214,6 +231,7 @@ class SimpleAuthClient:
print("✨ Session initialization complete!") print("✨ Session initialization complete!")
print(f"\n✅ Connected to MCP server at {self.server_url}") print(f"\n✅ Connected to MCP server at {self.server_url}")
if get_session_id:
session_id = get_session_id() session_id = get_session_id()
if session_id: if session_id:
print(f"Session ID: {session_id}") print(f"Session ID: {session_id}")
@@ -221,12 +239,6 @@ class SimpleAuthClient:
# Run interactive loop # Run interactive loop
await self.interactive_loop() await self.interactive_loop()
except Exception as e:
print(f"❌ Failed to connect: {e}")
import traceback
traceback.print_exc()
async def list_tools(self): async def list_tools(self):
"""List available tools from the server.""" """List available tools from the server."""
if not self.session: if not self.session:
@@ -326,13 +338,20 @@ async def main():
"""Main entry point.""" """Main entry point."""
# Default server URL - can be overridden with environment variable # Default server URL - can be overridden with environment variable
# Most MCP streamable HTTP servers use /mcp as the endpoint # Most MCP streamable HTTP servers use /mcp as the endpoint
server_url = os.getenv("MCP_SERVER_URL", "http://localhost:8000/mcp") server_url = os.getenv("MCP_SERVER_PORT", 8000)
transport_type = os.getenv("MCP_TRANSPORT_TYPE", "streamable_http")
server_url = (
f"http://localhost:{server_url}/mcp"
if transport_type == "streamable_http"
else f"http://localhost:{server_url}/sse"
)
print("🚀 Simple MCP Auth Client") print("🚀 Simple MCP Auth Client")
print(f"Connecting to: {server_url}") print(f"Connecting to: {server_url}")
print(f"Transport type: {transport_type}")
# Start connection flow - OAuth will be handled automatically # Start connection flow - OAuth will be handled automatically
client = SimpleAuthClient(server_url) client = SimpleAuthClient(server_url, transport_type)
await client.connect() await client.connect()

View File

@@ -26,12 +26,20 @@ async def sse_client(
headers: dict[str, Any] | None = None, headers: dict[str, Any] | None = None,
timeout: float = 5, timeout: float = 5,
sse_read_timeout: float = 60 * 5, sse_read_timeout: float = 60 * 5,
auth: httpx.Auth | None = None,
): ):
""" """
Client transport for SSE. Client transport for SSE.
`sse_read_timeout` determines how long (in seconds) the client will wait for a new `sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`. event before disconnecting. All other HTTP operations are controlled by `timeout`.
Args:
url: The SSE endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations.
sse_read_timeout: Timeout for SSE read operations.
auth: Optional HTTPX authentication handler.
""" """
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
@@ -45,7 +53,7 @@ async def sse_client(
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
try: try:
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
async with create_mcp_http_client(headers=headers) as client: async with create_mcp_http_client(headers=headers, auth=auth) as client:
async with aconnect_sse( async with aconnect_sse(
client, client,
"GET", "GET",