mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 23:44:24 +01:00
363 lines
15 KiB
Python
363 lines
15 KiB
Python
"""
|
|
Trusted Execution Environment (TEE) Service
|
|
Handles Privatemode.ai TEE integration for confidential computing
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
from typing import Dict, Any, Optional, List
|
|
from datetime import datetime, timedelta
|
|
from enum import Enum
|
|
|
|
import aiohttp
|
|
from fastapi import HTTPException, status
|
|
from cryptography.hazmat.primitives import hashes
|
|
from cryptography.hazmat.primitives.asymmetric import rsa, padding
|
|
from cryptography.hazmat.primitives.serialization import load_pem_public_key
|
|
import base64
|
|
|
|
from app.core.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TEEStatus(str, Enum):
|
|
"""TEE environment status"""
|
|
HEALTHY = "healthy"
|
|
DEGRADED = "degraded"
|
|
OFFLINE = "offline"
|
|
UNKNOWN = "unknown"
|
|
|
|
|
|
class AttestationStatus(str, Enum):
|
|
"""Attestation verification status"""
|
|
VERIFIED = "verified"
|
|
FAILED = "failed"
|
|
PENDING = "pending"
|
|
EXPIRED = "expired"
|
|
|
|
|
|
class TEEService:
|
|
"""Service for managing Privatemode.ai TEE integration"""
|
|
|
|
def __init__(self):
|
|
self.privatemode_base_url = "http://privatemode-proxy:8080"
|
|
self.privatemode_api_key = settings.PRIVATEMODE_API_KEY
|
|
self.session: Optional[aiohttp.ClientSession] = None
|
|
self.timeout = aiohttp.ClientTimeout(total=300) # 5 minutes timeout
|
|
self.attestation_cache = {} # Cache for attestation results
|
|
self.attestation_ttl = timedelta(hours=1) # Cache TTL
|
|
|
|
async def _get_session(self) -> aiohttp.ClientSession:
|
|
"""Get or create aiohttp session"""
|
|
if self.session is None or self.session.closed:
|
|
self.session = aiohttp.ClientSession(
|
|
timeout=self.timeout,
|
|
headers={
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {self.privatemode_api_key}"
|
|
}
|
|
)
|
|
return self.session
|
|
|
|
async def close(self):
|
|
"""Close the HTTP session"""
|
|
if self.session and not self.session.closed:
|
|
await self.session.close()
|
|
|
|
async def health_check(self) -> Dict[str, Any]:
|
|
"""Check TEE environment health"""
|
|
try:
|
|
session = await self._get_session()
|
|
async with session.get(f"{self.privatemode_base_url}/health") as response:
|
|
if response.status == 200:
|
|
health_data = await response.json()
|
|
return {
|
|
"status": TEEStatus.HEALTHY.value,
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"tee_enabled": health_data.get("tee_enabled", False),
|
|
"attestation_available": health_data.get("attestation_available", False),
|
|
"secure_memory": health_data.get("secure_memory", False),
|
|
"details": health_data
|
|
}
|
|
else:
|
|
return {
|
|
"status": TEEStatus.DEGRADED.value,
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"error": f"HTTP {response.status}"
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"TEE health check error: {e}")
|
|
return {
|
|
"status": TEEStatus.OFFLINE.value,
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"error": str(e)
|
|
}
|
|
|
|
async def get_attestation(self, nonce: Optional[str] = None) -> Dict[str, Any]:
|
|
"""Get TEE attestation report"""
|
|
try:
|
|
if not nonce:
|
|
nonce = base64.b64encode(os.urandom(32)).decode()
|
|
|
|
# Check cache first
|
|
cache_key = f"attestation_{nonce}"
|
|
if cache_key in self.attestation_cache:
|
|
cached_result = self.attestation_cache[cache_key]
|
|
if datetime.fromisoformat(cached_result["timestamp"]) + self.attestation_ttl > datetime.utcnow():
|
|
return cached_result
|
|
|
|
session = await self._get_session()
|
|
payload = {"nonce": nonce}
|
|
|
|
async with session.post(
|
|
f"{self.privatemode_base_url}/attestation",
|
|
json=payload
|
|
) as response:
|
|
if response.status == 200:
|
|
attestation_data = await response.json()
|
|
|
|
# Process attestation report
|
|
result = {
|
|
"status": AttestationStatus.VERIFIED.value,
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"nonce": nonce,
|
|
"report": attestation_data.get("report"),
|
|
"signature": attestation_data.get("signature"),
|
|
"certificate_chain": attestation_data.get("certificate_chain"),
|
|
"measurements": attestation_data.get("measurements", {}),
|
|
"tee_type": attestation_data.get("tee_type", "unknown"),
|
|
"verified": True
|
|
}
|
|
|
|
# Cache the result
|
|
self.attestation_cache[cache_key] = result
|
|
|
|
return result
|
|
else:
|
|
error_text = await response.text()
|
|
logger.error(f"TEE attestation failed: {response.status} - {error_text}")
|
|
return {
|
|
"status": AttestationStatus.FAILED.value,
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"nonce": nonce,
|
|
"error": error_text,
|
|
"verified": False
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"TEE attestation error: {e}")
|
|
return {
|
|
"status": AttestationStatus.FAILED.value,
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"nonce": nonce,
|
|
"error": str(e),
|
|
"verified": False
|
|
}
|
|
|
|
async def verify_attestation(self, attestation_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Verify TEE attestation report"""
|
|
try:
|
|
# Extract components
|
|
report = attestation_data.get("report")
|
|
signature = attestation_data.get("signature")
|
|
cert_chain = attestation_data.get("certificate_chain")
|
|
|
|
if not all([report, signature, cert_chain]):
|
|
return {
|
|
"verified": False,
|
|
"status": AttestationStatus.FAILED.value,
|
|
"error": "Missing required attestation components"
|
|
}
|
|
|
|
# Verify signature (simplified - in production, use proper certificate validation)
|
|
try:
|
|
# This is a placeholder for actual attestation verification
|
|
# In production, you would:
|
|
# 1. Validate the certificate chain
|
|
# 2. Verify the signature using the public key
|
|
# 3. Check measurements against known good values
|
|
# 4. Validate the nonce
|
|
|
|
verification_result = {
|
|
"verified": True,
|
|
"status": AttestationStatus.VERIFIED.value,
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"certificate_valid": True,
|
|
"signature_valid": True,
|
|
"measurements_valid": True,
|
|
"nonce_valid": True
|
|
}
|
|
|
|
return verification_result
|
|
|
|
except Exception as verify_error:
|
|
logger.error(f"Attestation verification failed: {verify_error}")
|
|
return {
|
|
"verified": False,
|
|
"status": AttestationStatus.FAILED.value,
|
|
"error": str(verify_error)
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Attestation verification error: {e}")
|
|
return {
|
|
"verified": False,
|
|
"status": AttestationStatus.FAILED.value,
|
|
"error": str(e)
|
|
}
|
|
|
|
async def get_tee_capabilities(self) -> Dict[str, Any]:
|
|
"""Get TEE environment capabilities"""
|
|
try:
|
|
session = await self._get_session()
|
|
async with session.get(f"{self.privatemode_base_url}/capabilities") as response:
|
|
if response.status == 200:
|
|
capabilities = await response.json()
|
|
return {
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"tee_type": capabilities.get("tee_type", "unknown"),
|
|
"secure_memory_size": capabilities.get("secure_memory_size", 0),
|
|
"encryption_algorithms": capabilities.get("encryption_algorithms", []),
|
|
"attestation_types": capabilities.get("attestation_types", []),
|
|
"key_management": capabilities.get("key_management", False),
|
|
"secure_storage": capabilities.get("secure_storage", False),
|
|
"network_isolation": capabilities.get("network_isolation", False),
|
|
"confidential_computing": capabilities.get("confidential_computing", False),
|
|
"details": capabilities
|
|
}
|
|
else:
|
|
return {
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"error": f"Failed to get capabilities: HTTP {response.status}"
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"TEE capabilities error: {e}")
|
|
return {
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"error": str(e)
|
|
}
|
|
|
|
async def create_secure_session(self, user_id: str, api_key_id: int) -> Dict[str, Any]:
|
|
"""Create a secure TEE session"""
|
|
try:
|
|
session = await self._get_session()
|
|
payload = {
|
|
"user_id": user_id,
|
|
"api_key_id": api_key_id,
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"requested_capabilities": [
|
|
"confidential_inference",
|
|
"secure_memory",
|
|
"attestation"
|
|
]
|
|
}
|
|
|
|
async with session.post(
|
|
f"{self.privatemode_base_url}/session",
|
|
json=payload
|
|
) as response:
|
|
if response.status == 201:
|
|
session_data = await response.json()
|
|
return {
|
|
"session_id": session_data.get("session_id"),
|
|
"status": "active",
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"capabilities": session_data.get("capabilities", []),
|
|
"expires_at": session_data.get("expires_at"),
|
|
"attestation_token": session_data.get("attestation_token")
|
|
}
|
|
else:
|
|
error_text = await response.text()
|
|
logger.error(f"TEE session creation failed: {response.status} - {error_text}")
|
|
raise HTTPException(
|
|
status_code=response.status,
|
|
detail=f"Failed to create TEE session: {error_text}"
|
|
)
|
|
except aiohttp.ClientError as e:
|
|
logger.error(f"TEE session creation error: {e}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="TEE service unavailable"
|
|
)
|
|
|
|
async def get_privacy_metrics(self) -> Dict[str, Any]:
|
|
"""Get privacy and security metrics"""
|
|
try:
|
|
session = await self._get_session()
|
|
async with session.get(f"{self.privatemode_base_url}/metrics") as response:
|
|
if response.status == 200:
|
|
metrics = await response.json()
|
|
return {
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"requests_processed": metrics.get("requests_processed", 0),
|
|
"data_encrypted": metrics.get("data_encrypted", 0),
|
|
"attestations_verified": metrics.get("attestations_verified", 0),
|
|
"secure_sessions": metrics.get("secure_sessions", 0),
|
|
"uptime": metrics.get("uptime", 0),
|
|
"memory_usage": metrics.get("memory_usage", {}),
|
|
"performance": metrics.get("performance", {}),
|
|
"privacy_score": metrics.get("privacy_score", 0)
|
|
}
|
|
else:
|
|
return {
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"error": f"Failed to get metrics: HTTP {response.status}"
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"TEE metrics error: {e}")
|
|
return {
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"error": str(e)
|
|
}
|
|
|
|
async def list_tee_models(self) -> List[Dict[str, Any]]:
|
|
"""List available TEE models"""
|
|
try:
|
|
session = await self._get_session()
|
|
async with session.get(f"{self.privatemode_base_url}/models") as response:
|
|
if response.status == 200:
|
|
models_data = await response.json()
|
|
models = []
|
|
|
|
for model in models_data.get("models", []):
|
|
models.append({
|
|
"id": model.get("id"),
|
|
"name": model.get("name"),
|
|
"type": model.get("type", "chat"),
|
|
"provider": "privatemode",
|
|
"tee_enabled": True,
|
|
"confidential_computing": True,
|
|
"secure_inference": True,
|
|
"attestation_required": model.get("attestation_required", False),
|
|
"max_tokens": model.get("max_tokens", 4096),
|
|
"cost_per_token": model.get("cost_per_token", 0.0),
|
|
"availability": model.get("availability", "available")
|
|
})
|
|
|
|
return models
|
|
else:
|
|
logger.error(f"Failed to get TEE models: HTTP {response.status}")
|
|
return []
|
|
except Exception as e:
|
|
logger.error(f"TEE models error: {e}")
|
|
return []
|
|
|
|
async def cleanup_expired_cache(self):
|
|
"""Clean up expired attestation cache entries"""
|
|
current_time = datetime.utcnow()
|
|
expired_keys = []
|
|
|
|
for key, cached_data in self.attestation_cache.items():
|
|
if datetime.fromisoformat(cached_data["timestamp"]) + self.attestation_ttl <= current_time:
|
|
expired_keys.append(key)
|
|
|
|
for key in expired_keys:
|
|
del self.attestation_cache[key]
|
|
|
|
logger.info(f"Cleaned up {len(expired_keys)} expired attestation cache entries")
|
|
|
|
|
|
# Global TEE service instance
|
|
tee_service = TEEService() |