mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
603 lines
23 KiB
Python
603 lines
23 KiB
Python
"""
|
|
Base Plugin Class and Plugin Runtime Environment
|
|
Provides the foundation for all Enclava plugins with security and isolation
|
|
"""
|
|
from abc import ABC, abstractmethod
|
|
from typing import Dict, Any, List, Optional, Tuple
|
|
from dataclasses import dataclass
|
|
from fastapi import APIRouter, Request, HTTPException, Depends
|
|
import asyncio
|
|
import aiohttp
|
|
import logging
|
|
import time
|
|
import json
|
|
from pathlib import Path
|
|
import importlib.util
|
|
import sys
|
|
|
|
from app.schemas.plugin_manifest import PluginManifest, PluginManifestValidator
|
|
from app.core.logging import get_logger
|
|
from app.core.config import settings
|
|
from app.utils.exceptions import SecurityError, ValidationError
|
|
from app.models.plugin import PluginConfiguration
|
|
from app.models.user import User
|
|
from app.db.database import get_db
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
@dataclass
|
|
class PluginContext:
|
|
"""Plugin execution context with user and authentication info"""
|
|
user_id: Optional[str] = None
|
|
api_key_id: Optional[str] = None
|
|
user_permissions: List[str] = None
|
|
ip_address: Optional[str] = None
|
|
user_agent: Optional[str] = None
|
|
request_id: Optional[str] = None
|
|
|
|
|
|
class PlatformAPIClient:
|
|
"""Secure client for plugins to access platform APIs"""
|
|
|
|
def __init__(self, plugin_id: str, plugin_token: str):
|
|
self.plugin_id = plugin_id
|
|
self.plugin_token = plugin_token
|
|
self.base_url = settings.INTERNAL_API_URL or "http://localhost:58000"
|
|
self.logger = get_logger(f"plugin.{plugin_id}.api_client")
|
|
|
|
async def _make_request(self, method: str, endpoint: str, **kwargs) -> Dict[str, Any]:
|
|
"""Make authenticated request to platform API"""
|
|
headers = kwargs.setdefault('headers', {})
|
|
headers.update({
|
|
'Authorization': f'Bearer {self.plugin_token}',
|
|
'X-Plugin-ID': self.plugin_id,
|
|
'X-Platform-Client': 'plugin',
|
|
'Content-Type': 'application/json'
|
|
})
|
|
|
|
url = f"{self.base_url}{endpoint}"
|
|
|
|
try:
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.request(method, url, **kwargs) as response:
|
|
if response.status >= 400:
|
|
error_text = await response.text()
|
|
raise HTTPException(
|
|
status_code=response.status,
|
|
detail=f"Platform API error: {error_text}"
|
|
)
|
|
|
|
if response.content_type == 'application/json':
|
|
return await response.json()
|
|
else:
|
|
return {"data": await response.text()}
|
|
|
|
except aiohttp.ClientError as e:
|
|
self.logger.error(f"Platform API client error: {e}")
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail=f"Platform API unavailable: {str(e)}"
|
|
)
|
|
|
|
async def get(self, endpoint: str, **kwargs) -> Dict[str, Any]:
|
|
"""GET request to platform API"""
|
|
return await self._make_request('GET', endpoint, **kwargs)
|
|
|
|
async def post(self, endpoint: str, data: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
|
|
"""POST request to platform API"""
|
|
if data:
|
|
kwargs['json'] = data
|
|
return await self._make_request('POST', endpoint, **kwargs)
|
|
|
|
async def put(self, endpoint: str, data: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
|
|
"""PUT request to platform API"""
|
|
if data:
|
|
kwargs['json'] = data
|
|
return await self._make_request('PUT', endpoint, **kwargs)
|
|
|
|
async def delete(self, endpoint: str, **kwargs) -> Dict[str, Any]:
|
|
"""DELETE request to platform API"""
|
|
return await self._make_request('DELETE', endpoint, **kwargs)
|
|
|
|
# Platform-specific API methods
|
|
async def call_chatbot_api(self, chatbot_id: str, message: str,
|
|
context: Dict[str, Any] = None) -> Dict[str, Any]:
|
|
"""Consume platform chatbot API"""
|
|
return await self.post(
|
|
f"/api/v1/chatbot/external/{chatbot_id}/chat",
|
|
{
|
|
"message": message,
|
|
"context": context or {}
|
|
}
|
|
)
|
|
|
|
async def call_llm_api(self, model: str, messages: List[Dict[str, Any]],
|
|
**kwargs) -> Dict[str, Any]:
|
|
"""Consume platform LLM API"""
|
|
return await self.post(
|
|
"/api/v1/llm/chat/completions",
|
|
{
|
|
"model": model,
|
|
"messages": messages,
|
|
**kwargs
|
|
}
|
|
)
|
|
|
|
async def search_rag(self, collection: str, query: str,
|
|
top_k: int = 5) -> Dict[str, Any]:
|
|
"""Consume platform RAG API"""
|
|
return await self.post(
|
|
f"/api/v1/rag/collections/{collection}/search",
|
|
{
|
|
"query": query,
|
|
"top_k": top_k
|
|
}
|
|
)
|
|
|
|
async def get_embeddings(self, model: str, input_text: str) -> Dict[str, Any]:
|
|
"""Generate embeddings via platform API"""
|
|
return await self.post(
|
|
"/api/v1/llm/embeddings",
|
|
{
|
|
"model": model,
|
|
"input": input_text
|
|
}
|
|
)
|
|
|
|
|
|
class PluginConfigManager:
|
|
"""Manages plugin configuration with validation and encryption"""
|
|
|
|
def __init__(self, plugin_id: str):
|
|
self.plugin_id = plugin_id
|
|
self.logger = get_logger(f"plugin.{plugin_id}.config")
|
|
|
|
async def get_config(self, user_id: Optional[str] = None) -> Dict[str, Any]:
|
|
"""Get plugin configuration for user (or default)"""
|
|
try:
|
|
# Use dependency injection to get database session
|
|
from app.db.database import SessionLocal
|
|
db = SessionLocal()
|
|
|
|
try:
|
|
# Query for active configuration
|
|
query = db.query(PluginConfiguration).filter(
|
|
PluginConfiguration.plugin_id == self.plugin_id,
|
|
PluginConfiguration.is_active == True
|
|
)
|
|
|
|
if user_id:
|
|
# Get user-specific configuration
|
|
query = query.filter(PluginConfiguration.user_id == user_id)
|
|
else:
|
|
# Get default configuration (is_default=True)
|
|
query = query.filter(PluginConfiguration.is_default == True)
|
|
|
|
config = query.first()
|
|
|
|
if config:
|
|
self.logger.debug(f"Retrieved configuration for plugin {self.plugin_id}, user {user_id}")
|
|
return config.config_data or {}
|
|
else:
|
|
self.logger.debug(f"No configuration found for plugin {self.plugin_id}, user {user_id}")
|
|
return {}
|
|
|
|
finally:
|
|
db.close()
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to get configuration: {e}")
|
|
return {}
|
|
|
|
async def save_config(self, config: Dict[str, Any], user_id: str,
|
|
name: str = "Default Configuration",
|
|
description: str = None) -> bool:
|
|
"""Save plugin configuration for user"""
|
|
try:
|
|
from app.db.database import SessionLocal
|
|
db = SessionLocal()
|
|
|
|
try:
|
|
# Check if configuration already exists
|
|
existing_config = db.query(PluginConfiguration).filter(
|
|
PluginConfiguration.plugin_id == self.plugin_id,
|
|
PluginConfiguration.user_id == user_id,
|
|
PluginConfiguration.name == name
|
|
).first()
|
|
|
|
if existing_config:
|
|
# Update existing configuration
|
|
existing_config.config_data = config
|
|
existing_config.description = description
|
|
existing_config.is_active = True
|
|
|
|
self.logger.info(f"Updated configuration for plugin {self.plugin_id}, user {user_id}")
|
|
else:
|
|
# Create new configuration
|
|
new_config = PluginConfiguration(
|
|
plugin_id=self.plugin_id,
|
|
user_id=user_id,
|
|
name=name,
|
|
description=description,
|
|
config_data=config,
|
|
is_active=True,
|
|
is_default=(name == "Default Configuration"),
|
|
created_by_user_id=user_id
|
|
)
|
|
|
|
# If this is the first configuration for this user/plugin, make it default
|
|
existing_count = db.query(PluginConfiguration).filter(
|
|
PluginConfiguration.plugin_id == self.plugin_id,
|
|
PluginConfiguration.user_id == user_id
|
|
).count()
|
|
|
|
if existing_count == 0:
|
|
new_config.is_default = True
|
|
|
|
db.add(new_config)
|
|
self.logger.info(f"Created new configuration for plugin {self.plugin_id}, user {user_id}")
|
|
|
|
db.commit()
|
|
return True
|
|
|
|
except Exception as e:
|
|
db.rollback()
|
|
self.logger.error(f"Database error saving configuration: {e}")
|
|
return False
|
|
finally:
|
|
db.close()
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to save configuration: {e}")
|
|
return False
|
|
|
|
async def validate_config(self, config: Dict[str, Any],
|
|
schema: Dict[str, Any]) -> Tuple[bool, List[str]]:
|
|
"""Validate configuration against JSON schema"""
|
|
try:
|
|
import jsonschema
|
|
jsonschema.validate(config, schema)
|
|
return True, []
|
|
except jsonschema.ValidationError as e:
|
|
return False, [str(e)]
|
|
except Exception as e:
|
|
return False, [f"Schema validation error: {str(e)}"]
|
|
|
|
|
|
class PluginLogger:
|
|
"""Plugin-specific logger with security filtering"""
|
|
|
|
def __init__(self, plugin_id: str):
|
|
self.plugin_id = plugin_id
|
|
self.logger = get_logger(f"plugin.{plugin_id}")
|
|
|
|
# Sensitive data patterns to filter
|
|
self.sensitive_patterns = [
|
|
r'password', r'token', r'key', r'secret', r'api_key',
|
|
r'bearer', r'authorization', r'credential'
|
|
]
|
|
|
|
def _filter_sensitive_data(self, message: str) -> str:
|
|
"""Filter sensitive data from log messages"""
|
|
import re
|
|
filtered_message = message
|
|
for pattern in self.sensitive_patterns:
|
|
filtered_message = re.sub(
|
|
f'{pattern}[=:]\s*["\']?([^"\'\\s]+)["\']?',
|
|
f'{pattern}=***REDACTED***',
|
|
filtered_message,
|
|
flags=re.IGNORECASE
|
|
)
|
|
return filtered_message
|
|
|
|
def info(self, message: str, **kwargs):
|
|
"""Log info message with sensitive data filtering"""
|
|
filtered_message = self._filter_sensitive_data(message)
|
|
self.logger.info(f"[PLUGIN:{self.plugin_id}] {filtered_message}", **kwargs)
|
|
|
|
def warning(self, message: str, **kwargs):
|
|
"""Log warning message with sensitive data filtering"""
|
|
filtered_message = self._filter_sensitive_data(message)
|
|
self.logger.warning(f"[PLUGIN:{self.plugin_id}] {filtered_message}", **kwargs)
|
|
|
|
def error(self, message: str, **kwargs):
|
|
"""Log error message with sensitive data filtering"""
|
|
filtered_message = self._filter_sensitive_data(message)
|
|
self.logger.error(f"[PLUGIN:{self.plugin_id}] {filtered_message}", **kwargs)
|
|
|
|
def debug(self, message: str, **kwargs):
|
|
"""Log debug message with sensitive data filtering"""
|
|
filtered_message = self._filter_sensitive_data(message)
|
|
self.logger.debug(f"[PLUGIN:{self.plugin_id}] {filtered_message}", **kwargs)
|
|
|
|
|
|
class BasePlugin(ABC):
|
|
"""Base class for all Enclava plugins with security and isolation"""
|
|
|
|
def __init__(self, manifest: PluginManifest, plugin_token: str):
|
|
self.manifest = manifest
|
|
self.plugin_id = manifest.metadata.name
|
|
self.version = manifest.metadata.version
|
|
|
|
# Initialize plugin services
|
|
self.api_client = PlatformAPIClient(self.plugin_id, plugin_token)
|
|
self.config = PluginConfigManager(self.plugin_id)
|
|
self.logger = PluginLogger(self.plugin_id)
|
|
|
|
# Plugin state
|
|
self.initialized = False
|
|
self._startup_time = time.time()
|
|
self._request_count = 0
|
|
self._error_count = 0
|
|
|
|
self.logger.info(f"Plugin {self.plugin_id} v{self.version} instantiated")
|
|
|
|
@abstractmethod
|
|
def get_api_router(self) -> APIRouter:
|
|
"""Return FastAPI router for plugin endpoints"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def initialize(self) -> bool:
|
|
"""Initialize plugin resources and connections"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def cleanup(self) -> bool:
|
|
"""Cleanup plugin resources on shutdown"""
|
|
pass
|
|
|
|
async def health_check(self) -> Dict[str, Any]:
|
|
"""Plugin health status"""
|
|
uptime = time.time() - self._startup_time
|
|
error_rate = self._error_count / max(self._request_count, 1)
|
|
|
|
return {
|
|
"status": "healthy" if error_rate < 0.1 else "warning",
|
|
"plugin": self.plugin_id,
|
|
"version": self.version,
|
|
"uptime_seconds": round(uptime, 2),
|
|
"request_count": self._request_count,
|
|
"error_count": self._error_count,
|
|
"error_rate": round(error_rate, 3),
|
|
"initialized": self.initialized
|
|
}
|
|
|
|
async def get_configuration_schema(self) -> Dict[str, Any]:
|
|
"""Return JSON schema for plugin configuration"""
|
|
return self.manifest.spec.config_schema
|
|
|
|
async def execute_cron_job(self, job_name: str) -> bool:
|
|
"""Execute scheduled cron job"""
|
|
self.logger.info(f"Executing cron job: {job_name}")
|
|
|
|
# Find job in manifest
|
|
job_spec = None
|
|
for job in self.manifest.spec.cron_jobs:
|
|
if job.name == job_name:
|
|
job_spec = job
|
|
break
|
|
|
|
if not job_spec:
|
|
self.logger.error(f"Cron job not found: {job_name}")
|
|
return False
|
|
|
|
try:
|
|
# Get the function to execute
|
|
if hasattr(self, job_spec.function):
|
|
func = getattr(self, job_spec.function)
|
|
if asyncio.iscoroutinefunction(func):
|
|
result = await func()
|
|
else:
|
|
result = func()
|
|
|
|
self.logger.info(f"Cron job {job_name} completed successfully")
|
|
return bool(result)
|
|
else:
|
|
self.logger.error(f"Cron job function not found: {job_spec.function}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Cron job {job_name} failed: {e}")
|
|
self._error_count += 1
|
|
return False
|
|
|
|
def get_auth_context(self) -> PluginContext:
|
|
"""Dependency to get authentication context in API endpoints"""
|
|
async def _get_context(request: Request) -> PluginContext:
|
|
# Extract authentication info from request
|
|
# This would be populated by the plugin API gateway
|
|
return PluginContext(
|
|
user_id=request.headers.get('X-User-ID'),
|
|
api_key_id=request.headers.get('X-API-Key-ID'),
|
|
user_permissions=request.headers.get('X-User-Permissions', '').split(','),
|
|
ip_address=request.headers.get('X-Real-IP'),
|
|
user_agent=request.headers.get('User-Agent'),
|
|
request_id=request.headers.get('X-Request-ID')
|
|
)
|
|
|
|
return Depends(_get_context)
|
|
|
|
def _track_request(self, success: bool = True):
|
|
"""Track request metrics"""
|
|
self._request_count += 1
|
|
if not success:
|
|
self._error_count += 1
|
|
|
|
|
|
class PluginSecurityManager:
|
|
"""Manages plugin security and isolation"""
|
|
|
|
BLOCKED_IMPORTS = {
|
|
# Core platform modules
|
|
'app.db', 'app.models', 'app.core', 'app.services',
|
|
'sqlalchemy', 'alembic',
|
|
|
|
# Security sensitive
|
|
'subprocess', 'eval', 'exec', 'compile', '__import__',
|
|
'os.system', 'os.popen', 'os.spawn',
|
|
|
|
# System access
|
|
'socket', 'multiprocessing', 'threading'
|
|
}
|
|
|
|
ALLOWED_IMPORTS = {
|
|
# Standard library
|
|
'asyncio', 'aiohttp', 'json', 'datetime', 'typing', 'pydantic',
|
|
'logging', 'time', 'uuid', 'hashlib', 'base64', 'pathlib',
|
|
're', 'urllib.parse', 'dataclasses', 'enum',
|
|
|
|
# Approved third-party
|
|
'httpx', 'requests', 'pandas', 'numpy', 'yaml',
|
|
|
|
# Plugin framework
|
|
'app.services.base_plugin', 'app.schemas.plugin_manifest'
|
|
}
|
|
|
|
@classmethod
|
|
def validate_plugin_import(cls, import_name: str) -> bool:
|
|
"""Validate if plugin can import a module"""
|
|
# Block dangerous imports
|
|
if any(import_name.startswith(blocked) for blocked in cls.BLOCKED_IMPORTS):
|
|
raise SecurityError(f"Import '{import_name}' not allowed in plugin environment")
|
|
|
|
# Allow explicit safe imports
|
|
if any(import_name.startswith(allowed) for allowed in cls.ALLOWED_IMPORTS):
|
|
return True
|
|
|
|
# Log potentially unsafe imports
|
|
logger = get_logger("plugin.security")
|
|
logger.warning(f"Potentially unsafe import in plugin: {import_name}")
|
|
return True
|
|
|
|
@classmethod
|
|
def create_plugin_sandbox(cls, plugin_id: str) -> Dict[str, Any]:
|
|
"""Create isolated environment for plugin execution"""
|
|
return {
|
|
'max_memory_mb': 128,
|
|
'max_cpu_percent': 25,
|
|
'max_disk_mb': 100,
|
|
'max_api_calls_per_minute': 100,
|
|
'allowed_domains': [], # Will be populated from manifest
|
|
'network_timeout_seconds': 30
|
|
}
|
|
|
|
|
|
class PluginLoader:
|
|
"""Loads and validates plugins from directories"""
|
|
|
|
def __init__(self):
|
|
self.logger = get_logger("plugin.loader")
|
|
self.loaded_plugins: Dict[str, BasePlugin] = {}
|
|
|
|
async def load_plugin(self, plugin_dir: Path, plugin_token: str) -> BasePlugin:
|
|
"""Load a plugin from a directory"""
|
|
self.logger.info(f"Loading plugin from: {plugin_dir}")
|
|
|
|
# Load and validate manifest
|
|
manifest_path = plugin_dir / "manifest.yaml"
|
|
validation_result = validate_manifest_file(manifest_path)
|
|
|
|
if not validation_result["valid"]:
|
|
raise ValidationError(f"Invalid plugin manifest: {validation_result['errors']}")
|
|
|
|
manifest = validation_result["manifest"]
|
|
|
|
# Check compatibility
|
|
compatibility = validation_result["compatibility"]
|
|
if not compatibility["compatible"]:
|
|
raise ValidationError(f"Plugin incompatible: {compatibility['errors']}")
|
|
|
|
# Load plugin module
|
|
main_py_path = plugin_dir / "main.py"
|
|
spec = importlib.util.spec_from_file_location(
|
|
f"plugin_{manifest.metadata.name}",
|
|
main_py_path
|
|
)
|
|
|
|
if not spec or not spec.loader:
|
|
raise ValidationError(f"Cannot load plugin module: {main_py_path}")
|
|
|
|
# Security check before loading
|
|
self._validate_plugin_security(main_py_path)
|
|
|
|
# Load module
|
|
plugin_module = importlib.util.module_from_spec(spec)
|
|
|
|
# Add to sys.modules to allow imports
|
|
sys.modules[spec.name] = plugin_module
|
|
|
|
try:
|
|
spec.loader.exec_module(plugin_module)
|
|
except Exception as e:
|
|
raise ValidationError(f"Failed to execute plugin module: {e}")
|
|
|
|
# Find plugin class
|
|
plugin_class = None
|
|
for attr_name in dir(plugin_module):
|
|
attr = getattr(plugin_module, attr_name)
|
|
if (isinstance(attr, type) and
|
|
issubclass(attr, BasePlugin) and
|
|
attr is not BasePlugin):
|
|
plugin_class = attr
|
|
break
|
|
|
|
if not plugin_class:
|
|
raise ValidationError("Plugin must contain a class inheriting from BasePlugin")
|
|
|
|
# Instantiate plugin
|
|
plugin_instance = plugin_class(manifest, plugin_token)
|
|
|
|
# Initialize plugin
|
|
try:
|
|
await plugin_instance.initialize()
|
|
plugin_instance.initialized = True
|
|
except Exception as e:
|
|
raise ValidationError(f"Plugin initialization failed: {e}")
|
|
|
|
self.loaded_plugins[manifest.metadata.name] = plugin_instance
|
|
self.logger.info(f"Plugin {manifest.metadata.name} loaded successfully")
|
|
|
|
return plugin_instance
|
|
|
|
def _validate_plugin_security(self, main_py_path: Path):
|
|
"""Validate plugin code for security issues"""
|
|
with open(main_py_path, 'r', encoding='utf-8') as f:
|
|
code_content = f.read()
|
|
|
|
# Check for dangerous patterns
|
|
dangerous_patterns = [
|
|
'eval(', 'exec(', 'compile(',
|
|
'subprocess.', 'os.system', 'os.popen',
|
|
'__import__', 'importlib.import_module',
|
|
'from app.db', 'from app.models',
|
|
'sqlalchemy', 'SessionLocal'
|
|
]
|
|
|
|
for pattern in dangerous_patterns:
|
|
if pattern in code_content:
|
|
raise SecurityError(f"Dangerous pattern detected in plugin code: {pattern}")
|
|
|
|
async def unload_plugin(self, plugin_id: str) -> bool:
|
|
"""Unload a plugin and cleanup resources"""
|
|
if plugin_id not in self.loaded_plugins:
|
|
return False
|
|
|
|
plugin = self.loaded_plugins[plugin_id]
|
|
|
|
try:
|
|
await plugin.cleanup()
|
|
del self.loaded_plugins[plugin_id]
|
|
self.logger.info(f"Plugin {plugin_id} unloaded successfully")
|
|
return True
|
|
except Exception as e:
|
|
self.logger.error(f"Error unloading plugin {plugin_id}: {e}")
|
|
return False
|
|
|
|
def get_plugin(self, plugin_id: str) -> Optional[BasePlugin]:
|
|
"""Get loaded plugin by ID"""
|
|
return self.loaded_plugins.get(plugin_id)
|
|
|
|
def list_loaded_plugins(self) -> List[str]:
|
|
"""List all loaded plugin IDs"""
|
|
return list(self.loaded_plugins.keys()) |