mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 06:54:18 +01:00
add auth client sse (#760)
This commit is contained in:
@@ -18,6 +18,7 @@ from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from mcp.client.auth import OAuthClientProvider, TokenStorage
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
|
||||
|
||||
@@ -149,8 +150,9 @@ class CallbackServer:
|
||||
class SimpleAuthClient:
|
||||
"""Simple MCP client with auth support."""
|
||||
|
||||
def __init__(self, server_url: str):
|
||||
def __init__(self, server_url: str, transport_type: str = "streamable_http"):
|
||||
self.server_url = server_url
|
||||
self.transport_type = transport_type
|
||||
self.session: ClientSession | None = None
|
||||
|
||||
async def connect(self):
|
||||
@@ -195,31 +197,23 @@ class SimpleAuthClient:
|
||||
callback_handler=callback_handler,
|
||||
)
|
||||
|
||||
# Create streamable HTTP transport with auth handler
|
||||
stream_context = streamablehttp_client(
|
||||
url=self.server_url,
|
||||
auth=oauth_auth,
|
||||
timeout=timedelta(seconds=60),
|
||||
)
|
||||
|
||||
print(
|
||||
"📡 Opening transport connection (HTTPX handles auth automatically)..."
|
||||
)
|
||||
async with stream_context as (read_stream, write_stream, get_session_id):
|
||||
print("🤝 Initializing MCP session...")
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
self.session = session
|
||||
print("⚡ Starting session initialization...")
|
||||
await session.initialize()
|
||||
print("✨ Session initialization complete!")
|
||||
|
||||
print(f"\n✅ Connected to MCP server at {self.server_url}")
|
||||
session_id = get_session_id()
|
||||
if session_id:
|
||||
print(f"Session ID: {session_id}")
|
||||
|
||||
# Run interactive loop
|
||||
await self.interactive_loop()
|
||||
# Create transport with auth handler based on transport type
|
||||
if self.transport_type == "sse":
|
||||
print("📡 Opening SSE transport connection with auth...")
|
||||
async with sse_client(
|
||||
url=self.server_url,
|
||||
auth=oauth_auth,
|
||||
timeout=60,
|
||||
) as (read_stream, write_stream):
|
||||
await self._run_session(read_stream, write_stream, None)
|
||||
else:
|
||||
print("📡 Opening StreamableHTTP transport connection with auth...")
|
||||
async with streamablehttp_client(
|
||||
url=self.server_url,
|
||||
auth=oauth_auth,
|
||||
timeout=timedelta(seconds=60),
|
||||
) as (read_stream, write_stream, get_session_id):
|
||||
await self._run_session(read_stream, write_stream, get_session_id)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to connect: {e}")
|
||||
@@ -227,6 +221,24 @@ class SimpleAuthClient:
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
async def _run_session(self, read_stream, write_stream, get_session_id):
|
||||
"""Run the MCP session with the given streams."""
|
||||
print("🤝 Initializing MCP session...")
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
self.session = session
|
||||
print("⚡ Starting session initialization...")
|
||||
await session.initialize()
|
||||
print("✨ Session initialization complete!")
|
||||
|
||||
print(f"\n✅ Connected to MCP server at {self.server_url}")
|
||||
if get_session_id:
|
||||
session_id = get_session_id()
|
||||
if session_id:
|
||||
print(f"Session ID: {session_id}")
|
||||
|
||||
# Run interactive loop
|
||||
await self.interactive_loop()
|
||||
|
||||
async def list_tools(self):
|
||||
"""List available tools from the server."""
|
||||
if not self.session:
|
||||
@@ -326,13 +338,20 @@ async def main():
|
||||
"""Main entry point."""
|
||||
# Default server URL - can be overridden with environment variable
|
||||
# Most MCP streamable HTTP servers use /mcp as the endpoint
|
||||
server_url = os.getenv("MCP_SERVER_URL", "http://localhost:8000/mcp")
|
||||
server_url = os.getenv("MCP_SERVER_PORT", 8000)
|
||||
transport_type = os.getenv("MCP_TRANSPORT_TYPE", "streamable_http")
|
||||
server_url = (
|
||||
f"http://localhost:{server_url}/mcp"
|
||||
if transport_type == "streamable_http"
|
||||
else f"http://localhost:{server_url}/sse"
|
||||
)
|
||||
|
||||
print("🚀 Simple MCP Auth Client")
|
||||
print(f"Connecting to: {server_url}")
|
||||
print(f"Transport type: {transport_type}")
|
||||
|
||||
# Start connection flow - OAuth will be handled automatically
|
||||
client = SimpleAuthClient(server_url)
|
||||
client = SimpleAuthClient(server_url, transport_type)
|
||||
await client.connect()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user