mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-18 16:04:28 +01:00
clean commit
This commit is contained in:
363
backend/app/services/tee_service.py
Normal file
363
backend/app/services/tee_service.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""
|
||||
Trusted Execution Environment (TEE) Service
|
||||
Handles Privatemode.ai TEE integration for confidential computing
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
|
||||
import aiohttp
|
||||
from fastapi import HTTPException, status
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa, padding
|
||||
from cryptography.hazmat.primitives.serialization import load_pem_public_key
|
||||
import base64
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TEEStatus(str, Enum):
|
||||
"""TEE environment status"""
|
||||
HEALTHY = "healthy"
|
||||
DEGRADED = "degraded"
|
||||
OFFLINE = "offline"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class AttestationStatus(str, Enum):
|
||||
"""Attestation verification status"""
|
||||
VERIFIED = "verified"
|
||||
FAILED = "failed"
|
||||
PENDING = "pending"
|
||||
EXPIRED = "expired"
|
||||
|
||||
|
||||
class TEEService:
|
||||
"""Service for managing Privatemode.ai TEE integration"""
|
||||
|
||||
def __init__(self):
|
||||
self.privatemode_base_url = "http://privatemode-proxy:8080"
|
||||
self.privatemode_api_key = settings.PRIVATEMODE_API_KEY
|
||||
self.session: Optional[aiohttp.ClientSession] = None
|
||||
self.timeout = aiohttp.ClientTimeout(total=300) # 5 minutes timeout
|
||||
self.attestation_cache = {} # Cache for attestation results
|
||||
self.attestation_ttl = timedelta(hours=1) # Cache TTL
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create aiohttp session"""
|
||||
if self.session is None or self.session.closed:
|
||||
self.session = aiohttp.ClientSession(
|
||||
timeout=self.timeout,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.privatemode_api_key}"
|
||||
}
|
||||
)
|
||||
return self.session
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP session"""
|
||||
if self.session and not self.session.closed:
|
||||
await self.session.close()
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Check TEE environment health"""
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(f"{self.privatemode_base_url}/health") as response:
|
||||
if response.status == 200:
|
||||
health_data = await response.json()
|
||||
return {
|
||||
"status": TEEStatus.HEALTHY.value,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"tee_enabled": health_data.get("tee_enabled", False),
|
||||
"attestation_available": health_data.get("attestation_available", False),
|
||||
"secure_memory": health_data.get("secure_memory", False),
|
||||
"details": health_data
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": TEEStatus.DEGRADED.value,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"error": f"HTTP {response.status}"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"TEE health check error: {e}")
|
||||
return {
|
||||
"status": TEEStatus.OFFLINE.value,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def get_attestation(self, nonce: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get TEE attestation report"""
|
||||
try:
|
||||
if not nonce:
|
||||
nonce = base64.b64encode(os.urandom(32)).decode()
|
||||
|
||||
# Check cache first
|
||||
cache_key = f"attestation_{nonce}"
|
||||
if cache_key in self.attestation_cache:
|
||||
cached_result = self.attestation_cache[cache_key]
|
||||
if datetime.fromisoformat(cached_result["timestamp"]) + self.attestation_ttl > datetime.utcnow():
|
||||
return cached_result
|
||||
|
||||
session = await self._get_session()
|
||||
payload = {"nonce": nonce}
|
||||
|
||||
async with session.post(
|
||||
f"{self.privatemode_base_url}/attestation",
|
||||
json=payload
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
attestation_data = await response.json()
|
||||
|
||||
# Process attestation report
|
||||
result = {
|
||||
"status": AttestationStatus.VERIFIED.value,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"nonce": nonce,
|
||||
"report": attestation_data.get("report"),
|
||||
"signature": attestation_data.get("signature"),
|
||||
"certificate_chain": attestation_data.get("certificate_chain"),
|
||||
"measurements": attestation_data.get("measurements", {}),
|
||||
"tee_type": attestation_data.get("tee_type", "unknown"),
|
||||
"verified": True
|
||||
}
|
||||
|
||||
# Cache the result
|
||||
self.attestation_cache[cache_key] = result
|
||||
|
||||
return result
|
||||
else:
|
||||
error_text = await response.text()
|
||||
logger.error(f"TEE attestation failed: {response.status} - {error_text}")
|
||||
return {
|
||||
"status": AttestationStatus.FAILED.value,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"nonce": nonce,
|
||||
"error": error_text,
|
||||
"verified": False
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"TEE attestation error: {e}")
|
||||
return {
|
||||
"status": AttestationStatus.FAILED.value,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"nonce": nonce,
|
||||
"error": str(e),
|
||||
"verified": False
|
||||
}
|
||||
|
||||
async def verify_attestation(self, attestation_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Verify TEE attestation report"""
|
||||
try:
|
||||
# Extract components
|
||||
report = attestation_data.get("report")
|
||||
signature = attestation_data.get("signature")
|
||||
cert_chain = attestation_data.get("certificate_chain")
|
||||
|
||||
if not all([report, signature, cert_chain]):
|
||||
return {
|
||||
"verified": False,
|
||||
"status": AttestationStatus.FAILED.value,
|
||||
"error": "Missing required attestation components"
|
||||
}
|
||||
|
||||
# Verify signature (simplified - in production, use proper certificate validation)
|
||||
try:
|
||||
# This is a placeholder for actual attestation verification
|
||||
# In production, you would:
|
||||
# 1. Validate the certificate chain
|
||||
# 2. Verify the signature using the public key
|
||||
# 3. Check measurements against known good values
|
||||
# 4. Validate the nonce
|
||||
|
||||
verification_result = {
|
||||
"verified": True,
|
||||
"status": AttestationStatus.VERIFIED.value,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"certificate_valid": True,
|
||||
"signature_valid": True,
|
||||
"measurements_valid": True,
|
||||
"nonce_valid": True
|
||||
}
|
||||
|
||||
return verification_result
|
||||
|
||||
except Exception as verify_error:
|
||||
logger.error(f"Attestation verification failed: {verify_error}")
|
||||
return {
|
||||
"verified": False,
|
||||
"status": AttestationStatus.FAILED.value,
|
||||
"error": str(verify_error)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Attestation verification error: {e}")
|
||||
return {
|
||||
"verified": False,
|
||||
"status": AttestationStatus.FAILED.value,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def get_tee_capabilities(self) -> Dict[str, Any]:
|
||||
"""Get TEE environment capabilities"""
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(f"{self.privatemode_base_url}/capabilities") as response:
|
||||
if response.status == 200:
|
||||
capabilities = await response.json()
|
||||
return {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"tee_type": capabilities.get("tee_type", "unknown"),
|
||||
"secure_memory_size": capabilities.get("secure_memory_size", 0),
|
||||
"encryption_algorithms": capabilities.get("encryption_algorithms", []),
|
||||
"attestation_types": capabilities.get("attestation_types", []),
|
||||
"key_management": capabilities.get("key_management", False),
|
||||
"secure_storage": capabilities.get("secure_storage", False),
|
||||
"network_isolation": capabilities.get("network_isolation", False),
|
||||
"confidential_computing": capabilities.get("confidential_computing", False),
|
||||
"details": capabilities
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"error": f"Failed to get capabilities: HTTP {response.status}"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"TEE capabilities error: {e}")
|
||||
return {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def create_secure_session(self, user_id: str, api_key_id: int) -> Dict[str, Any]:
|
||||
"""Create a secure TEE session"""
|
||||
try:
|
||||
session = await self._get_session()
|
||||
payload = {
|
||||
"user_id": user_id,
|
||||
"api_key_id": api_key_id,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"requested_capabilities": [
|
||||
"confidential_inference",
|
||||
"secure_memory",
|
||||
"attestation"
|
||||
]
|
||||
}
|
||||
|
||||
async with session.post(
|
||||
f"{self.privatemode_base_url}/session",
|
||||
json=payload
|
||||
) as response:
|
||||
if response.status == 201:
|
||||
session_data = await response.json()
|
||||
return {
|
||||
"session_id": session_data.get("session_id"),
|
||||
"status": "active",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"capabilities": session_data.get("capabilities", []),
|
||||
"expires_at": session_data.get("expires_at"),
|
||||
"attestation_token": session_data.get("attestation_token")
|
||||
}
|
||||
else:
|
||||
error_text = await response.text()
|
||||
logger.error(f"TEE session creation failed: {response.status} - {error_text}")
|
||||
raise HTTPException(
|
||||
status_code=response.status,
|
||||
detail=f"Failed to create TEE session: {error_text}"
|
||||
)
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"TEE session creation error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="TEE service unavailable"
|
||||
)
|
||||
|
||||
async def get_privacy_metrics(self) -> Dict[str, Any]:
|
||||
"""Get privacy and security metrics"""
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(f"{self.privatemode_base_url}/metrics") as response:
|
||||
if response.status == 200:
|
||||
metrics = await response.json()
|
||||
return {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"requests_processed": metrics.get("requests_processed", 0),
|
||||
"data_encrypted": metrics.get("data_encrypted", 0),
|
||||
"attestations_verified": metrics.get("attestations_verified", 0),
|
||||
"secure_sessions": metrics.get("secure_sessions", 0),
|
||||
"uptime": metrics.get("uptime", 0),
|
||||
"memory_usage": metrics.get("memory_usage", {}),
|
||||
"performance": metrics.get("performance", {}),
|
||||
"privacy_score": metrics.get("privacy_score", 0)
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"error": f"Failed to get metrics: HTTP {response.status}"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"TEE metrics error: {e}")
|
||||
return {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def list_tee_models(self) -> List[Dict[str, Any]]:
|
||||
"""List available TEE models"""
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(f"{self.privatemode_base_url}/models") as response:
|
||||
if response.status == 200:
|
||||
models_data = await response.json()
|
||||
models = []
|
||||
|
||||
for model in models_data.get("models", []):
|
||||
models.append({
|
||||
"id": model.get("id"),
|
||||
"name": model.get("name"),
|
||||
"type": model.get("type", "chat"),
|
||||
"provider": "privatemode",
|
||||
"tee_enabled": True,
|
||||
"confidential_computing": True,
|
||||
"secure_inference": True,
|
||||
"attestation_required": model.get("attestation_required", False),
|
||||
"max_tokens": model.get("max_tokens", 4096),
|
||||
"cost_per_token": model.get("cost_per_token", 0.0),
|
||||
"availability": model.get("availability", "available")
|
||||
})
|
||||
|
||||
return models
|
||||
else:
|
||||
logger.error(f"Failed to get TEE models: HTTP {response.status}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"TEE models error: {e}")
|
||||
return []
|
||||
|
||||
async def cleanup_expired_cache(self):
|
||||
"""Clean up expired attestation cache entries"""
|
||||
current_time = datetime.utcnow()
|
||||
expired_keys = []
|
||||
|
||||
for key, cached_data in self.attestation_cache.items():
|
||||
if datetime.fromisoformat(cached_data["timestamp"]) + self.attestation_ttl <= current_time:
|
||||
expired_keys.append(key)
|
||||
|
||||
for key in expired_keys:
|
||||
del self.attestation_cache[key]
|
||||
|
||||
logger.info(f"Cleaned up {len(expired_keys)} expired attestation cache entries")
|
||||
|
||||
|
||||
# Global TEE service instance
|
||||
tee_service = TEEService()
|
||||
Reference in New Issue
Block a user