mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 06:54:18 +01:00
Add OAuth authentication client for HTTPX (#751)
Co-authored-by: Paul Carleton <paulc@anthropic.com>
This commit is contained in:
@@ -0,0 +1,345 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple MCP client example with OAuth authentication support.
|
||||
|
||||
This client connects to an MCP server using streamable HTTP transport with OAuth.
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import webbrowser
|
||||
from datetime import timedelta
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from mcp.client.auth import OAuthClientProvider, TokenStorage
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
|
||||
|
||||
|
||||
class InMemoryTokenStorage(TokenStorage):
|
||||
"""Simple in-memory token storage implementation."""
|
||||
|
||||
def __init__(self):
|
||||
self._tokens: OAuthToken | None = None
|
||||
self._client_info: OAuthClientInformationFull | None = None
|
||||
|
||||
async def get_tokens(self) -> OAuthToken | None:
|
||||
return self._tokens
|
||||
|
||||
async def set_tokens(self, tokens: OAuthToken) -> None:
|
||||
self._tokens = tokens
|
||||
|
||||
async def get_client_info(self) -> OAuthClientInformationFull | None:
|
||||
return self._client_info
|
||||
|
||||
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
|
||||
self._client_info = client_info
|
||||
|
||||
|
||||
class CallbackHandler(BaseHTTPRequestHandler):
|
||||
"""Simple HTTP handler to capture OAuth callback."""
|
||||
|
||||
def __init__(self, request, client_address, server, callback_data):
|
||||
"""Initialize with callback data storage."""
|
||||
self.callback_data = callback_data
|
||||
super().__init__(request, client_address, server)
|
||||
|
||||
def do_GET(self):
|
||||
"""Handle GET request from OAuth redirect."""
|
||||
parsed = urlparse(self.path)
|
||||
query_params = parse_qs(parsed.query)
|
||||
|
||||
if "code" in query_params:
|
||||
self.callback_data["authorization_code"] = query_params["code"][0]
|
||||
self.callback_data["state"] = query_params.get("state", [None])[0]
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/html")
|
||||
self.end_headers()
|
||||
self.wfile.write(b"""
|
||||
<html>
|
||||
<body>
|
||||
<h1>Authorization Successful!</h1>
|
||||
<p>You can close this window and return to the terminal.</p>
|
||||
<script>setTimeout(() => window.close(), 2000);</script>
|
||||
</body>
|
||||
</html>
|
||||
""")
|
||||
elif "error" in query_params:
|
||||
self.callback_data["error"] = query_params["error"][0]
|
||||
self.send_response(400)
|
||||
self.send_header("Content-type", "text/html")
|
||||
self.end_headers()
|
||||
self.wfile.write(
|
||||
f"""
|
||||
<html>
|
||||
<body>
|
||||
<h1>Authorization Failed</h1>
|
||||
<p>Error: {query_params['error'][0]}</p>
|
||||
<p>You can close this window and return to the terminal.</p>
|
||||
</body>
|
||||
</html>
|
||||
""".encode()
|
||||
)
|
||||
else:
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
|
||||
def log_message(self, format, *args):
|
||||
"""Suppress default logging."""
|
||||
pass
|
||||
|
||||
|
||||
class CallbackServer:
|
||||
"""Simple server to handle OAuth callbacks."""
|
||||
|
||||
def __init__(self, port=3000):
|
||||
self.port = port
|
||||
self.server = None
|
||||
self.thread = None
|
||||
self.callback_data = {"authorization_code": None, "state": None, "error": None}
|
||||
|
||||
def _create_handler_with_data(self):
|
||||
"""Create a handler class with access to callback data."""
|
||||
callback_data = self.callback_data
|
||||
|
||||
class DataCallbackHandler(CallbackHandler):
|
||||
def __init__(self, request, client_address, server):
|
||||
super().__init__(request, client_address, server, callback_data)
|
||||
|
||||
return DataCallbackHandler
|
||||
|
||||
def start(self):
|
||||
"""Start the callback server in a background thread."""
|
||||
handler_class = self._create_handler_with_data()
|
||||
self.server = HTTPServer(("localhost", self.port), handler_class)
|
||||
self.thread = threading.Thread(target=self.server.serve_forever, daemon=True)
|
||||
self.thread.start()
|
||||
print(f"🖥️ Started callback server on http://localhost:{self.port}")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the callback server."""
|
||||
if self.server:
|
||||
self.server.shutdown()
|
||||
self.server.server_close()
|
||||
if self.thread:
|
||||
self.thread.join(timeout=1)
|
||||
|
||||
def wait_for_callback(self, timeout=300):
|
||||
"""Wait for OAuth callback with timeout."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
if self.callback_data["authorization_code"]:
|
||||
return self.callback_data["authorization_code"]
|
||||
elif self.callback_data["error"]:
|
||||
raise Exception(f"OAuth error: {self.callback_data['error']}")
|
||||
time.sleep(0.1)
|
||||
raise Exception("Timeout waiting for OAuth callback")
|
||||
|
||||
def get_state(self):
|
||||
"""Get the received state parameter."""
|
||||
return self.callback_data["state"]
|
||||
|
||||
|
||||
class SimpleAuthClient:
|
||||
"""Simple MCP client with auth support."""
|
||||
|
||||
def __init__(self, server_url: str):
|
||||
self.server_url = server_url
|
||||
self.session: ClientSession | None = None
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to the MCP server."""
|
||||
print(f"🔗 Attempting to connect to {self.server_url}...")
|
||||
|
||||
try:
|
||||
# Set up callback server
|
||||
callback_server = CallbackServer(port=3000)
|
||||
callback_server.start()
|
||||
|
||||
async def callback_handler() -> tuple[str, str | None]:
|
||||
"""Wait for OAuth callback and return auth code and state."""
|
||||
print("⏳ Waiting for authorization callback...")
|
||||
try:
|
||||
auth_code = callback_server.wait_for_callback(timeout=300)
|
||||
return auth_code, callback_server.get_state()
|
||||
finally:
|
||||
callback_server.stop()
|
||||
|
||||
client_metadata_dict = {
|
||||
"client_name": "Simple Auth Client",
|
||||
"redirect_uris": ["http://localhost:3000/callback"],
|
||||
"grant_types": ["authorization_code", "refresh_token"],
|
||||
"response_types": ["code"],
|
||||
"token_endpoint_auth_method": "client_secret_post",
|
||||
}
|
||||
|
||||
async def _default_redirect_handler(authorization_url: str) -> None:
|
||||
"""Default redirect handler that opens the URL in a browser."""
|
||||
print(f"Opening browser for authorization: {authorization_url}")
|
||||
webbrowser.open(authorization_url)
|
||||
|
||||
# Create OAuth authentication handler using the new interface
|
||||
oauth_auth = OAuthClientProvider(
|
||||
server_url=self.server_url.replace("/mcp", ""),
|
||||
client_metadata=OAuthClientMetadata.model_validate(
|
||||
client_metadata_dict
|
||||
),
|
||||
storage=InMemoryTokenStorage(),
|
||||
redirect_handler=_default_redirect_handler,
|
||||
callback_handler=callback_handler,
|
||||
)
|
||||
|
||||
# Create streamable HTTP transport with auth handler
|
||||
stream_context = streamablehttp_client(
|
||||
url=self.server_url,
|
||||
auth=oauth_auth,
|
||||
timeout=timedelta(seconds=60),
|
||||
)
|
||||
|
||||
print(
|
||||
"📡 Opening transport connection (HTTPX handles auth automatically)..."
|
||||
)
|
||||
async with stream_context as (read_stream, write_stream, get_session_id):
|
||||
print("🤝 Initializing MCP session...")
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
self.session = session
|
||||
print("⚡ Starting session initialization...")
|
||||
await session.initialize()
|
||||
print("✨ Session initialization complete!")
|
||||
|
||||
print(f"\n✅ Connected to MCP server at {self.server_url}")
|
||||
session_id = get_session_id()
|
||||
if session_id:
|
||||
print(f"Session ID: {session_id}")
|
||||
|
||||
# Run interactive loop
|
||||
await self.interactive_loop()
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to connect: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
async def list_tools(self):
|
||||
"""List available tools from the server."""
|
||||
if not self.session:
|
||||
print("❌ Not connected to server")
|
||||
return
|
||||
|
||||
try:
|
||||
result = await self.session.list_tools()
|
||||
if hasattr(result, "tools") and result.tools:
|
||||
print("\n📋 Available tools:")
|
||||
for i, tool in enumerate(result.tools, 1):
|
||||
print(f"{i}. {tool.name}")
|
||||
if tool.description:
|
||||
print(f" Description: {tool.description}")
|
||||
print()
|
||||
else:
|
||||
print("No tools available")
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to list tools: {e}")
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None = None):
|
||||
"""Call a specific tool."""
|
||||
if not self.session:
|
||||
print("❌ Not connected to server")
|
||||
return
|
||||
|
||||
try:
|
||||
result = await self.session.call_tool(tool_name, arguments or {})
|
||||
print(f"\n🔧 Tool '{tool_name}' result:")
|
||||
if hasattr(result, "content"):
|
||||
for content in result.content:
|
||||
if content.type == "text":
|
||||
print(content.text)
|
||||
else:
|
||||
print(content)
|
||||
else:
|
||||
print(result)
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to call tool '{tool_name}': {e}")
|
||||
|
||||
async def interactive_loop(self):
|
||||
"""Run interactive command loop."""
|
||||
print("\n🎯 Interactive MCP Client")
|
||||
print("Commands:")
|
||||
print(" list - List available tools")
|
||||
print(" call <tool_name> [args] - Call a tool")
|
||||
print(" quit - Exit the client")
|
||||
print()
|
||||
|
||||
while True:
|
||||
try:
|
||||
command = input("mcp> ").strip()
|
||||
|
||||
if not command:
|
||||
continue
|
||||
|
||||
if command == "quit":
|
||||
break
|
||||
|
||||
elif command == "list":
|
||||
await self.list_tools()
|
||||
|
||||
elif command.startswith("call "):
|
||||
parts = command.split(maxsplit=2)
|
||||
tool_name = parts[1] if len(parts) > 1 else ""
|
||||
|
||||
if not tool_name:
|
||||
print("❌ Please specify a tool name")
|
||||
continue
|
||||
|
||||
# Parse arguments (simple JSON-like format)
|
||||
arguments = {}
|
||||
if len(parts) > 2:
|
||||
import json
|
||||
|
||||
try:
|
||||
arguments = json.loads(parts[2])
|
||||
except json.JSONDecodeError:
|
||||
print("❌ Invalid arguments format (expected JSON)")
|
||||
continue
|
||||
|
||||
await self.call_tool(tool_name, arguments)
|
||||
|
||||
else:
|
||||
print(
|
||||
"❌ Unknown command. Try 'list', 'call <tool_name>', or 'quit'"
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n👋 Goodbye!")
|
||||
break
|
||||
except EOFError:
|
||||
break
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point."""
|
||||
# Default server URL - can be overridden with environment variable
|
||||
# Most MCP streamable HTTP servers use /mcp as the endpoint
|
||||
server_url = os.getenv("MCP_SERVER_URL", "http://localhost:8000/mcp")
|
||||
|
||||
print("🚀 Simple MCP Auth Client")
|
||||
print(f"Connecting to: {server_url}")
|
||||
|
||||
# Start connection flow - OAuth will be handled automatically
|
||||
client = SimpleAuthClient(server_url)
|
||||
await client.connect()
|
||||
|
||||
|
||||
def cli():
|
||||
"""CLI entry point for uv script."""
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
Reference in New Issue
Block a user