mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
572 lines
21 KiB
Python
572 lines
21 KiB
Python
"""
|
|
Plugin API Gateway
|
|
Handles authentication, routing, and security for plugin APIs
|
|
"""
|
|
import asyncio
|
|
import time
|
|
from jose import jwt
|
|
from typing import Dict, Any, List, Optional, Tuple
|
|
from fastapi import FastAPI, Request, Response, HTTPException, Depends
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
from fastapi.middleware.base import BaseHTTPMiddleware
|
|
from fastapi.responses import JSONResponse
|
|
import aiohttp
|
|
from urllib.parse import urlparse
|
|
import hashlib
|
|
import hmac
|
|
|
|
from app.core.config import settings
|
|
from app.core.logging import get_logger
|
|
from app.core.security import verify_jwt_token, get_current_user
|
|
from app.models.plugin import Plugin, PluginConfiguration, PluginAuditLog
|
|
from app.models.api_key import APIKey
|
|
from app.models.user import User
|
|
from app.db.database import get_db
|
|
from app.services.plugin_sandbox import plugin_loader
|
|
from app.services.plugin_context_manager import plugin_context_manager
|
|
from app.utils.exceptions import SecurityError, PluginError
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
logger = get_logger("plugin.gateway")
|
|
security = HTTPBearer()
|
|
|
|
|
|
class PluginAuthenticationService:
|
|
"""Handles plugin authentication and authorization"""
|
|
|
|
@staticmethod
|
|
async def verify_plugin_token(token: str, db: Session) -> Tuple[Optional[str], Optional[Dict[str, Any]]]:
|
|
"""Verify plugin authentication token"""
|
|
try:
|
|
# Decode JWT token
|
|
payload = jwt.decode(
|
|
token,
|
|
settings.JWT_SECRET,
|
|
algorithms=["HS256"]
|
|
)
|
|
|
|
user_id = payload.get("sub")
|
|
plugin_id = payload.get("plugin_id")
|
|
|
|
if not user_id or not plugin_id:
|
|
return None, None
|
|
|
|
# Verify plugin exists and is active
|
|
plugin = db.query(Plugin).filter(
|
|
Plugin.id == plugin_id,
|
|
Plugin.status == "enabled"
|
|
).first()
|
|
|
|
if not plugin:
|
|
return None, None
|
|
|
|
# Get user
|
|
user = db.query(User).filter(User.id == user_id).first()
|
|
if not user:
|
|
return None, None
|
|
|
|
return user_id, {
|
|
"user": user,
|
|
"plugin": plugin,
|
|
"permissions": payload.get("permissions", [])
|
|
}
|
|
|
|
except jwt.InvalidTokenError:
|
|
return None, None
|
|
|
|
@staticmethod
|
|
async def verify_api_key_access(api_key: str, plugin_id: str, db: Session) -> Tuple[Optional[str], Optional[Dict[str, Any]]]:
|
|
"""Verify API key has access to plugin"""
|
|
try:
|
|
# Find API key
|
|
api_key_obj = db.query(APIKey).filter(
|
|
APIKey.key_hash == hashlib.sha256(api_key.encode()).hexdigest(),
|
|
APIKey.is_active == True
|
|
).first()
|
|
|
|
if not api_key_obj:
|
|
return None, None
|
|
|
|
# Check if API key has plugin access
|
|
plugin = db.query(Plugin).filter(
|
|
Plugin.id == plugin_id,
|
|
Plugin.status == "enabled"
|
|
).first()
|
|
|
|
if not plugin:
|
|
return None, None
|
|
|
|
# Verify plugin permissions for API key
|
|
# TODO: Check plugin-specific permissions in API key scopes
|
|
|
|
return api_key_obj.user_id, {
|
|
"user": api_key_obj.user,
|
|
"plugin": plugin,
|
|
"api_key": api_key_obj,
|
|
"permissions": ["api_access"]
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"API key verification failed: {e}")
|
|
return None, None
|
|
|
|
@staticmethod
|
|
async def check_plugin_permissions(user_id: str, plugin_id: str, endpoint: str,
|
|
method: str, db: Session) -> bool:
|
|
"""Check if user has permission to access plugin endpoint"""
|
|
try:
|
|
# Get plugin configuration for user
|
|
config = db.query(PluginConfiguration).filter(
|
|
PluginConfiguration.user_id == user_id,
|
|
PluginConfiguration.plugin_id == plugin_id,
|
|
PluginConfiguration.is_active == True
|
|
).first()
|
|
|
|
if not config:
|
|
return False
|
|
|
|
# Get plugin manifest to check endpoint permissions
|
|
plugin = db.query(Plugin).filter(Plugin.id == plugin_id).first()
|
|
if not plugin or not plugin.manifest_data:
|
|
return False
|
|
|
|
# Check endpoint permissions in manifest
|
|
manifest = plugin.manifest_data
|
|
api_endpoints = manifest.get("spec", {}).get("api_endpoints", [])
|
|
|
|
for ep in api_endpoints:
|
|
if ep.get("path") == endpoint and method in ep.get("methods", []):
|
|
return ep.get("auth_required", True) # Default to requiring auth
|
|
|
|
# If endpoint not found in manifest, deny access
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Permission check failed: {e}")
|
|
return False
|
|
|
|
|
|
class PluginAuditService:
|
|
"""Handles audit logging for plugin access"""
|
|
|
|
@staticmethod
|
|
async def log_plugin_access(
|
|
plugin_id: str,
|
|
user_id: Optional[str],
|
|
api_key_id: Optional[str],
|
|
endpoint: str,
|
|
method: str,
|
|
ip_address: str,
|
|
user_agent: str,
|
|
status_code: int,
|
|
response_time_ms: float,
|
|
db: Session
|
|
):
|
|
"""Log plugin API access for audit trail"""
|
|
try:
|
|
audit_log = PluginAuditLog(
|
|
plugin_id=plugin_id,
|
|
user_id=user_id,
|
|
api_key_id=api_key_id,
|
|
action="api_access",
|
|
endpoint=endpoint,
|
|
method=method,
|
|
details={
|
|
"status_code": status_code,
|
|
"response_time_ms": response_time_ms,
|
|
"ip_address": ip_address,
|
|
"user_agent": user_agent
|
|
}
|
|
)
|
|
|
|
db.add(audit_log)
|
|
db.commit()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to log plugin access: {e}")
|
|
db.rollback()
|
|
|
|
|
|
class PluginRateLimiter:
|
|
"""Rate limiting for plugin API access"""
|
|
|
|
def __init__(self):
|
|
self.user_requests: Dict[str, List[float]] = {}
|
|
self.plugin_requests: Dict[str, List[float]] = {}
|
|
|
|
def check_rate_limit(self, user_id: str, plugin_id: str, limits: Dict[str, int]) -> bool:
|
|
"""Check if request is within rate limits"""
|
|
current_time = time.time()
|
|
window_seconds = 60 # 1 minute window
|
|
|
|
# Clean old requests
|
|
self._clean_old_requests(current_time, window_seconds)
|
|
|
|
# Check user rate limit
|
|
user_key = f"user:{user_id}"
|
|
user_limit = limits.get("user_requests_per_minute", 100)
|
|
|
|
if user_key not in self.user_requests:
|
|
self.user_requests[user_key] = []
|
|
|
|
if len(self.user_requests[user_key]) >= user_limit:
|
|
logger.warning(f"User {user_id} rate limit exceeded: {len(self.user_requests[user_key])}/{user_limit}")
|
|
return False
|
|
|
|
# Check plugin rate limit
|
|
plugin_key = f"plugin:{plugin_id}"
|
|
plugin_limit = limits.get("plugin_requests_per_minute", 200)
|
|
|
|
if plugin_key not in self.plugin_requests:
|
|
self.plugin_requests[plugin_key] = []
|
|
|
|
if len(self.plugin_requests[plugin_key]) >= plugin_limit:
|
|
logger.warning(f"Plugin {plugin_id} rate limit exceeded: {len(self.plugin_requests[plugin_key])}/{plugin_limit}")
|
|
return False
|
|
|
|
# Record requests
|
|
self.user_requests[user_key].append(current_time)
|
|
self.plugin_requests[plugin_key].append(current_time)
|
|
|
|
return True
|
|
|
|
def _clean_old_requests(self, current_time: float, window_seconds: int):
|
|
"""Remove requests older than window"""
|
|
cutoff_time = current_time - window_seconds
|
|
|
|
for key in self.user_requests:
|
|
self.user_requests[key] = [
|
|
req_time for req_time in self.user_requests[key]
|
|
if req_time > cutoff_time
|
|
]
|
|
|
|
for key in self.plugin_requests:
|
|
self.plugin_requests[key] = [
|
|
req_time for req_time in self.plugin_requests[key]
|
|
if req_time > cutoff_time
|
|
]
|
|
|
|
|
|
class PluginGatewayMiddleware(BaseHTTPMiddleware):
|
|
"""Middleware for plugin API gateway"""
|
|
|
|
def __init__(self, app: FastAPI):
|
|
super().__init__(app)
|
|
self.auth_service = PluginAuthenticationService()
|
|
self.audit_service = PluginAuditService()
|
|
self.rate_limiter = PluginRateLimiter()
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
"""Process plugin API requests"""
|
|
start_time = time.time()
|
|
|
|
# Check if this is a plugin API request
|
|
if not request.url.path.startswith("/api/v1/plugins/"):
|
|
return await call_next(request)
|
|
|
|
# Extract plugin ID from path
|
|
path_parts = request.url.path.split("/")
|
|
if len(path_parts) < 5:
|
|
return JSONResponse(
|
|
status_code=404,
|
|
content={"error": "Invalid plugin API path"}
|
|
)
|
|
|
|
plugin_id = path_parts[4]
|
|
plugin_endpoint = "/" + "/".join(path_parts[5:]) if len(path_parts) > 5 else "/"
|
|
|
|
# Get database session
|
|
db = next(get_db())
|
|
|
|
try:
|
|
# Authenticate request
|
|
auth_result = await self._authenticate_request(request, plugin_id, db)
|
|
if not auth_result:
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={"error": "Authentication failed"}
|
|
)
|
|
|
|
user_id, auth_context = auth_result
|
|
|
|
# Check permissions
|
|
has_permission = await self.auth_service.check_plugin_permissions(
|
|
user_id, plugin_id, plugin_endpoint, request.method, db
|
|
)
|
|
|
|
if not has_permission:
|
|
return JSONResponse(
|
|
status_code=403,
|
|
content={"error": "Insufficient permissions"}
|
|
)
|
|
|
|
# Check rate limits
|
|
rate_limits = {
|
|
"user_requests_per_minute": 100,
|
|
"plugin_requests_per_minute": 200
|
|
}
|
|
|
|
if not self.rate_limiter.check_rate_limit(user_id, plugin_id, rate_limits):
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={"error": "Rate limit exceeded"}
|
|
)
|
|
|
|
# Add authentication context to request
|
|
request.state.user_id = user_id
|
|
request.state.plugin_id = plugin_id
|
|
request.state.auth_context = auth_context
|
|
request.state.plugin_endpoint = plugin_endpoint
|
|
|
|
# Forward to plugin
|
|
response = await self._forward_to_plugin(request, plugin_id, plugin_endpoint, call_next)
|
|
|
|
# Log access
|
|
response_time = (time.time() - start_time) * 1000
|
|
await self.audit_service.log_plugin_access(
|
|
plugin_id=plugin_id,
|
|
user_id=user_id,
|
|
api_key_id=auth_context.get("api_key", {}).get("id") if "api_key" in auth_context else None,
|
|
endpoint=plugin_endpoint,
|
|
method=request.method,
|
|
ip_address=request.client.host,
|
|
user_agent=request.headers.get("user-agent", ""),
|
|
status_code=response.status_code,
|
|
response_time_ms=response_time,
|
|
db=db
|
|
)
|
|
|
|
return response
|
|
|
|
except Exception as e:
|
|
logger.error(f"Plugin gateway error: {e}")
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={"error": "Internal gateway error"}
|
|
)
|
|
finally:
|
|
db.close()
|
|
|
|
async def _authenticate_request(self, request: Request, plugin_id: str, db: Session) -> Optional[Tuple[str, Dict[str, Any]]]:
|
|
"""Authenticate plugin API request"""
|
|
|
|
# Check for Authorization header
|
|
auth_header = request.headers.get("authorization")
|
|
if not auth_header:
|
|
return None
|
|
|
|
if auth_header.startswith("Bearer "):
|
|
token = auth_header[7:]
|
|
|
|
# Try JWT token first
|
|
result = await self.auth_service.verify_plugin_token(token, db)
|
|
if result[0]:
|
|
return result
|
|
|
|
# Try API key
|
|
result = await self.auth_service.verify_api_key_access(token, plugin_id, db)
|
|
if result[0]:
|
|
return result
|
|
|
|
return None
|
|
|
|
async def _forward_to_plugin(self, request: Request, plugin_id: str,
|
|
plugin_endpoint: str, call_next) -> Response:
|
|
"""Forward request to plugin and handle response"""
|
|
|
|
# Check if plugin is loaded
|
|
plugin_instance = plugin_loader.loaded_plugins.get(plugin_id)
|
|
if not plugin_instance:
|
|
return JSONResponse(
|
|
status_code=503,
|
|
content={"error": f"Plugin {plugin_id} not loaded"}
|
|
)
|
|
|
|
# Check sandbox resource usage
|
|
sandbox = plugin_loader.get_plugin_sandbox(plugin_id)
|
|
if sandbox:
|
|
try:
|
|
sandbox.check_resource_usage()
|
|
sandbox.track_api_call()
|
|
except Exception as e:
|
|
return JSONResponse(
|
|
status_code=503,
|
|
content={"error": f"Plugin resource limit exceeded: {e}"}
|
|
)
|
|
|
|
# Add plugin context headers
|
|
request.headers.__dict__["_list"].extend([
|
|
(b"x-user-id", request.state.user_id.encode()),
|
|
(b"x-plugin-id", plugin_id.encode()),
|
|
(b"x-plugin-endpoint", plugin_endpoint.encode()),
|
|
(b"x-real-ip", request.client.host.encode()),
|
|
])
|
|
|
|
# Forward to plugin
|
|
return await call_next(request)
|
|
|
|
|
|
class PluginAPIGateway:
|
|
"""Main plugin API gateway service"""
|
|
|
|
def __init__(self):
|
|
self.middleware = None
|
|
self.app = None
|
|
|
|
def init_app(self, app: FastAPI):
|
|
"""Initialize gateway with FastAPI app"""
|
|
self.app = app
|
|
self.middleware = PluginGatewayMiddleware(app)
|
|
app.add_middleware(PluginGatewayMiddleware)
|
|
|
|
# Add plugin management endpoints
|
|
self._add_management_endpoints(app)
|
|
|
|
def _add_management_endpoints(self, app: FastAPI):
|
|
"""Add plugin management endpoints"""
|
|
|
|
@app.get("/api/v1/plugins")
|
|
async def list_plugins(db: Session = Depends(get_db)):
|
|
"""List available plugins"""
|
|
plugins = db.query(Plugin).filter(Plugin.status == "enabled").all()
|
|
|
|
plugin_list = []
|
|
for plugin in plugins:
|
|
# Get runtime status
|
|
plugin_instance = plugin_loader.loaded_plugins.get(plugin.id)
|
|
loaded = plugin_instance is not None
|
|
|
|
# Get resource stats if loaded
|
|
resource_stats = {}
|
|
if loaded:
|
|
resource_stats = plugin_loader.get_resource_stats(plugin.id)
|
|
|
|
plugin_list.append({
|
|
"id": plugin.id,
|
|
"name": plugin.name,
|
|
"version": plugin.version,
|
|
"description": plugin.description,
|
|
"status": plugin.status,
|
|
"loaded": loaded,
|
|
"resource_usage": resource_stats,
|
|
"created_at": plugin.created_at.isoformat(),
|
|
"updated_at": plugin.updated_at.isoformat()
|
|
})
|
|
|
|
return {"plugins": plugin_list}
|
|
|
|
@app.get("/api/v1/plugins/{plugin_id}")
|
|
async def get_plugin(plugin_id: str, db: Session = Depends(get_db)):
|
|
"""Get plugin details"""
|
|
plugin = db.query(Plugin).filter(Plugin.id == plugin_id).first()
|
|
if not plugin:
|
|
raise HTTPException(status_code=404, detail="Plugin not found")
|
|
|
|
# Get runtime status
|
|
plugin_instance = plugin_loader.loaded_plugins.get(plugin_id)
|
|
loaded = plugin_instance is not None
|
|
|
|
# Get resource stats if loaded
|
|
resource_stats = {}
|
|
health_status = {}
|
|
if loaded:
|
|
resource_stats = plugin_loader.get_resource_stats(plugin_id)
|
|
health_status = await plugin_instance.health_check()
|
|
|
|
return {
|
|
"plugin": {
|
|
"id": plugin.id,
|
|
"name": plugin.name,
|
|
"version": plugin.version,
|
|
"description": plugin.description,
|
|
"status": plugin.status,
|
|
"manifest": plugin.manifest_data,
|
|
"loaded": loaded,
|
|
"resource_usage": resource_stats,
|
|
"health": health_status,
|
|
"created_at": plugin.created_at.isoformat(),
|
|
"updated_at": plugin.updated_at.isoformat()
|
|
}
|
|
}
|
|
|
|
@app.post("/api/v1/plugins/{plugin_id}/load")
|
|
async def load_plugin(plugin_id: str, db: Session = Depends(get_db)):
|
|
"""Load a plugin"""
|
|
plugin = db.query(Plugin).filter(Plugin.id == plugin_id).first()
|
|
if not plugin:
|
|
raise HTTPException(status_code=404, detail="Plugin not found")
|
|
|
|
if plugin.status != "enabled":
|
|
raise HTTPException(status_code=400, detail="Plugin not enabled")
|
|
|
|
# Check if already loaded
|
|
if plugin_id in plugin_loader.loaded_plugins:
|
|
raise HTTPException(status_code=400, detail="Plugin already loaded")
|
|
|
|
try:
|
|
# Load plugin with proper context management
|
|
plugin_dir = f"/plugins/{plugin_id}"
|
|
|
|
# Create plugin context for standardized interface
|
|
plugin_context = plugin_context_manager.create_plugin_context(
|
|
plugin_id=plugin_id,
|
|
user_id="system", # System loading context
|
|
session_type="plugin_load"
|
|
)
|
|
|
|
# Generate plugin token based on context
|
|
plugin_token = plugin_context_manager.generate_plugin_token(plugin_context["context_id"])
|
|
|
|
# Log plugin loading action
|
|
plugin_context_manager.add_audit_trail_entry(
|
|
plugin_context["context_id"],
|
|
"plugin_load",
|
|
{"plugin_dir": plugin_dir, "action": "load_plugin_with_sandbox"}
|
|
)
|
|
|
|
await plugin_loader.load_plugin_with_sandbox(plugin_dir, plugin_token)
|
|
|
|
return {"status": "loaded", "plugin_id": plugin_id}
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Failed to load plugin: {e}")
|
|
|
|
@app.post("/api/v1/plugins/{plugin_id}/unload")
|
|
async def unload_plugin(plugin_id: str):
|
|
"""Unload a plugin"""
|
|
if plugin_id not in plugin_loader.loaded_plugins:
|
|
raise HTTPException(status_code=404, detail="Plugin not loaded")
|
|
|
|
try:
|
|
success = await plugin_loader.unload_plugin(plugin_id)
|
|
if success:
|
|
return {"status": "unloaded", "plugin_id": plugin_id}
|
|
else:
|
|
raise HTTPException(status_code=500, detail="Failed to unload plugin")
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Failed to unload plugin: {e}")
|
|
|
|
@app.get("/api/v1/plugins/{plugin_id}/health")
|
|
async def get_plugin_health(plugin_id: str):
|
|
"""Get plugin health status"""
|
|
plugin_instance = plugin_loader.loaded_plugins.get(plugin_id)
|
|
if not plugin_instance:
|
|
raise HTTPException(status_code=404, detail="Plugin not loaded")
|
|
|
|
try:
|
|
health = await plugin_instance.health_check()
|
|
resource_stats = plugin_loader.get_resource_stats(plugin_id)
|
|
|
|
return {
|
|
"health": health,
|
|
"resource_usage": resource_stats
|
|
}
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Health check failed: {e}")
|
|
|
|
|
|
# Global gateway instance
|
|
plugin_gateway = PluginAPIGateway() |