mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 15:34:36 +01:00
492 lines
16 KiB
Python
492 lines
16 KiB
Python
"""
|
|
Base module interface and interceptor pattern implementation
|
|
"""
|
|
from abc import ABC, abstractmethod
|
|
from typing import Dict, List, Any, Optional, Tuple
|
|
from dataclasses import dataclass
|
|
from fastapi import Request, Response
|
|
import json
|
|
import re
|
|
import copy
|
|
import time
|
|
import hashlib
|
|
from urllib.parse import urlparse
|
|
|
|
from app.core.logging import get_logger
|
|
from app.utils.exceptions import ValidationError, AuthenticationError, RateLimitExceeded
|
|
from app.services.permission_manager import permission_registry
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class Permission:
|
|
"""Represents a module permission"""
|
|
|
|
resource: str
|
|
action: str
|
|
description: str
|
|
|
|
def __str__(self) -> str:
|
|
return f"{self.resource}:{self.action}"
|
|
|
|
|
|
@dataclass
|
|
class ModuleMetrics:
|
|
"""Module performance metrics"""
|
|
|
|
requests_processed: int = 0
|
|
average_response_time: float = 0.0
|
|
error_rate: float = 0.0
|
|
last_activity: Optional[str] = None
|
|
total_errors: int = 0
|
|
uptime_start: float = 0.0
|
|
|
|
def __post_init__(self):
|
|
if self.uptime_start == 0.0:
|
|
self.uptime_start = time.time()
|
|
|
|
|
|
@dataclass
|
|
class ModuleHealth:
|
|
"""Module health status"""
|
|
|
|
status: str = "healthy" # healthy, warning, error
|
|
message: str = "Module is functioning normally"
|
|
uptime: float = 0.0
|
|
last_check: float = 0.0
|
|
|
|
def __post_init__(self):
|
|
if self.last_check == 0.0:
|
|
self.last_check = time.time()
|
|
|
|
|
|
class BaseModule(ABC):
|
|
"""Base class for all modules with interceptor pattern support"""
|
|
|
|
def __init__(self, module_id: str, config: Dict[str, Any] = None):
|
|
self.module_id = module_id
|
|
self.config = config or {}
|
|
self.metrics = ModuleMetrics()
|
|
self.health = ModuleHealth()
|
|
self.initialized = False
|
|
self.interceptors: List["ModuleInterceptor"] = []
|
|
|
|
# Register default interceptors
|
|
self._register_default_interceptors()
|
|
|
|
@abstractmethod
|
|
async def initialize(self) -> None:
|
|
"""Initialize the module"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def cleanup(self) -> None:
|
|
"""Cleanup module resources"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_required_permissions(self) -> List[Permission]:
|
|
"""Return list of permissions this module requires"""
|
|
return []
|
|
|
|
@abstractmethod
|
|
async def process_request(
|
|
self, request: Dict[str, Any], context: Dict[str, Any]
|
|
) -> Dict[str, Any]:
|
|
"""Process a module request"""
|
|
pass
|
|
|
|
def get_health(self) -> ModuleHealth:
|
|
"""Get current module health status"""
|
|
self.health.uptime = time.time() - self.metrics.uptime_start
|
|
self.health.last_check = time.time()
|
|
return self.health
|
|
|
|
def get_metrics(self) -> ModuleMetrics:
|
|
"""Get current module metrics"""
|
|
return self.metrics
|
|
|
|
def check_access(self, user_permissions: List[str], action: str) -> bool:
|
|
"""Check if user can perform action on this module"""
|
|
required = f"modules:{self.module_id}:{action}"
|
|
return permission_registry.check_permission(user_permissions, required)
|
|
|
|
def _register_default_interceptors(self):
|
|
"""Register default interceptors for all modules"""
|
|
self.interceptors = [
|
|
AuthenticationInterceptor(),
|
|
PermissionInterceptor(self),
|
|
ValidationInterceptor(),
|
|
MetricsInterceptor(self),
|
|
SecurityInterceptor(),
|
|
AuditInterceptor(self),
|
|
]
|
|
|
|
async def execute_with_interceptors(
|
|
self, request: Dict[str, Any], context: Dict[str, Any]
|
|
) -> Dict[str, Any]:
|
|
"""Execute request through interceptor chain"""
|
|
start_time = time.time()
|
|
|
|
try:
|
|
# Pre-processing interceptors
|
|
for interceptor in self.interceptors:
|
|
request, context = await interceptor.pre_process(request, context)
|
|
|
|
# Execute main module logic
|
|
response = await self.process_request(request, context)
|
|
|
|
# Post-processing interceptors (in reverse order)
|
|
for interceptor in reversed(self.interceptors):
|
|
response = await interceptor.post_process(request, context, response)
|
|
|
|
# Update metrics
|
|
self._update_metrics(start_time, success=True)
|
|
|
|
return response
|
|
|
|
except Exception as e:
|
|
# Update error metrics
|
|
self._update_metrics(start_time, success=False, error=str(e))
|
|
|
|
# Error handling interceptors
|
|
for interceptor in reversed(self.interceptors):
|
|
if hasattr(interceptor, "handle_error"):
|
|
await interceptor.handle_error(request, context, e)
|
|
|
|
raise
|
|
|
|
def _update_metrics(self, start_time: float, success: bool, error: str = None):
|
|
"""Update module metrics"""
|
|
duration = time.time() - start_time
|
|
|
|
self.metrics.requests_processed += 1
|
|
|
|
# Update average response time
|
|
if self.metrics.requests_processed == 1:
|
|
self.metrics.average_response_time = duration
|
|
else:
|
|
# Exponential moving average
|
|
alpha = 0.1
|
|
self.metrics.average_response_time = (
|
|
alpha * duration + (1 - alpha) * self.metrics.average_response_time
|
|
)
|
|
|
|
if not success:
|
|
self.metrics.total_errors += 1
|
|
self.metrics.error_rate = (
|
|
self.metrics.total_errors / self.metrics.requests_processed
|
|
)
|
|
|
|
# Update health status based on error rate
|
|
if self.metrics.error_rate > 0.1: # More than 10% error rate
|
|
self.health.status = "error"
|
|
self.health.message = f"High error rate: {self.metrics.error_rate:.2%}"
|
|
elif self.metrics.error_rate > 0.05: # More than 5% error rate
|
|
self.health.status = "warning"
|
|
self.health.message = (
|
|
f"Elevated error rate: {self.metrics.error_rate:.2%}"
|
|
)
|
|
else:
|
|
self.metrics.error_rate = (
|
|
self.metrics.total_errors / self.metrics.requests_processed
|
|
)
|
|
if self.metrics.error_rate <= 0.05:
|
|
self.health.status = "healthy"
|
|
self.health.message = "Module is functioning normally"
|
|
|
|
self.metrics.last_activity = time.strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
|
|
class ModuleInterceptor(ABC):
|
|
"""Base class for module interceptors"""
|
|
|
|
@abstractmethod
|
|
async def pre_process(
|
|
self, request: Dict[str, Any], context: Dict[str, Any]
|
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
"""Pre-process the request"""
|
|
return request, context
|
|
|
|
@abstractmethod
|
|
async def post_process(
|
|
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
|
|
) -> Dict[str, Any]:
|
|
"""Post-process the response"""
|
|
return response
|
|
|
|
|
|
class AuthenticationInterceptor(ModuleInterceptor):
|
|
"""Handles authentication for module requests"""
|
|
|
|
async def pre_process(
|
|
self, request: Dict[str, Any], context: Dict[str, Any]
|
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
# Check if user is authenticated (context should contain user info from API auth)
|
|
if not context.get("user_id") and not context.get("api_key_id"):
|
|
raise AuthenticationError("Authentication required for module access")
|
|
|
|
return request, context
|
|
|
|
async def post_process(
|
|
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
|
|
) -> Dict[str, Any]:
|
|
return response
|
|
|
|
|
|
class PermissionInterceptor(ModuleInterceptor):
|
|
"""Handles permission checking for module requests"""
|
|
|
|
def __init__(self, module: BaseModule):
|
|
self.module = module
|
|
|
|
async def pre_process(
|
|
self, request: Dict[str, Any], context: Dict[str, Any]
|
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
action = request.get("action", "execute")
|
|
user_permissions = context.get("user_permissions", [])
|
|
|
|
if not self.module.check_access(user_permissions, action):
|
|
raise AuthenticationError(
|
|
f"Insufficient permissions for module action: {action}"
|
|
)
|
|
|
|
return request, context
|
|
|
|
async def post_process(
|
|
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
|
|
) -> Dict[str, Any]:
|
|
return response
|
|
|
|
|
|
class ValidationInterceptor(ModuleInterceptor):
|
|
"""Handles request validation and sanitization"""
|
|
|
|
async def pre_process(
|
|
self, request: Dict[str, Any], context: Dict[str, Any]
|
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
# Sanitize request data
|
|
sanitized_request = self._sanitize_request(request)
|
|
|
|
# Validate request structure
|
|
self._validate_request(sanitized_request)
|
|
|
|
return sanitized_request, context
|
|
|
|
async def post_process(
|
|
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
|
|
) -> Dict[str, Any]:
|
|
# Sanitize response data
|
|
return self._sanitize_response(response)
|
|
|
|
def _sanitize_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Remove potentially dangerous content from request"""
|
|
sanitized = copy.deepcopy(request)
|
|
|
|
# Define dangerous patterns
|
|
dangerous_patterns = [
|
|
r"<script[^>]*>.*?</script>",
|
|
r"javascript:",
|
|
r"data:text/html",
|
|
r"vbscript:",
|
|
r"onload\s*=",
|
|
r"onerror\s*=",
|
|
r"eval\s*\(",
|
|
r"Function\s*\(",
|
|
]
|
|
|
|
def sanitize_value(value):
|
|
if isinstance(value, str):
|
|
# Remove dangerous patterns
|
|
for pattern in dangerous_patterns:
|
|
value = re.sub(pattern, "", value, flags=re.IGNORECASE)
|
|
|
|
# Limit string length
|
|
max_length = 10000
|
|
if len(value) > max_length:
|
|
value = value[:max_length]
|
|
logger.warning(
|
|
f"Truncated long string in request: {len(value)} chars"
|
|
)
|
|
|
|
return value
|
|
elif isinstance(value, dict):
|
|
return {k: sanitize_value(v) for k, v in value.items()}
|
|
elif isinstance(value, list):
|
|
return [sanitize_value(item) for item in value]
|
|
else:
|
|
return value
|
|
|
|
return sanitize_value(sanitized)
|
|
|
|
def _sanitize_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Sanitize response data"""
|
|
# Similar sanitization for responses
|
|
return self._sanitize_request(response)
|
|
|
|
def _validate_request(self, request: Dict[str, Any]):
|
|
"""Validate request structure"""
|
|
# Check for required fields
|
|
if not isinstance(request, dict):
|
|
raise ValidationError("Request must be a dictionary")
|
|
|
|
# Check request size
|
|
request_str = json.dumps(request)
|
|
max_size = 10 * 1024 * 1024 # 10MB
|
|
if len(request_str.encode()) > max_size:
|
|
raise ValidationError(
|
|
f"Request size exceeds maximum allowed ({max_size} bytes)"
|
|
)
|
|
|
|
|
|
class MetricsInterceptor(ModuleInterceptor):
|
|
"""Handles metrics collection for module requests"""
|
|
|
|
def __init__(self, module: BaseModule):
|
|
self.module = module
|
|
|
|
async def pre_process(
|
|
self, request: Dict[str, Any], context: Dict[str, Any]
|
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
context["_metrics_start_time"] = time.time()
|
|
return request, context
|
|
|
|
async def post_process(
|
|
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
|
|
) -> Dict[str, Any]:
|
|
# Metrics are updated in the base module execute_with_interceptors method
|
|
return response
|
|
|
|
|
|
class SecurityInterceptor(ModuleInterceptor):
|
|
"""Handles security-related processing"""
|
|
|
|
async def pre_process(
|
|
self, request: Dict[str, Any], context: Dict[str, Any]
|
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
# Add security headers to context
|
|
context["security_headers"] = {
|
|
"X-Content-Type-Options": "nosniff",
|
|
"X-Frame-Options": "DENY",
|
|
"X-XSS-Protection": "1; mode=block",
|
|
"Strict-Transport-Security": "max-age=31536000; includeSubDomains",
|
|
}
|
|
|
|
# Check for suspicious patterns
|
|
self._check_security_patterns(request)
|
|
|
|
return request, context
|
|
|
|
async def post_process(
|
|
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
|
|
) -> Dict[str, Any]:
|
|
# Remove any sensitive information from response
|
|
return self._remove_sensitive_data(response)
|
|
|
|
def _check_security_patterns(self, request: Dict[str, Any]):
|
|
"""Check for suspicious security patterns"""
|
|
request_str = json.dumps(request).lower()
|
|
|
|
suspicious_patterns = [
|
|
"union select",
|
|
"drop table",
|
|
"insert into",
|
|
"delete from",
|
|
"script>",
|
|
"javascript:",
|
|
"eval(",
|
|
"expression(",
|
|
"../",
|
|
"..\\",
|
|
"file://",
|
|
"ftp://",
|
|
]
|
|
|
|
for pattern in suspicious_patterns:
|
|
if pattern in request_str:
|
|
logger.warning(f"Suspicious pattern detected in request: {pattern}")
|
|
# Could implement additional security measures here
|
|
|
|
def _remove_sensitive_data(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Remove sensitive data from response"""
|
|
sensitive_keys = ["password", "secret", "token", "key", "private"]
|
|
|
|
def clean_dict(obj):
|
|
if isinstance(obj, dict):
|
|
return {
|
|
k: "***REDACTED***"
|
|
if any(sk in k.lower() for sk in sensitive_keys)
|
|
else clean_dict(v)
|
|
for k, v in obj.items()
|
|
}
|
|
elif isinstance(obj, list):
|
|
return [clean_dict(item) for item in obj]
|
|
else:
|
|
return obj
|
|
|
|
return clean_dict(response)
|
|
|
|
|
|
class AuditInterceptor(ModuleInterceptor):
|
|
"""Handles audit logging for module requests"""
|
|
|
|
def __init__(self, module: BaseModule):
|
|
self.module = module
|
|
|
|
async def pre_process(
|
|
self, request: Dict[str, Any], context: Dict[str, Any]
|
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
context["_audit_start_time"] = time.time()
|
|
context["_audit_request_hash"] = self._hash_request(request)
|
|
return request, context
|
|
|
|
async def post_process(
|
|
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
|
|
) -> Dict[str, Any]:
|
|
await self._log_audit_event(request, context, response, success=True)
|
|
return response
|
|
|
|
async def handle_error(
|
|
self, request: Dict[str, Any], context: Dict[str, Any], error: Exception
|
|
):
|
|
"""Handle error logging"""
|
|
await self._log_audit_event(
|
|
request, context, {"error": str(error)}, success=False
|
|
)
|
|
|
|
def _hash_request(self, request: Dict[str, Any]) -> str:
|
|
"""Create a hash of the request for audit purposes"""
|
|
request_str = json.dumps(request, sort_keys=True)
|
|
return hashlib.sha256(request_str.encode()).hexdigest()[:16]
|
|
|
|
async def _log_audit_event(
|
|
self,
|
|
request: Dict[str, Any],
|
|
context: Dict[str, Any],
|
|
response: Dict[str, Any],
|
|
success: bool,
|
|
):
|
|
"""Log audit event"""
|
|
duration = time.time() - context.get("_audit_start_time", time.time())
|
|
|
|
audit_data = {
|
|
"module_id": self.module.module_id,
|
|
"action": request.get("action", "execute"),
|
|
"user_id": context.get("user_id"),
|
|
"api_key_id": context.get("api_key_id"),
|
|
"ip_address": context.get("ip_address"),
|
|
"request_hash": context.get("_audit_request_hash"),
|
|
"success": success,
|
|
"duration_ms": int(duration * 1000),
|
|
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
}
|
|
|
|
if not success:
|
|
audit_data["error"] = response.get("error", "Unknown error")
|
|
|
|
# Log the audit event
|
|
logger.info(f"Module audit: {json.dumps(audit_data)}")
|
|
|
|
# Could also store in database for persistent audit trail
|