Files
mcp-python-sdk/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py
2025-05-19 20:38:04 +01:00

346 lines
12 KiB
Python

#!/usr/bin/env python3
"""
Simple MCP client example with OAuth authentication support.
This client connects to an MCP server using streamable HTTP transport with OAuth.
"""
import asyncio
import os
import threading
import time
import webbrowser
from datetime import timedelta
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Any
from urllib.parse import parse_qs, urlparse
from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.client.session import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
class InMemoryTokenStorage(TokenStorage):
"""Simple in-memory token storage implementation."""
def __init__(self):
self._tokens: OAuthToken | None = None
self._client_info: OAuthClientInformationFull | None = None
async def get_tokens(self) -> OAuthToken | None:
return self._tokens
async def set_tokens(self, tokens: OAuthToken) -> None:
self._tokens = tokens
async def get_client_info(self) -> OAuthClientInformationFull | None:
return self._client_info
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
self._client_info = client_info
class CallbackHandler(BaseHTTPRequestHandler):
"""Simple HTTP handler to capture OAuth callback."""
def __init__(self, request, client_address, server, callback_data):
"""Initialize with callback data storage."""
self.callback_data = callback_data
super().__init__(request, client_address, server)
def do_GET(self):
"""Handle GET request from OAuth redirect."""
parsed = urlparse(self.path)
query_params = parse_qs(parsed.query)
if "code" in query_params:
self.callback_data["authorization_code"] = query_params["code"][0]
self.callback_data["state"] = query_params.get("state", [None])[0]
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
self.wfile.write(b"""
<html>
<body>
<h1>Authorization Successful!</h1>
<p>You can close this window and return to the terminal.</p>
<script>setTimeout(() => window.close(), 2000);</script>
</body>
</html>
""")
elif "error" in query_params:
self.callback_data["error"] = query_params["error"][0]
self.send_response(400)
self.send_header("Content-type", "text/html")
self.end_headers()
self.wfile.write(
f"""
<html>
<body>
<h1>Authorization Failed</h1>
<p>Error: {query_params['error'][0]}</p>
<p>You can close this window and return to the terminal.</p>
</body>
</html>
""".encode()
)
else:
self.send_response(404)
self.end_headers()
def log_message(self, format, *args):
"""Suppress default logging."""
pass
class CallbackServer:
"""Simple server to handle OAuth callbacks."""
def __init__(self, port=3000):
self.port = port
self.server = None
self.thread = None
self.callback_data = {"authorization_code": None, "state": None, "error": None}
def _create_handler_with_data(self):
"""Create a handler class with access to callback data."""
callback_data = self.callback_data
class DataCallbackHandler(CallbackHandler):
def __init__(self, request, client_address, server):
super().__init__(request, client_address, server, callback_data)
return DataCallbackHandler
def start(self):
"""Start the callback server in a background thread."""
handler_class = self._create_handler_with_data()
self.server = HTTPServer(("localhost", self.port), handler_class)
self.thread = threading.Thread(target=self.server.serve_forever, daemon=True)
self.thread.start()
print(f"🖥️ Started callback server on http://localhost:{self.port}")
def stop(self):
"""Stop the callback server."""
if self.server:
self.server.shutdown()
self.server.server_close()
if self.thread:
self.thread.join(timeout=1)
def wait_for_callback(self, timeout=300):
"""Wait for OAuth callback with timeout."""
start_time = time.time()
while time.time() - start_time < timeout:
if self.callback_data["authorization_code"]:
return self.callback_data["authorization_code"]
elif self.callback_data["error"]:
raise Exception(f"OAuth error: {self.callback_data['error']}")
time.sleep(0.1)
raise Exception("Timeout waiting for OAuth callback")
def get_state(self):
"""Get the received state parameter."""
return self.callback_data["state"]
class SimpleAuthClient:
"""Simple MCP client with auth support."""
def __init__(self, server_url: str):
self.server_url = server_url
self.session: ClientSession | None = None
async def connect(self):
"""Connect to the MCP server."""
print(f"🔗 Attempting to connect to {self.server_url}...")
try:
# Set up callback server
callback_server = CallbackServer(port=3000)
callback_server.start()
async def callback_handler() -> tuple[str, str | None]:
"""Wait for OAuth callback and return auth code and state."""
print("⏳ Waiting for authorization callback...")
try:
auth_code = callback_server.wait_for_callback(timeout=300)
return auth_code, callback_server.get_state()
finally:
callback_server.stop()
client_metadata_dict = {
"client_name": "Simple Auth Client",
"redirect_uris": ["http://localhost:3000/callback"],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"token_endpoint_auth_method": "client_secret_post",
}
async def _default_redirect_handler(authorization_url: str) -> None:
"""Default redirect handler that opens the URL in a browser."""
print(f"Opening browser for authorization: {authorization_url}")
webbrowser.open(authorization_url)
# Create OAuth authentication handler using the new interface
oauth_auth = OAuthClientProvider(
server_url=self.server_url.replace("/mcp", ""),
client_metadata=OAuthClientMetadata.model_validate(
client_metadata_dict
),
storage=InMemoryTokenStorage(),
redirect_handler=_default_redirect_handler,
callback_handler=callback_handler,
)
# Create streamable HTTP transport with auth handler
stream_context = streamablehttp_client(
url=self.server_url,
auth=oauth_auth,
timeout=timedelta(seconds=60),
)
print(
"📡 Opening transport connection (HTTPX handles auth automatically)..."
)
async with stream_context as (read_stream, write_stream, get_session_id):
print("🤝 Initializing MCP session...")
async with ClientSession(read_stream, write_stream) as session:
self.session = session
print("⚡ Starting session initialization...")
await session.initialize()
print("✨ Session initialization complete!")
print(f"\n✅ Connected to MCP server at {self.server_url}")
session_id = get_session_id()
if session_id:
print(f"Session ID: {session_id}")
# Run interactive loop
await self.interactive_loop()
except Exception as e:
print(f"❌ Failed to connect: {e}")
import traceback
traceback.print_exc()
async def list_tools(self):
"""List available tools from the server."""
if not self.session:
print("❌ Not connected to server")
return
try:
result = await self.session.list_tools()
if hasattr(result, "tools") and result.tools:
print("\n📋 Available tools:")
for i, tool in enumerate(result.tools, 1):
print(f"{i}. {tool.name}")
if tool.description:
print(f" Description: {tool.description}")
print()
else:
print("No tools available")
except Exception as e:
print(f"❌ Failed to list tools: {e}")
async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None = None):
"""Call a specific tool."""
if not self.session:
print("❌ Not connected to server")
return
try:
result = await self.session.call_tool(tool_name, arguments or {})
print(f"\n🔧 Tool '{tool_name}' result:")
if hasattr(result, "content"):
for content in result.content:
if content.type == "text":
print(content.text)
else:
print(content)
else:
print(result)
except Exception as e:
print(f"❌ Failed to call tool '{tool_name}': {e}")
async def interactive_loop(self):
"""Run interactive command loop."""
print("\n🎯 Interactive MCP Client")
print("Commands:")
print(" list - List available tools")
print(" call <tool_name> [args] - Call a tool")
print(" quit - Exit the client")
print()
while True:
try:
command = input("mcp> ").strip()
if not command:
continue
if command == "quit":
break
elif command == "list":
await self.list_tools()
elif command.startswith("call "):
parts = command.split(maxsplit=2)
tool_name = parts[1] if len(parts) > 1 else ""
if not tool_name:
print("❌ Please specify a tool name")
continue
# Parse arguments (simple JSON-like format)
arguments = {}
if len(parts) > 2:
import json
try:
arguments = json.loads(parts[2])
except json.JSONDecodeError:
print("❌ Invalid arguments format (expected JSON)")
continue
await self.call_tool(tool_name, arguments)
else:
print(
"❌ Unknown command. Try 'list', 'call <tool_name>', or 'quit'"
)
except KeyboardInterrupt:
print("\n\n👋 Goodbye!")
break
except EOFError:
break
async def main():
"""Main entry point."""
# Default server URL - can be overridden with environment variable
# Most MCP streamable HTTP servers use /mcp as the endpoint
server_url = os.getenv("MCP_SERVER_URL", "http://localhost:8000/mcp")
print("🚀 Simple MCP Auth Client")
print(f"Connecting to: {server_url}")
# Start connection flow - OAuth will be handled automatically
client = SimpleAuthClient(server_url)
await client.connect()
def cli():
"""CLI entry point for uv script."""
asyncio.run(main())
if __name__ == "__main__":
cli()