mega changes

This commit is contained in:
2025-11-20 11:11:18 +01:00
parent e070c95190
commit 841d79f26b
138 changed files with 21499 additions and 8844 deletions

View File

@@ -15,9 +15,16 @@ import aiohttp
from .base import BaseLLMProvider
from ..models import (
ChatRequest, ChatResponse, ChatMessage, ChatChoice, TokenUsage,
EmbeddingRequest, EmbeddingResponse, EmbeddingData,
ModelInfo, ProviderStatus
ChatRequest,
ChatResponse,
ChatMessage,
ChatChoice,
TokenUsage,
EmbeddingRequest,
EmbeddingResponse,
EmbeddingData,
ModelInfo,
ProviderStatus,
)
from ..config import ProviderConfig
from ..exceptions import ProviderError, ValidationError, TimeoutError
@@ -27,22 +34,22 @@ logger = logging.getLogger(__name__)
class PrivateModeProvider(BaseLLMProvider):
"""PrivateMode.ai provider with TEE security"""
def __init__(self, config: ProviderConfig, api_key: str):
super().__init__(config, api_key)
self.base_url = config.base_url.rstrip('/')
self.base_url = config.base_url.rstrip("/")
self._session: Optional[aiohttp.ClientSession] = None
# TEE-specific settings
self.verify_ssl = True # Always verify SSL for security
self.trust_env = False # Don't trust environment proxy settings
logger.info(f"PrivateMode provider initialized with base URL: {self.base_url}")
@property
def provider_name(self) -> str:
return "privatemode"
async def _get_session(self) -> aiohttp.ClientSession:
"""Get or create HTTP session with security settings"""
if self._session is None or self._session.closed:
@@ -52,45 +59,49 @@ class PrivateModeProvider(BaseLLMProvider):
limit=100, # Connection pool limit
limit_per_host=50,
ttl_dns_cache=300, # DNS cache TTL
use_dns_cache=True
use_dns_cache=True,
)
# Create session with security headers
timeout = aiohttp.ClientTimeout(total=self.config.resilience.timeout_ms / 1000.0)
timeout = aiohttp.ClientTimeout(
total=self.config.resilience.timeout_ms / 1000.0
)
self._session = aiohttp.ClientSession(
connector=connector,
timeout=timeout,
headers=self._create_headers(),
trust_env=False # Don't trust environment variables
trust_env=False, # Don't trust environment variables
)
logger.debug("Created new secure HTTP session for PrivateMode")
return self._session
async def health_check(self) -> ProviderStatus:
"""Check PrivateMode.ai service health"""
start_time = time.time()
try:
session = await self._get_session()
# Use a lightweight endpoint for health check
async with session.get(f"{self.base_url}/models") as response:
latency = (time.time() - start_time) * 1000
if response.status == 200:
models_data = await response.json()
models = [model.get("id", "") for model in models_data.get("data", [])]
models = [
model.get("id", "") for model in models_data.get("data", [])
]
return ProviderStatus(
provider=self.provider_name,
status="healthy",
latency_ms=latency,
success_rate=1.0,
last_check=datetime.utcnow(),
models_available=models
models_available=models,
)
else:
error_text = await response.text()
@@ -101,13 +112,13 @@ class PrivateModeProvider(BaseLLMProvider):
success_rate=0.0,
last_check=datetime.utcnow(),
error_message=f"HTTP {response.status}: {error_text}",
models_available=[]
models_available=[],
)
except Exception as e:
latency = (time.time() - start_time) * 1000
logger.error(f"PrivateMode health check failed: {e}")
return ProviderStatus(
provider=self.provider_name,
status="unavailable",
@@ -115,33 +126,33 @@ class PrivateModeProvider(BaseLLMProvider):
success_rate=0.0,
last_check=datetime.utcnow(),
error_message=str(e),
models_available=[]
models_available=[],
)
async def get_models(self) -> List[ModelInfo]:
"""Get available models from PrivateMode.ai"""
try:
session = await self._get_session()
async with session.get(f"{self.base_url}/models") as response:
if response.status == 200:
data = await response.json()
models_data = data.get("data", [])
models = []
for model_data in models_data:
model_id = model_data.get("id", "")
if not model_id:
continue
# Extract all information directly from API response
# Determine capabilities based on tasks field
tasks = model_data.get("tasks", [])
capabilities = []
# All PrivateMode models have TEE capability
capabilities.append("tee")
# Add capabilities based on tasks
if "generate" in tasks:
capabilities.append("chat")
@@ -149,12 +160,14 @@ class PrivateModeProvider(BaseLLMProvider):
capabilities.append("embeddings")
if "vision" in tasks:
capabilities.append("vision")
# Check for function calling support in the API response
supports_function_calling = model_data.get("supports_function_calling", False)
supports_function_calling = model_data.get(
"supports_function_calling", False
)
if supports_function_calling:
capabilities.append("function_calling")
model_info = ModelInfo(
id=model_id,
object="model",
@@ -164,40 +177,44 @@ class PrivateModeProvider(BaseLLMProvider):
capabilities=capabilities,
context_window=model_data.get("context_window"),
max_output_tokens=model_data.get("max_output_tokens"),
supports_streaming=model_data.get("supports_streaming", True),
supports_streaming=model_data.get(
"supports_streaming", True
),
supports_function_calling=supports_function_calling,
tasks=tasks # Pass through tasks field from PrivateMode API
tasks=tasks, # Pass through tasks field from PrivateMode API
)
models.append(model_info)
logger.info(f"Retrieved {len(models)} models from PrivateMode")
return models
else:
error_text = await response.text()
self._handle_http_error(response.status, error_text, "models endpoint")
self._handle_http_error(
response.status, error_text, "models endpoint"
)
return [] # Never reached due to exception
except Exception as e:
if isinstance(e, ProviderError):
raise
logger.error(f"Failed to get models from PrivateMode: {e}")
raise ProviderError(
"Failed to retrieve models from PrivateMode",
provider=self.provider_name,
error_code="MODEL_RETRIEVAL_ERROR",
details={"error": str(e)}
details={"error": str(e)},
)
async def create_chat_completion(self, request: ChatRequest) -> ChatResponse:
"""Create chat completion via PrivateMode.ai"""
self._validate_request(request)
start_time = time.time()
try:
session = await self._get_session()
# Prepare request payload
payload = {
"model": request.model,
@@ -205,14 +222,14 @@ class PrivateModeProvider(BaseLLMProvider):
{
"role": msg.role,
"content": msg.content,
**({"name": msg.name} if msg.name else {})
**({"name": msg.name} if msg.name else {}),
}
for msg in request.messages
],
"temperature": request.temperature,
"stream": False # Non-streaming version
"stream": False, # Non-streaming version
}
# Add optional parameters
if request.max_tokens is not None:
payload["max_tokens"] = request.max_tokens
@@ -224,28 +241,27 @@ class PrivateModeProvider(BaseLLMProvider):
payload["presence_penalty"] = request.presence_penalty
if request.stop is not None:
payload["stop"] = request.stop
# Add user tracking
payload["user"] = f"user_{request.user_id}"
# Add metadata for TEE audit trail
payload["metadata"] = {
"user_id": request.user_id,
"api_key_id": request.api_key_id,
"timestamp": datetime.utcnow().isoformat(),
"enclava_request_id": str(uuid.uuid4()),
**(request.metadata or {})
**(request.metadata or {}),
}
async with session.post(
f"{self.base_url}/chat/completions",
json=payload
f"{self.base_url}/chat/completions", json=payload
) as response:
provider_latency = (time.time() - start_time) * 1000
if response.status == 200:
data = await response.json()
# Parse response
choices = []
for choice_data in data.get("choices", []):
@@ -254,20 +270,20 @@ class PrivateModeProvider(BaseLLMProvider):
index=choice_data.get("index", 0),
message=ChatMessage(
role=message_data.get("role", "assistant"),
content=message_data.get("content", "")
content=message_data.get("content", ""),
),
finish_reason=choice_data.get("finish_reason")
finish_reason=choice_data.get("finish_reason"),
)
choices.append(choice)
# Parse token usage
usage_data = data.get("usage", {})
usage = TokenUsage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0)
total_tokens=usage_data.get("total_tokens", 0),
)
# Create response
chat_response = ChatResponse(
id=data.get("id", str(uuid.uuid4())),
@@ -279,45 +295,51 @@ class PrivateModeProvider(BaseLLMProvider):
usage=usage,
system_fingerprint=data.get("system_fingerprint"),
security_check=True, # Will be set by security manager
risk_score=0.0, # Will be set by security manager
risk_score=0.0, # Will be set by security manager
latency_ms=provider_latency,
provider_latency_ms=provider_latency
provider_latency_ms=provider_latency,
)
logger.debug(
f"PrivateMode chat completion successful in {provider_latency:.2f}ms"
)
logger.debug(f"PrivateMode chat completion successful in {provider_latency:.2f}ms")
return chat_response
else:
error_text = await response.text()
self._handle_http_error(response.status, error_text, "chat completion")
self._handle_http_error(
response.status, error_text, "chat completion"
)
except aiohttp.ClientError as e:
logger.error(f"PrivateMode request error: {e}")
raise ProviderError(
"Network error communicating with PrivateMode",
provider=self.provider_name,
error_code="NETWORK_ERROR",
details={"error": str(e)}
details={"error": str(e)},
)
except Exception as e:
if isinstance(e, (ProviderError, ValidationError)):
raise
logger.error(f"Unexpected error in PrivateMode chat completion: {e}")
raise ProviderError(
"Unexpected error during chat completion",
provider=self.provider_name,
error_code="UNEXPECTED_ERROR",
details={"error": str(e)}
details={"error": str(e)},
)
async def create_chat_completion_stream(self, request: ChatRequest) -> AsyncGenerator[Dict[str, Any], None]:
async def create_chat_completion_stream(
self, request: ChatRequest
) -> AsyncGenerator[Dict[str, Any], None]:
"""Create streaming chat completion"""
self._validate_request(request)
try:
session = await self._get_session()
# Prepare streaming payload
payload = {
"model": request.model,
@@ -325,14 +347,14 @@ class PrivateModeProvider(BaseLLMProvider):
{
"role": msg.role,
"content": msg.content,
**({"name": msg.name} if msg.name else {})
**({"name": msg.name} if msg.name else {}),
}
for msg in request.messages
],
"temperature": request.temperature,
"stream": True
"stream": True,
}
# Add optional parameters
if request.max_tokens is not None:
payload["max_tokens"] = request.max_tokens
@@ -344,100 +366,104 @@ class PrivateModeProvider(BaseLLMProvider):
payload["presence_penalty"] = request.presence_penalty
if request.stop is not None:
payload["stop"] = request.stop
# Add user tracking
payload["user"] = f"user_{request.user_id}"
async with session.post(
f"{self.base_url}/chat/completions",
json=payload
f"{self.base_url}/chat/completions", json=payload
) as response:
if response.status == 200:
async for line in response.content:
line = line.decode('utf-8').strip()
line = line.decode("utf-8").strip()
if line.startswith("data: "):
data_str = line[6:] # Remove "data: " prefix
if data_str == "[DONE]":
break
try:
chunk_data = json.loads(data_str)
yield chunk_data
except json.JSONDecodeError:
logger.warning(f"Failed to parse streaming chunk: {data_str}")
logger.warning(
f"Failed to parse streaming chunk: {data_str}"
)
continue
else:
error_text = await response.text()
self._handle_http_error(response.status, error_text, "streaming chat completion")
self._handle_http_error(
response.status, error_text, "streaming chat completion"
)
except aiohttp.ClientError as e:
logger.error(f"PrivateMode streaming error: {e}")
raise ProviderError(
"Network error during streaming",
provider=self.provider_name,
error_code="STREAMING_ERROR",
details={"error": str(e)}
details={"error": str(e)},
)
async def create_embedding(self, request: EmbeddingRequest) -> EmbeddingResponse:
"""Create embeddings via PrivateMode.ai"""
self._validate_request(request)
start_time = time.time()
try:
session = await self._get_session()
# Prepare embedding payload
payload = {
"model": request.model,
"input": request.input,
"user": f"user_{request.user_id}"
"user": f"user_{request.user_id}",
}
# Add optional parameters
if request.encoding_format:
payload["encoding_format"] = request.encoding_format
if request.dimensions:
payload["dimensions"] = request.dimensions
# Add metadata
payload["metadata"] = {
"user_id": request.user_id,
"api_key_id": request.api_key_id,
"timestamp": datetime.utcnow().isoformat(),
**(request.metadata or {})
**(request.metadata or {}),
}
async with session.post(
f"{self.base_url}/embeddings",
json=payload
f"{self.base_url}/embeddings", json=payload
) as response:
provider_latency = (time.time() - start_time) * 1000
if response.status == 200:
data = await response.json()
# Parse embedding data
embeddings = []
for emb_data in data.get("data", []):
embedding = EmbeddingData(
object="embedding",
index=emb_data.get("index", 0),
embedding=emb_data.get("embedding", [])
embedding=emb_data.get("embedding", []),
)
embeddings.append(embedding)
# Parse usage
usage_data = data.get("usage", {})
usage = TokenUsage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=0, # No completion tokens for embeddings
total_tokens=usage_data.get("total_tokens", usage_data.get("prompt_tokens", 0))
total_tokens=usage_data.get(
"total_tokens", usage_data.get("prompt_tokens", 0)
),
)
return EmbeddingResponse(
object="list",
data=embeddings,
@@ -445,37 +471,39 @@ class PrivateModeProvider(BaseLLMProvider):
provider=self.provider_name,
usage=usage,
security_check=True, # Will be set by security manager
risk_score=0.0, # Will be set by security manager
risk_score=0.0, # Will be set by security manager
latency_ms=provider_latency,
provider_latency_ms=provider_latency
provider_latency_ms=provider_latency,
)
else:
error_text = await response.text()
# Log the detailed error response from the provider
logger.error(f"PrivateMode embedding error - Status {response.status}: {error_text}")
logger.error(
f"PrivateMode embedding error - Status {response.status}: {error_text}"
)
self._handle_http_error(response.status, error_text, "embeddings")
except aiohttp.ClientError as e:
logger.error(f"PrivateMode embedding error: {e}")
raise ProviderError(
"Network error during embedding generation",
provider=self.provider_name,
error_code="EMBEDDING_ERROR",
details={"error": str(e)}
details={"error": str(e)},
)
except Exception as e:
if isinstance(e, (ProviderError, ValidationError)):
raise
logger.error(f"Unexpected error in PrivateMode embedding: {e}")
raise ProviderError(
"Unexpected error during embedding generation",
provider=self.provider_name,
error_code="UNEXPECTED_ERROR",
details={"error": str(e)}
details={"error": str(e)},
)
async def cleanup(self):
"""Cleanup PrivateMode provider resources"""
# Close HTTP session to prevent memory leaks
@@ -485,4 +513,4 @@ class PrivateModeProvider(BaseLLMProvider):
logger.debug("Closed PrivateMode HTTP session")
await super().cleanup()
logger.debug("PrivateMode provider cleanup completed")
logger.debug("PrivateMode provider cleanup completed")