mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 15:34:36 +01:00
677 lines
23 KiB
Python
677 lines
23 KiB
Python
"""
|
|
Base Plugin Class and Plugin Runtime Environment
|
|
Provides the foundation for all Enclava plugins with security and isolation
|
|
"""
|
|
from abc import ABC, abstractmethod
|
|
from typing import Dict, Any, List, Optional, Tuple
|
|
from dataclasses import dataclass
|
|
from fastapi import APIRouter, Request, HTTPException, Depends
|
|
import asyncio
|
|
import aiohttp
|
|
import logging
|
|
import time
|
|
import json
|
|
from pathlib import Path
|
|
import importlib.util
|
|
import sys
|
|
|
|
from app.schemas.plugin_manifest import PluginManifest, PluginManifestValidator
|
|
from app.core.logging import get_logger
|
|
from app.core.config import settings
|
|
from app.utils.exceptions import SecurityError, ValidationError
|
|
from app.models.plugin import PluginConfiguration
|
|
from app.models.user import User
|
|
from app.db.database import get_db
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
@dataclass
|
|
class PluginContext:
|
|
"""Plugin execution context with user and authentication info"""
|
|
|
|
user_id: Optional[str] = None
|
|
api_key_id: Optional[str] = None
|
|
user_permissions: List[str] = None
|
|
ip_address: Optional[str] = None
|
|
user_agent: Optional[str] = None
|
|
request_id: Optional[str] = None
|
|
|
|
|
|
class PlatformAPIClient:
|
|
"""Secure client for plugins to access platform APIs"""
|
|
|
|
def __init__(self, plugin_id: str, plugin_token: str):
|
|
self.plugin_id = plugin_id
|
|
self.plugin_token = plugin_token
|
|
self.base_url = settings.INTERNAL_API_URL or "http://localhost:58000"
|
|
self.logger = get_logger(f"plugin.{plugin_id}.api_client")
|
|
|
|
async def _make_request(
|
|
self, method: str, endpoint: str, **kwargs
|
|
) -> Dict[str, Any]:
|
|
"""Make authenticated request to platform API"""
|
|
headers = kwargs.setdefault("headers", {})
|
|
headers.update(
|
|
{
|
|
"Authorization": f"Bearer {self.plugin_token}",
|
|
"X-Plugin-ID": self.plugin_id,
|
|
"X-Platform-Client": "plugin",
|
|
"Content-Type": "application/json",
|
|
}
|
|
)
|
|
|
|
url = f"{self.base_url}{endpoint}"
|
|
|
|
try:
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.request(method, url, **kwargs) as response:
|
|
if response.status >= 400:
|
|
error_text = await response.text()
|
|
raise HTTPException(
|
|
status_code=response.status,
|
|
detail=f"Platform API error: {error_text}",
|
|
)
|
|
|
|
if response.content_type == "application/json":
|
|
return await response.json()
|
|
else:
|
|
return {"data": await response.text()}
|
|
|
|
except aiohttp.ClientError as e:
|
|
self.logger.error(f"Platform API client error: {e}")
|
|
raise HTTPException(
|
|
status_code=503, detail=f"Platform API unavailable: {str(e)}"
|
|
)
|
|
|
|
async def get(self, endpoint: str, **kwargs) -> Dict[str, Any]:
|
|
"""GET request to platform API"""
|
|
return await self._make_request("GET", endpoint, **kwargs)
|
|
|
|
async def post(
|
|
self, endpoint: str, data: Dict[str, Any] = None, **kwargs
|
|
) -> Dict[str, Any]:
|
|
"""POST request to platform API"""
|
|
if data:
|
|
kwargs["json"] = data
|
|
return await self._make_request("POST", endpoint, **kwargs)
|
|
|
|
async def put(
|
|
self, endpoint: str, data: Dict[str, Any] = None, **kwargs
|
|
) -> Dict[str, Any]:
|
|
"""PUT request to platform API"""
|
|
if data:
|
|
kwargs["json"] = data
|
|
return await self._make_request("PUT", endpoint, **kwargs)
|
|
|
|
async def delete(self, endpoint: str, **kwargs) -> Dict[str, Any]:
|
|
"""DELETE request to platform API"""
|
|
return await self._make_request("DELETE", endpoint, **kwargs)
|
|
|
|
# Platform-specific API methods
|
|
async def call_chatbot_api(
|
|
self, chatbot_id: str, message: str, context: Dict[str, Any] = None
|
|
) -> Dict[str, Any]:
|
|
"""Consume platform chatbot API"""
|
|
return await self.post(
|
|
f"/api/v1/chatbot/external/{chatbot_id}/chat",
|
|
{"message": message, "context": context or {}},
|
|
)
|
|
|
|
async def call_llm_api(
|
|
self, model: str, messages: List[Dict[str, Any]], **kwargs
|
|
) -> Dict[str, Any]:
|
|
"""Consume platform LLM API"""
|
|
return await self.post(
|
|
"/api/v1/llm/chat/completions",
|
|
{"model": model, "messages": messages, **kwargs},
|
|
)
|
|
|
|
async def search_rag(
|
|
self, collection: str, query: str, top_k: int = 5
|
|
) -> Dict[str, Any]:
|
|
"""Consume platform RAG API"""
|
|
return await self.post(
|
|
f"/api/v1/rag/collections/{collection}/search",
|
|
{"query": query, "top_k": top_k},
|
|
)
|
|
|
|
async def get_embeddings(self, model: str, input_text: str) -> Dict[str, Any]:
|
|
"""Generate embeddings via platform API"""
|
|
return await self.post(
|
|
"/api/v1/llm/embeddings", {"model": model, "input": input_text}
|
|
)
|
|
|
|
|
|
class PluginConfigManager:
|
|
"""Manages plugin configuration with validation and encryption"""
|
|
|
|
def __init__(self, plugin_id: str):
|
|
self.plugin_id = plugin_id
|
|
self.logger = get_logger(f"plugin.{plugin_id}.config")
|
|
|
|
async def get_config(self, user_id: Optional[str] = None) -> Dict[str, Any]:
|
|
"""Get plugin configuration for user (or default)"""
|
|
try:
|
|
# Use dependency injection to get database session
|
|
from app.db.database import SessionLocal
|
|
|
|
db = SessionLocal()
|
|
|
|
try:
|
|
# Query for active configuration
|
|
query = db.query(PluginConfiguration).filter(
|
|
PluginConfiguration.plugin_id == self.plugin_id,
|
|
PluginConfiguration.is_active == True,
|
|
)
|
|
|
|
if user_id:
|
|
# Get user-specific configuration
|
|
query = query.filter(PluginConfiguration.user_id == user_id)
|
|
else:
|
|
# Get default configuration (is_default=True)
|
|
query = query.filter(PluginConfiguration.is_default == True)
|
|
|
|
config = query.first()
|
|
|
|
if config:
|
|
self.logger.debug(
|
|
f"Retrieved configuration for plugin {self.plugin_id}, user {user_id}"
|
|
)
|
|
return config.config_data or {}
|
|
else:
|
|
self.logger.debug(
|
|
f"No configuration found for plugin {self.plugin_id}, user {user_id}"
|
|
)
|
|
return {}
|
|
|
|
finally:
|
|
db.close()
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to get configuration: {e}")
|
|
return {}
|
|
|
|
async def save_config(
|
|
self,
|
|
config: Dict[str, Any],
|
|
user_id: str,
|
|
name: str = "Default Configuration",
|
|
description: str = None,
|
|
) -> bool:
|
|
"""Save plugin configuration for user"""
|
|
try:
|
|
from app.db.database import SessionLocal
|
|
|
|
db = SessionLocal()
|
|
|
|
try:
|
|
# Check if configuration already exists
|
|
existing_config = (
|
|
db.query(PluginConfiguration)
|
|
.filter(
|
|
PluginConfiguration.plugin_id == self.plugin_id,
|
|
PluginConfiguration.user_id == user_id,
|
|
PluginConfiguration.name == name,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
if existing_config:
|
|
# Update existing configuration
|
|
existing_config.config_data = config
|
|
existing_config.description = description
|
|
existing_config.is_active = True
|
|
|
|
self.logger.info(
|
|
f"Updated configuration for plugin {self.plugin_id}, user {user_id}"
|
|
)
|
|
else:
|
|
# Create new configuration
|
|
new_config = PluginConfiguration(
|
|
plugin_id=self.plugin_id,
|
|
user_id=user_id,
|
|
name=name,
|
|
description=description,
|
|
config_data=config,
|
|
is_active=True,
|
|
is_default=(name == "Default Configuration"),
|
|
created_by_user_id=user_id,
|
|
)
|
|
|
|
# If this is the first configuration for this user/plugin, make it default
|
|
existing_count = (
|
|
db.query(PluginConfiguration)
|
|
.filter(
|
|
PluginConfiguration.plugin_id == self.plugin_id,
|
|
PluginConfiguration.user_id == user_id,
|
|
)
|
|
.count()
|
|
)
|
|
|
|
if existing_count == 0:
|
|
new_config.is_default = True
|
|
|
|
db.add(new_config)
|
|
self.logger.info(
|
|
f"Created new configuration for plugin {self.plugin_id}, user {user_id}"
|
|
)
|
|
|
|
db.commit()
|
|
return True
|
|
|
|
except Exception as e:
|
|
db.rollback()
|
|
self.logger.error(f"Database error saving configuration: {e}")
|
|
return False
|
|
finally:
|
|
db.close()
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to save configuration: {e}")
|
|
return False
|
|
|
|
async def validate_config(
|
|
self, config: Dict[str, Any], schema: Dict[str, Any]
|
|
) -> Tuple[bool, List[str]]:
|
|
"""Validate configuration against JSON schema"""
|
|
try:
|
|
import jsonschema
|
|
|
|
jsonschema.validate(config, schema)
|
|
return True, []
|
|
except jsonschema.ValidationError as e:
|
|
return False, [str(e)]
|
|
except Exception as e:
|
|
return False, [f"Schema validation error: {str(e)}"]
|
|
|
|
|
|
class PluginLogger:
|
|
"""Plugin-specific logger with security filtering"""
|
|
|
|
def __init__(self, plugin_id: str):
|
|
self.plugin_id = plugin_id
|
|
self.logger = get_logger(f"plugin.{plugin_id}")
|
|
|
|
# Sensitive data patterns to filter
|
|
self.sensitive_patterns = [
|
|
r"password",
|
|
r"token",
|
|
r"key",
|
|
r"secret",
|
|
r"api_key",
|
|
r"bearer",
|
|
r"authorization",
|
|
r"credential",
|
|
]
|
|
|
|
def _filter_sensitive_data(self, message: str) -> str:
|
|
"""Filter sensitive data from log messages"""
|
|
import re
|
|
|
|
filtered_message = message
|
|
for pattern in self.sensitive_patterns:
|
|
filtered_message = re.sub(
|
|
f"{pattern}[=:]\s*[\"']?([^\"'\\s]+)[\"']?",
|
|
f"{pattern}=***REDACTED***",
|
|
filtered_message,
|
|
flags=re.IGNORECASE,
|
|
)
|
|
return filtered_message
|
|
|
|
def info(self, message: str, **kwargs):
|
|
"""Log info message with sensitive data filtering"""
|
|
filtered_message = self._filter_sensitive_data(message)
|
|
self.logger.info(f"[PLUGIN:{self.plugin_id}] {filtered_message}", **kwargs)
|
|
|
|
def warning(self, message: str, **kwargs):
|
|
"""Log warning message with sensitive data filtering"""
|
|
filtered_message = self._filter_sensitive_data(message)
|
|
self.logger.warning(f"[PLUGIN:{self.plugin_id}] {filtered_message}", **kwargs)
|
|
|
|
def error(self, message: str, **kwargs):
|
|
"""Log error message with sensitive data filtering"""
|
|
filtered_message = self._filter_sensitive_data(message)
|
|
self.logger.error(f"[PLUGIN:{self.plugin_id}] {filtered_message}", **kwargs)
|
|
|
|
def debug(self, message: str, **kwargs):
|
|
"""Log debug message with sensitive data filtering"""
|
|
filtered_message = self._filter_sensitive_data(message)
|
|
self.logger.debug(f"[PLUGIN:{self.plugin_id}] {filtered_message}", **kwargs)
|
|
|
|
|
|
class BasePlugin(ABC):
|
|
"""Base class for all Enclava plugins with security and isolation"""
|
|
|
|
def __init__(self, manifest: PluginManifest, plugin_token: str):
|
|
self.manifest = manifest
|
|
self.plugin_id = manifest.metadata.name
|
|
self.version = manifest.metadata.version
|
|
|
|
# Initialize plugin services
|
|
self.api_client = PlatformAPIClient(self.plugin_id, plugin_token)
|
|
self.config = PluginConfigManager(self.plugin_id)
|
|
self.logger = PluginLogger(self.plugin_id)
|
|
|
|
# Plugin state
|
|
self.initialized = False
|
|
self._startup_time = time.time()
|
|
self._request_count = 0
|
|
self._error_count = 0
|
|
|
|
self.logger.info(f"Plugin {self.plugin_id} v{self.version} instantiated")
|
|
|
|
@abstractmethod
|
|
def get_api_router(self) -> APIRouter:
|
|
"""Return FastAPI router for plugin endpoints"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def initialize(self) -> bool:
|
|
"""Initialize plugin resources and connections"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def cleanup(self) -> bool:
|
|
"""Cleanup plugin resources on shutdown"""
|
|
pass
|
|
|
|
async def health_check(self) -> Dict[str, Any]:
|
|
"""Plugin health status"""
|
|
uptime = time.time() - self._startup_time
|
|
error_rate = self._error_count / max(self._request_count, 1)
|
|
|
|
return {
|
|
"status": "healthy" if error_rate < 0.1 else "warning",
|
|
"plugin": self.plugin_id,
|
|
"version": self.version,
|
|
"uptime_seconds": round(uptime, 2),
|
|
"request_count": self._request_count,
|
|
"error_count": self._error_count,
|
|
"error_rate": round(error_rate, 3),
|
|
"initialized": self.initialized,
|
|
}
|
|
|
|
async def get_configuration_schema(self) -> Dict[str, Any]:
|
|
"""Return JSON schema for plugin configuration"""
|
|
return self.manifest.spec.config_schema
|
|
|
|
async def execute_cron_job(self, job_name: str) -> bool:
|
|
"""Execute scheduled cron job"""
|
|
self.logger.info(f"Executing cron job: {job_name}")
|
|
|
|
# Find job in manifest
|
|
job_spec = None
|
|
for job in self.manifest.spec.cron_jobs:
|
|
if job.name == job_name:
|
|
job_spec = job
|
|
break
|
|
|
|
if not job_spec:
|
|
self.logger.error(f"Cron job not found: {job_name}")
|
|
return False
|
|
|
|
try:
|
|
# Get the function to execute
|
|
if hasattr(self, job_spec.function):
|
|
func = getattr(self, job_spec.function)
|
|
if asyncio.iscoroutinefunction(func):
|
|
result = await func()
|
|
else:
|
|
result = func()
|
|
|
|
self.logger.info(f"Cron job {job_name} completed successfully")
|
|
return bool(result)
|
|
else:
|
|
self.logger.error(f"Cron job function not found: {job_spec.function}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Cron job {job_name} failed: {e}")
|
|
self._error_count += 1
|
|
return False
|
|
|
|
def get_auth_context(self) -> PluginContext:
|
|
"""Dependency to get authentication context in API endpoints"""
|
|
|
|
async def _get_context(request: Request) -> PluginContext:
|
|
# Extract authentication info from request
|
|
# This would be populated by the plugin API gateway
|
|
return PluginContext(
|
|
user_id=request.headers.get("X-User-ID"),
|
|
api_key_id=request.headers.get("X-API-Key-ID"),
|
|
user_permissions=request.headers.get("X-User-Permissions", "").split(
|
|
","
|
|
),
|
|
ip_address=request.headers.get("X-Real-IP"),
|
|
user_agent=request.headers.get("User-Agent"),
|
|
request_id=request.headers.get("X-Request-ID"),
|
|
)
|
|
|
|
return Depends(_get_context)
|
|
|
|
def _track_request(self, success: bool = True):
|
|
"""Track request metrics"""
|
|
self._request_count += 1
|
|
if not success:
|
|
self._error_count += 1
|
|
|
|
|
|
class PluginSecurityManager:
|
|
"""Manages plugin security and isolation"""
|
|
|
|
BLOCKED_IMPORTS = {
|
|
# Core platform modules
|
|
"app.db",
|
|
"app.models",
|
|
"app.core",
|
|
"app.services",
|
|
"sqlalchemy",
|
|
"alembic",
|
|
# Security sensitive
|
|
"subprocess",
|
|
"eval",
|
|
"exec",
|
|
"compile",
|
|
"__import__",
|
|
"os.system",
|
|
"os.popen",
|
|
"os.spawn",
|
|
# System access
|
|
"socket",
|
|
"multiprocessing",
|
|
"threading",
|
|
}
|
|
|
|
ALLOWED_IMPORTS = {
|
|
# Standard library
|
|
"asyncio",
|
|
"aiohttp",
|
|
"json",
|
|
"datetime",
|
|
"typing",
|
|
"pydantic",
|
|
"logging",
|
|
"time",
|
|
"uuid",
|
|
"hashlib",
|
|
"base64",
|
|
"pathlib",
|
|
"re",
|
|
"urllib.parse",
|
|
"dataclasses",
|
|
"enum",
|
|
# Approved third-party
|
|
"httpx",
|
|
"requests",
|
|
"pandas",
|
|
"numpy",
|
|
"yaml",
|
|
# Plugin framework
|
|
"app.services.base_plugin",
|
|
"app.schemas.plugin_manifest",
|
|
}
|
|
|
|
@classmethod
|
|
def validate_plugin_import(cls, import_name: str) -> bool:
|
|
"""Validate if plugin can import a module"""
|
|
# Block dangerous imports
|
|
if any(import_name.startswith(blocked) for blocked in cls.BLOCKED_IMPORTS):
|
|
raise SecurityError(
|
|
f"Import '{import_name}' not allowed in plugin environment"
|
|
)
|
|
|
|
# Allow explicit safe imports
|
|
if any(import_name.startswith(allowed) for allowed in cls.ALLOWED_IMPORTS):
|
|
return True
|
|
|
|
# Log potentially unsafe imports
|
|
logger = get_logger("plugin.security")
|
|
logger.warning(f"Potentially unsafe import in plugin: {import_name}")
|
|
return True
|
|
|
|
@classmethod
|
|
def create_plugin_sandbox(cls, plugin_id: str) -> Dict[str, Any]:
|
|
"""Create isolated environment for plugin execution"""
|
|
return {
|
|
"max_memory_mb": 128,
|
|
"max_cpu_percent": 25,
|
|
"max_disk_mb": 100,
|
|
"max_api_calls_per_minute": 100,
|
|
"allowed_domains": [], # Will be populated from manifest
|
|
"network_timeout_seconds": 30,
|
|
}
|
|
|
|
|
|
class PluginLoader:
|
|
"""Loads and validates plugins from directories"""
|
|
|
|
def __init__(self):
|
|
self.logger = get_logger("plugin.loader")
|
|
self.loaded_plugins: Dict[str, BasePlugin] = {}
|
|
|
|
async def load_plugin(self, plugin_dir: Path, plugin_token: str) -> BasePlugin:
|
|
"""Load a plugin from a directory"""
|
|
self.logger.info(f"Loading plugin from: {plugin_dir}")
|
|
|
|
# Load and validate manifest
|
|
manifest_path = plugin_dir / "manifest.yaml"
|
|
validation_result = validate_manifest_file(manifest_path)
|
|
|
|
if not validation_result["valid"]:
|
|
raise ValidationError(
|
|
f"Invalid plugin manifest: {validation_result['errors']}"
|
|
)
|
|
|
|
manifest = validation_result["manifest"]
|
|
|
|
# Check compatibility
|
|
compatibility = validation_result["compatibility"]
|
|
if not compatibility["compatible"]:
|
|
raise ValidationError(f"Plugin incompatible: {compatibility['errors']}")
|
|
|
|
# Load plugin module
|
|
main_py_path = plugin_dir / "main.py"
|
|
spec = importlib.util.spec_from_file_location(
|
|
f"plugin_{manifest.metadata.name}", main_py_path
|
|
)
|
|
|
|
if not spec or not spec.loader:
|
|
raise ValidationError(f"Cannot load plugin module: {main_py_path}")
|
|
|
|
# Security check before loading
|
|
self._validate_plugin_security(main_py_path)
|
|
|
|
# Load module
|
|
plugin_module = importlib.util.module_from_spec(spec)
|
|
|
|
# Add to sys.modules to allow imports
|
|
sys.modules[spec.name] = plugin_module
|
|
|
|
try:
|
|
spec.loader.exec_module(plugin_module)
|
|
except Exception as e:
|
|
raise ValidationError(f"Failed to execute plugin module: {e}")
|
|
|
|
# Find plugin class
|
|
plugin_class = None
|
|
for attr_name in dir(plugin_module):
|
|
attr = getattr(plugin_module, attr_name)
|
|
if (
|
|
isinstance(attr, type)
|
|
and issubclass(attr, BasePlugin)
|
|
and attr is not BasePlugin
|
|
):
|
|
plugin_class = attr
|
|
break
|
|
|
|
if not plugin_class:
|
|
raise ValidationError(
|
|
"Plugin must contain a class inheriting from BasePlugin"
|
|
)
|
|
|
|
# Instantiate plugin
|
|
plugin_instance = plugin_class(manifest, plugin_token)
|
|
|
|
# Initialize plugin
|
|
try:
|
|
await plugin_instance.initialize()
|
|
plugin_instance.initialized = True
|
|
except Exception as e:
|
|
raise ValidationError(f"Plugin initialization failed: {e}")
|
|
|
|
self.loaded_plugins[manifest.metadata.name] = plugin_instance
|
|
self.logger.info(f"Plugin {manifest.metadata.name} loaded successfully")
|
|
|
|
return plugin_instance
|
|
|
|
def _validate_plugin_security(self, main_py_path: Path):
|
|
"""Validate plugin code for security issues"""
|
|
with open(main_py_path, "r", encoding="utf-8") as f:
|
|
code_content = f.read()
|
|
|
|
# Check for dangerous patterns
|
|
dangerous_patterns = [
|
|
"eval(",
|
|
"exec(",
|
|
"compile(",
|
|
"subprocess.",
|
|
"os.system",
|
|
"os.popen",
|
|
"__import__",
|
|
"importlib.import_module",
|
|
"from app.db",
|
|
"from app.models",
|
|
"sqlalchemy",
|
|
"SessionLocal",
|
|
]
|
|
|
|
for pattern in dangerous_patterns:
|
|
if pattern in code_content:
|
|
raise SecurityError(
|
|
f"Dangerous pattern detected in plugin code: {pattern}"
|
|
)
|
|
|
|
async def unload_plugin(self, plugin_id: str) -> bool:
|
|
"""Unload a plugin and cleanup resources"""
|
|
if plugin_id not in self.loaded_plugins:
|
|
return False
|
|
|
|
plugin = self.loaded_plugins[plugin_id]
|
|
|
|
try:
|
|
await plugin.cleanup()
|
|
del self.loaded_plugins[plugin_id]
|
|
self.logger.info(f"Plugin {plugin_id} unloaded successfully")
|
|
return True
|
|
except Exception as e:
|
|
self.logger.error(f"Error unloading plugin {plugin_id}: {e}")
|
|
return False
|
|
|
|
def get_plugin(self, plugin_id: str) -> Optional[BasePlugin]:
|
|
"""Get loaded plugin by ID"""
|
|
return self.loaded_plugins.get(plugin_id)
|
|
|
|
def list_loaded_plugins(self) -> List[str]:
|
|
"""List all loaded plugin IDs"""
|
|
return list(self.loaded_plugins.keys())
|